#![warn(missing_debug_implementations)]
use baracuda_cufft_sys::{
cufft, cufftComplex, cufftDoubleComplex, cufftHandle, cufftResult, cufftType,
};
use baracuda_driver::{DeviceBuffer, Stream};
use baracuda_types::{Complex32, Complex64};
pub type Error = baracuda_core::Error<cufftResult>;
pub type Result<T, E = Error> = core::result::Result<T, E>;
#[inline]
fn check(status: cufftResult) -> Result<()> {
Error::check(status)
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum Transform {
R2C,
C2R,
C2C,
D2Z,
Z2D,
Z2Z,
}
impl Transform {
fn raw(self) -> cufftType {
match self {
Transform::R2C => cufftType::R2C,
Transform::C2R => cufftType::C2R,
Transform::C2C => cufftType::C2C,
Transform::D2Z => cufftType::D2Z,
Transform::Z2D => cufftType::Z2D,
Transform::Z2Z => cufftType::Z2Z,
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
pub enum Direction {
#[default]
Forward,
Inverse,
}
impl Direction {
fn raw(self) -> core::ffi::c_int {
match self {
Direction::Forward => baracuda_cufft_sys::CUFFT_FORWARD,
Direction::Inverse => baracuda_cufft_sys::CUFFT_INVERSE,
}
}
}
pub struct Plan1d {
handle: cufftHandle,
}
unsafe impl Send for Plan1d {}
impl core::fmt::Debug for Plan1d {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Plan1d")
.field("handle", &self.handle)
.finish()
}
}
impl Plan1d {
pub fn new(nx: i32, transform: Transform, batch: i32) -> Result<Self> {
let c = cufft()?;
let cu = c.cufft_plan_1d()?;
let mut plan: cufftHandle = 0;
check(unsafe { cu(&mut plan, nx, transform.raw(), batch) })?;
Ok(Self { handle: plan })
}
pub fn set_stream(&self, stream: &Stream) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_set_stream()?;
check(unsafe { cu(self.handle, stream.as_raw() as _) })
}
pub fn exec_r2c(
&self,
input: &mut DeviceBuffer<f32>,
output: &mut DeviceBuffer<Complex32>,
) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_exec_r2c()?;
check(unsafe {
cu(
self.handle,
input.as_raw().0 as *mut f32,
output.as_raw().0 as *mut cufftComplex,
)
})
}
pub fn exec_c2r(
&self,
input: &mut DeviceBuffer<Complex32>,
output: &mut DeviceBuffer<f32>,
) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_exec_c2r()?;
check(unsafe {
cu(
self.handle,
input.as_raw().0 as *mut cufftComplex,
output.as_raw().0 as *mut f32,
)
})
}
pub fn exec_c2c(
&self,
input: &mut DeviceBuffer<Complex32>,
output: &mut DeviceBuffer<Complex32>,
direction: Direction,
) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_exec_c2c()?;
check(unsafe {
cu(
self.handle,
input.as_raw().0 as *mut cufftComplex,
output.as_raw().0 as *mut cufftComplex,
direction.raw(),
)
})
}
#[inline]
pub fn as_raw(&self) -> cufftHandle {
self.handle
}
}
impl Drop for Plan1d {
fn drop(&mut self) {
if let Ok(c) = cufft() {
if let Ok(cu) = c.cufft_destroy() {
let _ = unsafe { cu(self.handle) };
}
}
}
}
pub struct Plan2d {
handle: cufftHandle,
}
unsafe impl Send for Plan2d {}
impl core::fmt::Debug for Plan2d {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Plan2d")
.field("handle", &self.handle)
.finish()
}
}
impl Plan2d {
pub fn new(nx: i32, ny: i32, transform: Transform) -> Result<Self> {
let c = cufft()?;
let cu = c.cufft_plan_2d()?;
let mut plan: cufftHandle = 0;
check(unsafe { cu(&mut plan, nx, ny, transform.raw()) })?;
Ok(Self { handle: plan })
}
pub fn exec_c2c(
&self,
input: &mut DeviceBuffer<Complex32>,
output: &mut DeviceBuffer<Complex32>,
direction: Direction,
) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_exec_c2c()?;
check(unsafe {
cu(
self.handle,
input.as_raw().0 as *mut cufftComplex,
output.as_raw().0 as *mut cufftComplex,
direction.raw(),
)
})
}
#[inline]
pub fn as_raw(&self) -> cufftHandle {
self.handle
}
}
impl Drop for Plan2d {
fn drop(&mut self) {
if let Ok(c) = cufft() {
if let Ok(cu) = c.cufft_destroy() {
let _ = unsafe { cu(self.handle) };
}
}
}
}
pub struct Plan3d {
handle: cufftHandle,
}
unsafe impl Send for Plan3d {}
impl core::fmt::Debug for Plan3d {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Plan3d")
.field("handle", &self.handle)
.finish()
}
}
impl Plan3d {
pub fn new(nx: i32, ny: i32, nz: i32, transform: Transform) -> Result<Self> {
let c = cufft()?;
let cu = c.cufft_plan_3d()?;
let mut plan: cufftHandle = 0;
check(unsafe { cu(&mut plan, nx, ny, nz, transform.raw()) })?;
Ok(Self { handle: plan })
}
pub fn set_stream(&self, stream: &Stream) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_set_stream()?;
check(unsafe { cu(self.handle, stream.as_raw() as _) })
}
pub fn exec_c2c(
&self,
input: &mut DeviceBuffer<Complex32>,
output: &mut DeviceBuffer<Complex32>,
direction: Direction,
) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_exec_c2c()?;
check(unsafe {
cu(
self.handle,
input.as_raw().0 as *mut cufftComplex,
output.as_raw().0 as *mut cufftComplex,
direction.raw(),
)
})
}
#[inline]
pub fn as_raw(&self) -> cufftHandle {
self.handle
}
}
impl Drop for Plan3d {
fn drop(&mut self) {
if let Ok(c) = cufft() {
if let Ok(cu) = c.cufft_destroy() {
let _ = unsafe { cu(self.handle) };
}
}
}
}
pub fn version() -> Result<i32> {
let c = cufft()?;
let cu = c.cufft_get_version()?;
let mut v: core::ffi::c_int = 0;
check(unsafe { cu(&mut v) })?;
Ok(v)
}
macro_rules! exec_z_impls {
($plan:ty) => {
impl $plan {
pub fn exec_d2z(
&self,
input: &mut DeviceBuffer<f64>,
output: &mut DeviceBuffer<Complex64>,
) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_exec_d2z()?;
check(unsafe {
cu(
self.handle,
input.as_raw().0 as *mut f64,
output.as_raw().0 as *mut cufftDoubleComplex,
)
})
}
pub fn exec_z2d(
&self,
input: &mut DeviceBuffer<Complex64>,
output: &mut DeviceBuffer<f64>,
) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_exec_z2d()?;
check(unsafe {
cu(
self.handle,
input.as_raw().0 as *mut cufftDoubleComplex,
output.as_raw().0 as *mut f64,
)
})
}
pub fn exec_z2z(
&self,
input: &mut DeviceBuffer<Complex64>,
output: &mut DeviceBuffer<Complex64>,
direction: Direction,
) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_exec_z2z()?;
check(unsafe {
cu(
self.handle,
input.as_raw().0 as *mut cufftDoubleComplex,
output.as_raw().0 as *mut cufftDoubleComplex,
direction.raw(),
)
})
}
}
};
}
exec_z_impls!(Plan1d);
exec_z_impls!(Plan2d);
#[derive(Debug)]
pub struct PlanMany {
handle: cufftHandle,
}
impl PlanMany {
#[allow(clippy::too_many_arguments)]
pub fn new(
rank: i32,
n: &mut [i32],
inembed: Option<&mut [i32]>,
istride: i32,
idist: i32,
onembed: Option<&mut [i32]>,
ostride: i32,
odist: i32,
ty: Transform,
batch: i32,
) -> Result<Self> {
let c = cufft()?;
let cu = c.cufft_plan_many()?;
let mut h: cufftHandle = 0;
check(unsafe {
cu(
&mut h,
rank,
n.as_mut_ptr(),
inembed.map_or(core::ptr::null_mut(), |s| s.as_mut_ptr()),
istride,
idist,
onembed.map_or(core::ptr::null_mut(), |s| s.as_mut_ptr()),
ostride,
odist,
ty.raw(),
batch,
)
})?;
Ok(Self { handle: h })
}
#[inline]
pub fn as_raw(&self) -> cufftHandle {
self.handle
}
pub fn set_stream(&self, stream: &Stream) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_set_stream()?;
check(unsafe { cu(self.handle, stream.as_raw() as _) })
}
}
impl Drop for PlanMany {
fn drop(&mut self) {
if let Ok(c) = cufft() {
if let Ok(cu) = c.cufft_destroy() {
let _ = unsafe { cu(self.handle) };
}
}
}
}
exec_z_impls!(PlanMany);
impl PlanMany {
pub fn exec_r2c(
&self,
input: &mut DeviceBuffer<f32>,
output: &mut DeviceBuffer<Complex32>,
) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_exec_r2c()?;
check(unsafe {
cu(
self.handle,
input.as_raw().0 as *mut f32,
output.as_raw().0 as *mut cufftComplex,
)
})
}
pub fn exec_c2r(
&self,
input: &mut DeviceBuffer<Complex32>,
output: &mut DeviceBuffer<f32>,
) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_exec_c2r()?;
check(unsafe {
cu(
self.handle,
input.as_raw().0 as *mut cufftComplex,
output.as_raw().0 as *mut f32,
)
})
}
pub fn exec_c2c(
&self,
input: &mut DeviceBuffer<Complex32>,
output: &mut DeviceBuffer<Complex32>,
direction: Direction,
) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_exec_c2c()?;
check(unsafe {
cu(
self.handle,
input.as_raw().0 as *mut cufftComplex,
output.as_raw().0 as *mut cufftComplex,
direction.raw(),
)
})
}
}
pub fn estimate_1d(nx: i32, ty: Transform, batch: i32) -> Result<usize> {
let c = cufft()?;
let cu = c.cufft_estimate_1d()?;
let mut s: usize = 0;
check(unsafe { cu(nx, ty.raw(), batch, &mut s) })?;
Ok(s)
}
pub fn estimate_2d(nx: i32, ny: i32, ty: Transform) -> Result<usize> {
let c = cufft()?;
let cu = c.cufft_estimate_2d()?;
let mut s: usize = 0;
check(unsafe { cu(nx, ny, ty.raw(), &mut s) })?;
Ok(s)
}
pub fn estimate_3d(nx: i32, ny: i32, nz: i32, ty: Transform) -> Result<usize> {
let c = cufft()?;
let cu = c.cufft_estimate_3d()?;
let mut s: usize = 0;
check(unsafe { cu(nx, ny, nz, ty.raw(), &mut s) })?;
Ok(s)
}
pub mod xt {
use super::*;
pub unsafe fn set_gpus(plan: cufftHandle, which_gpus: &mut [i32]) -> Result<()> { unsafe {
let c = cufft()?;
let cu = c.cufft_xt_set_gpus()?;
check(cu(plan, which_gpus.len() as i32, which_gpus.as_mut_ptr()))
}}
pub unsafe fn malloc(
plan: cufftHandle,
subformat: i32,
) -> Result<*mut core::ffi::c_void> { unsafe {
let c = cufft()?;
let cu = c.cufft_xt_malloc()?;
let mut desc: *mut core::ffi::c_void = core::ptr::null_mut();
check(cu(plan, &mut desc, subformat))?;
Ok(desc)
}}
pub unsafe fn free(desc: *mut core::ffi::c_void) -> Result<()> { unsafe {
let c = cufft()?;
let cu = c.cufft_xt_free()?;
check(cu(desc))
}}
pub unsafe fn memcpy(
plan: cufftHandle,
dst: *mut core::ffi::c_void,
src: *mut core::ffi::c_void,
ty: i32,
) -> Result<()> { unsafe {
let c = cufft()?;
let cu = c.cufft_xt_memcpy()?;
check(cu(plan, dst, src, ty))
}}
pub unsafe fn exec_descriptor(
plan: cufftHandle,
input: *mut core::ffi::c_void,
output: *mut core::ffi::c_void,
direction: Direction,
) -> Result<()> { unsafe {
let c = cufft()?;
let cu = c.cufft_xt_exec_descriptor()?;
check(cu(plan, input, output, direction.raw()))
}}
}
pub unsafe fn set_work_area(plan: cufftHandle, work_area: *mut core::ffi::c_void) -> Result<()> { unsafe {
let c = cufft()?;
let cu = c.cufft_set_work_area()?;
check(cu(plan, work_area))
}}
pub fn set_auto_allocation(plan: cufftHandle, auto: bool) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_set_auto_allocation()?;
check(unsafe { cu(plan, if auto { 1 } else { 0 }) })
}
pub fn get_size(plan: cufftHandle) -> Result<usize> {
let c = cufft()?;
let cu = c.cufft_get_size()?;
let mut s: usize = 0;
check(unsafe { cu(plan, &mut s) })?;
Ok(s)
}
pub struct Plan {
handle: cufftHandle,
}
unsafe impl Send for Plan {}
impl core::fmt::Debug for Plan {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Plan")
.field("handle", &self.handle)
.finish()
}
}
impl Plan {
pub fn create() -> Result<Self> {
let c = cufft()?;
let cu = c.cufft_create()?;
let mut plan: cufftHandle = 0;
check(unsafe { cu(&mut plan) })?;
Ok(Self { handle: plan })
}
pub fn make_plan_1d(&self, nx: i32, transform: Transform, batch: i32) -> Result<usize> {
let c = cufft()?;
let cu = c.cufft_make_plan_1d()?;
let mut size: usize = 0;
check(unsafe { cu(self.handle, nx, transform.raw(), batch, &mut size) })?;
Ok(size)
}
pub fn make_plan_2d(&self, nx: i32, ny: i32, transform: Transform) -> Result<usize> {
let c = cufft()?;
let cu = c.cufft_make_plan_2d()?;
let mut size: usize = 0;
check(unsafe { cu(self.handle, nx, ny, transform.raw(), &mut size) })?;
Ok(size)
}
pub fn make_plan_3d(
&self,
nx: i32,
ny: i32,
nz: i32,
transform: Transform,
) -> Result<usize> {
let c = cufft()?;
let cu = c.cufft_make_plan_3d()?;
let mut size: usize = 0;
check(unsafe { cu(self.handle, nx, ny, nz, transform.raw(), &mut size) })?;
Ok(size)
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn make_plan_many(
&self,
rank: i32,
n: &mut [i32],
inembed: *mut i32,
istride: i32,
idist: i32,
onembed: *mut i32,
ostride: i32,
odist: i32,
transform: Transform,
batch: i32,
) -> Result<usize> { unsafe {
assert_eq!(n.len() as i32, rank, "n.len() must equal rank");
let c = cufft()?;
let cu = c.cufft_make_plan_many()?;
let mut size: usize = 0;
check(cu(
self.handle,
rank,
n.as_mut_ptr(),
inembed,
istride,
idist,
onembed,
ostride,
odist,
transform.raw(),
batch,
&mut size,
))?;
Ok(size)
}}
#[allow(clippy::too_many_arguments)]
pub unsafe fn make_plan_many64(
&self,
rank: i32,
n: &mut [i64],
inembed: *mut i64,
istride: i64,
idist: i64,
onembed: *mut i64,
ostride: i64,
odist: i64,
transform: Transform,
batch: i64,
) -> Result<usize> { unsafe {
assert_eq!(n.len() as i32, rank, "n.len() must equal rank");
let c = cufft()?;
let cu = c.cufft_make_plan_many64()?;
let mut size: usize = 0;
check(cu(
self.handle,
rank,
n.as_mut_ptr(),
inembed,
istride,
idist,
onembed,
ostride,
odist,
transform.raw(),
batch,
&mut size,
))?;
Ok(size)
}}
pub fn set_stream(&self, stream: &Stream) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_set_stream()?;
check(unsafe { cu(self.handle, stream.as_raw() as _) })
}
#[inline]
pub fn as_raw(&self) -> cufftHandle {
self.handle
}
}
impl Drop for Plan {
fn drop(&mut self) {
if let Ok(c) = cufft() {
if let Ok(cu) = c.cufft_destroy() {
let _ = unsafe { cu(self.handle) };
}
}
}
}
pub mod callback {
use super::*;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[repr(i32)]
pub enum CallbackType {
LoadComplex = 0,
LoadDoubleComplex = 1,
LoadReal = 2,
LoadDoubleReal = 3,
StoreComplex = 4,
StoreDoubleComplex = 5,
StoreReal = 6,
StoreDoubleReal = 7,
}
pub unsafe fn set(
plan: cufftHandle,
callback_routine: &mut [*mut core::ffi::c_void],
cb_type: CallbackType,
caller_info: &mut [*mut core::ffi::c_void],
) -> Result<()> { unsafe {
assert_eq!(
callback_routine.len(),
caller_info.len(),
"callback_routine and caller_info must have the same length"
);
let c = cufft()?;
let cu = c.cufft_xt_set_callback()?;
check(cu(
plan,
callback_routine.as_mut_ptr(),
cb_type as i32,
caller_info.as_mut_ptr(),
))
}}
pub fn clear(plan: cufftHandle, cb_type: CallbackType) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_xt_clear_callback()?;
check(unsafe { cu(plan, cb_type as i32) })
}
pub fn set_shared_size(
plan: cufftHandle,
cb_type: CallbackType,
shared_size: usize,
) -> Result<()> {
let c = cufft()?;
let cu = c.cufft_xt_set_callback_shared_size()?;
check(unsafe { cu(plan, cb_type as i32, shared_size) })
}
}