use crate::rocfft::bindings;
use crate::rocfft::error::{Error, Result, check_error};
use crate::rocfft::field::Field;
use crate::rocfft::plan::ArrayType;
use std::marker::PhantomData;
use std::ptr;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum CommType {
None,
MPI,
}
impl From<CommType> for u32 {
fn from(comm_type: CommType) -> Self {
match comm_type {
CommType::None => bindings::rocfft_comm_type_e_rocfft_comm_none,
CommType::MPI => bindings::rocfft_comm_type_e_rocfft_comm_mpi,
}
}
}
pub struct PlanDescription {
handle: bindings::rocfft_plan_description,
_marker: PhantomData<*mut ()>, }
impl PlanDescription {
pub fn new() -> Result<Self> {
let mut handle: bindings::rocfft_plan_description = ptr::null_mut();
unsafe {
check_error(bindings::rocfft_plan_description_create(&mut handle))?;
}
Ok(PlanDescription {
handle,
_marker: PhantomData,
})
}
pub fn set_scale_factor(&mut self, scale_factor: f64) -> Result<()> {
if self.handle.is_null() {
return Err(Error::ObjectDestroyed);
}
if !scale_factor.is_finite() {
return Err(Error::InvalidArgValue);
}
unsafe {
check_error(bindings::rocfft_plan_description_set_scale_factor(
self.handle,
scale_factor,
))
}
}
pub fn set_data_layout(
&mut self,
in_array_type: ArrayType,
out_array_type: ArrayType,
in_offsets: Option<&[usize]>,
out_offsets: Option<&[usize]>,
in_strides: Option<&[usize]>,
in_distance: usize,
out_strides: Option<&[usize]>,
out_distance: usize,
) -> Result<()> {
if self.handle.is_null() {
return Err(Error::ObjectDestroyed);
}
let in_offsets_ptr = match in_offsets {
Some(offsets) => offsets.as_ptr(),
None => ptr::null(),
};
let out_offsets_ptr = match out_offsets {
Some(offsets) => offsets.as_ptr(),
None => ptr::null(),
};
let (in_strides_ptr, in_strides_size) = match in_strides {
Some(strides) => (strides.as_ptr(), strides.len()),
None => (ptr::null(), 0),
};
let (out_strides_ptr, out_strides_size) = match out_strides {
Some(strides) => (strides.as_ptr(), strides.len()),
None => (ptr::null(), 0),
};
unsafe {
check_error(bindings::rocfft_plan_description_set_data_layout(
self.handle,
in_array_type.into(),
out_array_type.into(),
in_offsets_ptr,
out_offsets_ptr,
in_strides_size,
in_strides_ptr,
in_distance,
out_strides_size,
out_strides_ptr,
out_distance,
))
}
}
pub unsafe fn set_comm(
&mut self,
comm_type: CommType,
comm_handle: *mut std::ffi::c_void,
) -> Result<()> {
if self.handle.is_null() {
return Err(Error::ObjectDestroyed);
}
unsafe {
check_error(bindings::rocfft_plan_description_set_comm(
self.handle,
comm_type.into(),
comm_handle,
))
}
}
pub fn add_infield(&mut self, field: &Field) -> Result<()> {
if self.handle.is_null() {
return Err(Error::ObjectDestroyed);
}
unsafe {
check_error(bindings::rocfft_plan_description_add_infield(
self.handle,
field.as_ptr(),
))
}
}
pub fn add_outfield(&mut self, field: &Field) -> Result<()> {
if self.handle.is_null() {
return Err(Error::ObjectDestroyed);
}
unsafe {
check_error(bindings::rocfft_plan_description_add_outfield(
self.handle,
field.as_ptr(),
))
}
}
pub(crate) fn as_ptr(&self) -> bindings::rocfft_plan_description {
self.handle
}
}
impl Drop for PlanDescription {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe {
bindings::rocfft_plan_description_destroy(self.handle);
}
self.handle = ptr::null_mut();
}
}
}