use anyhow::Result;
use cudarc::driver::{CudaStream, CudaView, CudaViewMut, DevicePtr, DevicePtrMut};
use std::{ffi::c_int, sync::Arc};
pub mod bindings {
#![allow(non_camel_case_types)]
#![allow(non_upper_case_globals)]
#![allow(non_snake_case)]
#![allow(unused_imports)]
#![allow(dead_code)]
#![allow(clippy::approx_constant)]
include!(concat!(env!("OUT_DIR"), "/cufft_bindings.rs"));
}
macro_rules! ensure_cufft_success {
($ret:expr, $fn: expr) => {
anyhow::ensure!(
$ret == crate::cufft::bindings::cufftResult_t_CUFFT_SUCCESS,
"{} failed: {}",
$fn,
cufft_result_to_str($ret),
)
};
}
#[repr(u8)]
enum CuFFTResult {
Success = 0,
InvalidPlan = 1,
AllocFailed = 2,
InvalidType = 3,
InvalidValue = 4,
InternalError = 5,
ExecFailed = 6,
SetupFailed = 7,
InvalidSize = 8,
UnalignedData = 9,
InvalidDevice = 11,
NoWorkspace = 13,
NotImplemented = 14,
NotSupported = 16,
MissingDependency = 17,
NVRTCFailure = 18,
NVJITLINKFailure = 19,
NVSHMEMFailure = 20,
}
impl From<CuFFTResult> for bindings::cufftResult_t {
fn from(value: CuFFTResult) -> bindings::cufftResult_t {
(value as u8).into()
}
}
fn cufft_result_to_str(result: bindings::cufftResult_t) -> String {
match result {
val if val == bindings::cufftResult_t::from(CuFFTResult::AllocFailed) => {
"alloc failed".to_string()
}
val if val == bindings::cufftResult_t::from(CuFFTResult::ExecFailed) => {
"exec failed".to_string()
}
val if val == bindings::cufftResult_t::from(CuFFTResult::InternalError) => {
"internal error".to_string()
}
val if val == bindings::cufftResult_t::from(CuFFTResult::InvalidDevice) => {
"invalid device".to_string()
}
val if val == bindings::cufftResult_t::from(CuFFTResult::InvalidPlan) => {
"invalid plan".to_string()
}
val if val == bindings::cufftResult_t::from(CuFFTResult::InvalidSize) => {
"invalid size".to_string()
}
val if val == bindings::cufftResult_t::from(CuFFTResult::InvalidType) => {
"invalid type".to_string()
}
val if val == bindings::cufftResult_t::from(CuFFTResult::InvalidValue) => {
"invalid value".to_string()
}
val if val == bindings::cufftResult_t::from(CuFFTResult::MissingDependency) => {
"missing dependency".to_string()
}
val if val == bindings::cufftResult_t::from(CuFFTResult::NotImplemented) => {
"not implemented".to_string()
}
val if val == bindings::cufftResult_t::from(CuFFTResult::NotSupported) => {
"not supported".to_string()
}
val if val == bindings::cufftResult_t::from(CuFFTResult::NoWorkspace) => {
"no workspace".to_string()
}
val if val == bindings::cufftResult_t::from(CuFFTResult::NVJITLINKFailure) => {
"nvjitlink failure".to_string()
}
val if val == bindings::cufftResult_t::from(CuFFTResult::NVRTCFailure) => {
"nvrtc failure".to_string()
}
val if val == bindings::cufftResult_t::from(CuFFTResult::NVSHMEMFailure) => {
"nvshmem failure".to_string()
}
val if val == bindings::cufftResult_t::from(CuFFTResult::SetupFailed) => {
"setup failed".to_string()
}
val if val == bindings::cufftResult_t::from(CuFFTResult::Success) => "success".to_string(),
val if val == bindings::cufftResult_t::from(CuFFTResult::UnalignedData) => {
"unaligned data".to_string()
}
_ => format!("unknown return code {}", result),
}
}
#[derive(Debug)]
pub struct CuFFTPlan {
plan: bindings::cufftHandle,
stream: Arc<CudaStream>,
required_input_len: usize,
required_output_len: usize,
}
impl Drop for CuFFTPlan {
fn drop(&mut self) {
unsafe { bindings::cufftDestroy(self.plan) };
}
}
impl CuFFTPlan {
pub fn new_c2c_batch(
stream: &Arc<CudaStream>,
nfft: usize,
batch_size: usize,
) -> Result<CuFFTPlan> {
CuFFTPlan::new_c2c_batch_with_idist(stream, nfft, batch_size, nfft)
}
pub fn new_c2c_batch_with_idist(
stream: &Arc<CudaStream>,
nfft: usize,
batch_size: usize,
input_distance: usize,
) -> Result<CuFFTPlan> {
CuFFTPlan::new_c2c_batch_with_idist_and_odist(
stream,
nfft,
batch_size,
input_distance,
nfft,
)
}
pub fn new_c2c_batch_with_odist(
stream: &Arc<CudaStream>,
nfft: usize,
batch_size: usize,
output_distance: usize,
) -> Result<CuFFTPlan> {
CuFFTPlan::new_c2c_batch_with_idist_and_odist(
stream,
nfft,
batch_size,
nfft,
output_distance,
)
}
pub fn new_c2c_batch_with_idist_and_odist(
stream: &Arc<CudaStream>,
nfft: usize,
batch_size: usize,
input_distance: usize,
output_distance: usize,
) -> Result<CuFFTPlan> {
let mut plan = Default::default();
let plan_ref = &mut plan;
let mut n = [c_int::try_from(nfft).unwrap()];
let n_ref = &mut n[0];
let required_input_len = if batch_size >= 1 {
(batch_size - 1) * input_distance + nfft
} else {
0
};
let mut inembed_array = [c_int::try_from(required_input_len).unwrap()];
let inembed_array_ref = &mut inembed_array[0];
let required_output_len = if batch_size >= 1 {
(batch_size - 1) * output_distance + nfft
} else {
0
};
let mut onembed_array = [c_int::try_from(required_output_len).unwrap()];
let onembed_array_ref = &mut onembed_array[0];
let dimensionality = 1;
let ret = unsafe {
bindings::cufftPlanMany(
std::ptr::from_mut(plan_ref),
dimensionality,
std::ptr::from_mut(n_ref),
std::ptr::from_mut(inembed_array_ref),
1,
c_int::try_from(input_distance).unwrap(),
std::ptr::from_mut(onembed_array_ref),
1,
c_int::try_from(output_distance).unwrap(),
bindings::cufftType_t_CUFFT_C2C,
c_int::try_from(batch_size).unwrap(),
)
};
ensure_cufft_success!(ret, "cufftPlanMany");
let ret = unsafe { bindings::cufftSetStream(plan, stream.cu_stream().cast::<_>()) };
ensure_cufft_success!(ret, "cufftSetStream");
Ok(CuFFTPlan {
plan,
stream: Arc::clone(stream),
required_input_len,
required_output_len,
})
}
pub fn execute_c2c(
&self,
input: CudaView<'_, f32>,
mut output: CudaViewMut<'_, f32>,
direction: c_int,
) -> Result<()> {
anyhow::ensure!(
input.len() == 2 * self.required_input_len,
"input slice does not have required length; got {} need {}",
input.len(),
2 * self.required_input_len
);
anyhow::ensure!(
output.len() == 2 * self.required_output_len,
"output slice does not have required length; got {} need {}",
output.len(),
2 * self.required_output_len
);
let (input_ptr, _input_sync) = input.device_ptr(&self.stream);
let (output_ptr, _output_sync) = output.device_ptr_mut(&self.stream);
let ret = unsafe {
bindings::cufftExecC2C(
self.plan,
input_ptr as *mut bindings::cufftComplex,
output_ptr as *mut bindings::cufftComplex,
direction,
)
};
ensure_cufft_success!(ret, "cufftExecC2C");
Ok(())
}
pub fn execute_c2c_forward(
&self,
input: CudaView<'_, f32>,
output: CudaViewMut<'_, f32>,
) -> Result<()> {
self.execute_c2c(input, output, bindings::CUFFT_FORWARD)
}
pub fn execute_c2c_inverse(
&self,
input: CudaView<'_, f32>,
output: CudaViewMut<'_, f32>,
) -> Result<()> {
self.execute_c2c(
input,
output,
c_int::try_from(bindings::CUFFT_INVERSE).unwrap(),
)
}
}