use super::{result, sys};
use crate::driver::{CudaStream, DevicePtr, DevicePtrMut, DeviceRepr, ValidAsZeroBits};
use core::ffi::c_int;
use std::sync::Arc;
unsafe impl DeviceRepr for sys::float2 {}
unsafe impl ValidAsZeroBits for sys::float2 {}
unsafe impl DeviceRepr for sys::double2 {}
unsafe impl ValidAsZeroBits for sys::double2 {}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum FftDirection {
Forward = -1,
Inverse = 1,
}
pub struct CudaFft {
pub(crate) handle: sys::cufftHandle,
pub(crate) stream: Arc<CudaStream>,
}
unsafe impl Send for CudaFft {}
unsafe impl Sync for CudaFft {}
impl CudaFft {
pub fn plan_1d(
nx: i32,
type_: sys::cufftType,
batch: i32,
stream: Arc<CudaStream>,
) -> Result<Self, result::CufftError> {
let ctx = stream.context();
ctx.record_err(ctx.bind_to_thread());
let handle = result::plan_1d(nx, type_, batch)?;
unsafe { result::set_stream(handle, stream.cu_stream() as _) }?;
Ok(Self { handle, stream })
}
pub fn plan_2d(
nx: i32,
ny: i32,
type_: sys::cufftType,
stream: Arc<CudaStream>,
) -> Result<Self, result::CufftError> {
let ctx = stream.context();
ctx.record_err(ctx.bind_to_thread());
let handle = result::plan_2d(nx, ny, type_)?;
unsafe { result::set_stream(handle, stream.cu_stream() as _) }?;
Ok(Self { handle, stream })
}
pub fn plan_3d(
nx: i32,
ny: i32,
nz: i32,
type_: sys::cufftType,
stream: Arc<CudaStream>,
) -> Result<Self, result::CufftError> {
let ctx = stream.context();
ctx.record_err(ctx.bind_to_thread());
let handle = result::plan_3d(nx, ny, nz, type_)?;
unsafe { result::set_stream(handle, stream.cu_stream() as _) }?;
Ok(Self { handle, stream })
}
#[allow(clippy::too_many_arguments)]
pub fn plan_many(
n: &[c_int],
inembed: Option<&[c_int]>,
istride: i32,
idist: i32,
onembed: Option<&[c_int]>,
ostride: i32,
odist: i32,
type_: sys::cufftType,
batch: i32,
stream: Arc<CudaStream>,
) -> Result<Self, result::CufftError> {
let ctx = stream.context();
ctx.record_err(ctx.bind_to_thread());
let rank = n.len() as c_int;
let inembed_ptr = match inembed {
Some(slice) => slice.as_ptr() as *mut c_int,
None => std::ptr::null_mut(),
};
let onembed_ptr = match onembed {
Some(slice) => slice.as_ptr() as *mut c_int,
None => std::ptr::null_mut(),
};
let handle = unsafe {
result::plan_many(
rank,
n.as_ptr() as *mut c_int,
inembed_ptr,
istride,
idist,
onembed_ptr,
ostride,
odist,
type_,
batch,
)
}?;
unsafe { result::set_stream(handle, stream.cu_stream() as _) }?;
Ok(Self { handle, stream })
}
pub fn handle(&self) -> sys::cufftHandle {
self.handle
}
pub unsafe fn set_stream(&mut self, stream: Arc<CudaStream>) -> Result<(), result::CufftError> {
self.stream = stream;
result::set_stream(self.handle, self.stream.cu_stream() as _)
}
pub fn exec_r2c<Src: DevicePtr<f32>, Dst: DevicePtrMut<sys::float2>>(
&self,
input: &Src,
output: &mut Dst,
) -> Result<(), result::CufftError> {
let (idata, _record_src) = input.device_ptr(&self.stream);
let (odata, _record_dst) = output.device_ptr_mut(&self.stream);
unsafe {
result::exec_r2c(
self.handle,
idata as *mut sys::cufftReal,
odata as *mut sys::cufftComplex,
)
}
}
pub fn exec_c2r<Src: DevicePtrMut<sys::float2>, Dst: DevicePtrMut<f32>>(
&self,
input: &mut Src,
output: &mut Dst,
) -> Result<(), result::CufftError> {
let (idata, _record_src) = input.device_ptr_mut(&self.stream);
let (odata, _record_dst) = output.device_ptr_mut(&self.stream);
unsafe {
result::exec_c2r(
self.handle,
idata as *mut sys::cufftComplex,
odata as *mut sys::cufftReal,
)
}
}
pub fn exec_c2c<Src: DevicePtrMut<sys::float2>, Dst: DevicePtrMut<sys::float2>>(
&self,
input: &mut Src,
output: &mut Dst,
direction: FftDirection,
) -> Result<(), result::CufftError> {
let (idata, _record_src) = input.device_ptr_mut(&self.stream);
let (odata, _record_dst) = output.device_ptr_mut(&self.stream);
unsafe {
result::exec_c2c(
self.handle,
idata as *mut sys::cufftComplex,
odata as *mut sys::cufftComplex,
direction as c_int,
)
}
}
pub fn exec_d2z<Src: DevicePtr<f64>, Dst: DevicePtrMut<sys::double2>>(
&self,
input: &Src,
output: &mut Dst,
) -> Result<(), result::CufftError> {
let (idata, _record_src) = input.device_ptr(&self.stream);
let (odata, _record_dst) = output.device_ptr_mut(&self.stream);
unsafe {
result::exec_d2z(
self.handle,
idata as *mut sys::cufftDoubleReal,
odata as *mut sys::cufftDoubleComplex,
)
}
}
pub fn exec_z2d<Src: DevicePtrMut<sys::double2>, Dst: DevicePtrMut<f64>>(
&self,
input: &mut Src,
output: &mut Dst,
) -> Result<(), result::CufftError> {
let (idata, _record_src) = input.device_ptr_mut(&self.stream);
let (odata, _record_dst) = output.device_ptr_mut(&self.stream);
unsafe {
result::exec_z2d(
self.handle,
idata as *mut sys::cufftDoubleComplex,
odata as *mut sys::cufftDoubleReal,
)
}
}
pub fn exec_z2z<Src: DevicePtrMut<sys::double2>, Dst: DevicePtrMut<sys::double2>>(
&self,
input: &mut Src,
output: &mut Dst,
direction: FftDirection,
) -> Result<(), result::CufftError> {
let (idata, _record_src) = input.device_ptr_mut(&self.stream);
let (odata, _record_dst) = output.device_ptr_mut(&self.stream);
unsafe {
result::exec_z2z(
self.handle,
idata as *mut sys::cufftDoubleComplex,
odata as *mut sys::cufftDoubleComplex,
direction as c_int,
)
}
}
}
impl Drop for CudaFft {
fn drop(&mut self) {
let handle = std::mem::replace(&mut self.handle, 0);
if handle != 0 {
unsafe { result::destroy(handle) }.unwrap();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cufft::sys;
use crate::driver::*;
use std::vec::Vec;
#[test]
fn test_plan_1d_c2c() {
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let _fft = CudaFft::plan_1d(256, sys::cufftType::CUFFT_C2C, 1, stream.clone()).unwrap();
}
#[test]
fn test_plan_2d_r2c() {
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let _fft = CudaFft::plan_2d(64, 64, sys::cufftType::CUFFT_R2C, stream.clone()).unwrap();
}
#[test]
fn test_plan_many_r2c_batched() {
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let height = 4i32;
let width = 8i32;
let batch = 2i32;
let w_half = width / 2 + 1;
let _fft = CudaFft::plan_many(
&[height, width],
Some(&[height, width]),
1,
height * width,
Some(&[height, w_half]),
1,
height * w_half,
sys::cufftType::CUFFT_R2C,
batch,
stream.clone(),
)
.unwrap();
}
#[test]
fn test_exec_c2c_roundtrip() {
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let n = 4;
let fft = CudaFft::plan_1d(n as i32, sys::cufftType::CUFFT_C2C, 1, stream.clone()).unwrap();
let input_data: Vec<sys::float2> = (1..=n)
.map(|i| sys::float2 {
x: i as f32,
y: 0.0,
})
.collect();
let input_dev = stream.clone_htod(&input_data).unwrap();
let mut freq_dev = stream.alloc_zeros::<sys::float2>(n).unwrap();
let mut output_dev = stream.alloc_zeros::<sys::float2>(n).unwrap();
let mut input_dev = input_dev;
fft.exec_c2c(&mut input_dev, &mut freq_dev, FftDirection::Forward)
.unwrap();
fft.exec_c2c(&mut freq_dev, &mut output_dev, FftDirection::Inverse)
.unwrap();
let output: Vec<sys::float2> = stream.clone_dtoh(&output_dev).unwrap();
for i in 0..n {
let expected = input_data[i].x * n as f32;
assert!(
(output[i].x - expected).abs() < 1e-3,
"real mismatch at {i}: got {} expected {expected}",
output[i].x
);
assert!(
output[i].y.abs() < 1e-3,
"imag mismatch at {i}: got {}",
output[i].y
);
}
}
#[test]
fn test_exec_r2c_c2r_roundtrip() {
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let n = 8;
let n_complex = n / 2 + 1;
let fft_r2c =
CudaFft::plan_1d(n as i32, sys::cufftType::CUFFT_R2C, 1, stream.clone()).unwrap();
let fft_c2r =
CudaFft::plan_1d(n as i32, sys::cufftType::CUFFT_C2R, 1, stream.clone()).unwrap();
let input_data: Vec<f32> = (0..n).map(|i| (i + 1) as f32).collect();
let input_dev = stream.clone_htod(&input_data).unwrap();
let mut freq_dev = stream.alloc_zeros::<sys::float2>(n_complex).unwrap();
let mut output_dev = stream.alloc_zeros::<f32>(n).unwrap();
fft_r2c.exec_r2c(&input_dev, &mut freq_dev).unwrap();
fft_c2r.exec_c2r(&mut freq_dev, &mut output_dev).unwrap();
let output: Vec<f32> = stream.clone_dtoh(&output_dev).unwrap();
for i in 0..n {
let expected = input_data[i] * n as f32;
assert!(
(output[i] - expected).abs() < 1e-2,
"mismatch at {i}: got {} expected {expected}",
output[i]
);
}
}
}