#![allow(non_camel_case_types, non_snake_case, dead_code)]
use core::ffi::{c_int, c_void};
use std::sync::OnceLock;
use cudarc::cufft::sys::{cufftHandle, cufftResult, cufftResult_t};
#[repr(i32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CufftXtCallbackType {
LoadComplex = 0x0,
LoadComplexDouble = 0x1,
LoadReal = 0x2,
LoadRealDouble = 0x3,
StoreComplex = 0x4,
StoreComplexDouble = 0x5,
StoreReal = 0x6,
StoreRealDouble = 0x7,
}
type CufftXtSetCallbackFn = unsafe extern "C" fn(
plan: cufftHandle,
callback_routine: *mut *mut c_void,
cb_type: i32,
caller_info: *mut *mut c_void,
) -> cufftResult;
type CufftXtClearCallbackFn = unsafe extern "C" fn(plan: cufftHandle, cb_type: i32) -> cufftResult;
struct XtSyms {
set_cb: Option<libloading::Symbol<'static, CufftXtSetCallbackFn>>,
clear_cb: Option<libloading::Symbol<'static, CufftXtClearCallbackFn>>,
_lib: libloading::Library,
}
unsafe impl Send for XtSyms {}
unsafe impl Sync for XtSyms {}
static XT_SYMS: OnceLock<Result<XtSyms, String>> = OnceLock::new();
#[cfg(target_os = "linux")]
const CUFFT_LIB_CANDIDATES: &[&str] = &["libcufft.so", "libcufft.so.11", "libcufft.so.10"];
#[cfg(target_os = "macos")]
const CUFFT_LIB_CANDIDATES: &[&str] = &["libcufft.dylib"];
#[cfg(target_os = "windows")]
const CUFFT_LIB_CANDIDATES: &[&str] = &["cufft64_11.dll", "cufft64_10.dll", "cufft64_9.dll"];
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
const CUFFT_LIB_CANDIDATES: &[&str] = &[];
fn load_xt_syms() -> Result<XtSyms, String> {
let mut last_err: Option<String> = None;
for cand in CUFFT_LIB_CANDIDATES {
match unsafe { libloading::Library::new(*cand) } {
Ok(lib) => {
let set_cb = unsafe {
lib.get::<CufftXtSetCallbackFn>(b"cufftXtSetCallback\0")
.ok()
.map(|s| {
std::mem::transmute::<
libloading::Symbol<'_, CufftXtSetCallbackFn>,
libloading::Symbol<'static, CufftXtSetCallbackFn>,
>(s)
})
};
let clear_cb = unsafe {
lib.get::<CufftXtClearCallbackFn>(b"cufftXtClearCallback\0")
.ok()
.map(|s| {
std::mem::transmute::<
libloading::Symbol<'_, CufftXtClearCallbackFn>,
libloading::Symbol<'static, CufftXtClearCallbackFn>,
>(s)
})
};
return Ok(XtSyms {
set_cb,
clear_cb,
_lib: lib,
});
}
Err(e) => {
last_err = Some(format!("{cand}: {e}"));
}
}
}
Err(last_err.unwrap_or_else(|| "no libcufft candidates configured".into()))
}
fn xt_syms() -> Result<&'static XtSyms, &'static str> {
let cell = XT_SYMS.get_or_init(load_xt_syms);
match cell {
Ok(s) => Ok(s),
Err(_) => Err("cuFFT shared library not loadable on this host"),
}
}
fn fail_not_supported() -> cufftResult {
cufftResult_t::CUFFT_NOT_SUPPORTED
}
pub unsafe fn xt_set_callback(
plan: cufftHandle,
cb: *mut c_void,
cb_type: CufftXtCallbackType,
caller_info: *mut c_void,
) -> cufftResult {
let syms = match xt_syms() {
Ok(s) => s,
Err(_) => return fail_not_supported(),
};
let f = match &syms.set_cb {
Some(f) => f,
None => return fail_not_supported(),
};
let mut routine = cb;
let mut info = caller_info;
f(plan, &mut routine, cb_type as c_int, &mut info)
}
pub unsafe fn xt_clear_callback(plan: cufftHandle, cb_type: CufftXtCallbackType) -> cufftResult {
let syms = match xt_syms() {
Ok(s) => s,
Err(_) => return fail_not_supported(),
};
let f = match &syms.clear_cb {
Some(f) => f,
None => return fail_not_supported(),
};
f(plan, cb_type as c_int)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn callback_kinds_are_distinct() {
assert_ne!(
CufftXtCallbackType::LoadComplex as i32,
CufftXtCallbackType::StoreComplex as i32
);
assert_ne!(
CufftXtCallbackType::LoadReal as i32,
CufftXtCallbackType::StoreReal as i32
);
assert_ne!(
CufftXtCallbackType::LoadComplex as i32,
CufftXtCallbackType::LoadComplexDouble as i32
);
}
#[test]
fn xt_set_callback_is_safe_to_call_without_gpu() {
let result = unsafe {
xt_set_callback(
0,
std::ptr::null_mut(),
CufftXtCallbackType::LoadComplex,
std::ptr::null_mut(),
)
};
let _ = result;
}
}