use std::any::Any;
use std::ffi::c_void;
use std::num::NonZeroUsize;
use std::sync::Arc;
use async_trait::async_trait;
use atomr_core::actor::{Actor, Context, Props};
use cudarc::cufft::sys as cufft_sys;
use cudarc::cufft::{CudaFft, FftDirection as CudarcFftDirection};
use lru::LruCache;
use parking_lot::Mutex;
use tokio::sync::oneshot;
use crate::completion::CompletionStrategy;
use crate::device::DeviceState;
use crate::dtype::{DType, FftSupported};
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
use crate::kernel::dispatch::{FftDispatch, FftDispatchCtx};
use crate::kernel::envelope;
use crate::stream::StreamAllocator;
use crate::sys::cufft as sys_cufft;
const LIB: &str = "cufft";
const DEFAULT_CACHE_SIZE: usize = 64;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FftDirection {
Forward,
Inverse,
}
impl FftDirection {
pub(crate) fn cudarc(self) -> CudarcFftDirection {
match self {
FftDirection::Forward => CudarcFftDirection::Forward,
FftDirection::Inverse => CudarcFftDirection::Inverse,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FftKind {
R2C,
C2R,
C2C,
D2Z,
Z2D,
Z2Z,
}
impl FftKind {
#[allow(non_upper_case_globals)]
pub const R2cF32: FftKind = FftKind::R2C;
#[allow(non_upper_case_globals)]
pub const C2rF32: FftKind = FftKind::C2R;
#[allow(non_upper_case_globals)]
pub const C2cF32: FftKind = FftKind::C2C;
pub fn cufft_type(self) -> cufft_sys::cufftType {
match self {
FftKind::R2C => cufft_sys::cufftType::CUFFT_R2C,
FftKind::C2R => cufft_sys::cufftType::CUFFT_C2R,
FftKind::C2C => cufft_sys::cufftType::CUFFT_C2C,
FftKind::D2Z => cufft_sys::cufftType::CUFFT_D2Z,
FftKind::Z2D => cufft_sys::cufftType::CUFFT_Z2D,
FftKind::Z2Z => cufft_sys::cufftType::CUFFT_Z2Z,
}
}
pub fn scalar_dtype(self) -> DType {
match self {
FftKind::R2C | FftKind::C2R | FftKind::C2C => DType::F32,
FftKind::D2Z | FftKind::Z2D | FftKind::Z2Z => DType::F64,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PlanKey {
pub rank: u32,
pub dims: [i32; 3],
pub kind: FftKind,
pub dtype: DType,
pub batch: i32,
pub many_layout: Option<u64>,
}
impl PlanKey {
pub fn plan_1d(n: i32, kind: FftKind, batch: i32) -> Self {
Self {
rank: 1,
dims: [n, 0, 0],
kind,
dtype: kind.scalar_dtype(),
batch,
many_layout: None,
}
}
pub fn plan_2d(nx: i32, ny: i32, kind: FftKind) -> Self {
Self {
rank: 2,
dims: [nx, ny, 0],
kind,
dtype: kind.scalar_dtype(),
batch: 1,
many_layout: None,
}
}
pub fn plan_3d(nx: i32, ny: i32, nz: i32, kind: FftKind) -> Self {
Self {
rank: 3,
dims: [nx, ny, nz],
kind,
dtype: kind.scalar_dtype(),
batch: 1,
many_layout: None,
}
}
}
#[derive(Debug, Clone)]
pub struct FftPlanMany {
pub rank: u32,
pub dims: [i32; 3],
pub in_embed: Option<[i32; 3]>,
pub in_stride: i32,
pub in_dist: i32,
pub out_embed: Option<[i32; 3]>,
pub out_stride: i32,
pub out_dist: i32,
pub kind: FftKind,
pub batch: i32,
}
impl FftPlanMany {
pub fn layout_seed(&self) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut h = DefaultHasher::new();
self.in_embed.hash(&mut h);
self.in_stride.hash(&mut h);
self.in_dist.hash(&mut h);
self.out_embed.hash(&mut h);
self.out_stride.hash(&mut h);
self.out_dist.hash(&mut h);
h.finish()
}
pub fn key(&self) -> PlanKey {
PlanKey {
rank: self.rank,
dims: self.dims,
kind: self.kind,
dtype: self.kind.scalar_dtype(),
batch: self.batch,
many_layout: Some(self.layout_seed()),
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum FftCallbackKind {
LoadComplex,
LoadComplexDouble,
LoadReal,
LoadRealDouble,
StoreComplex,
StoreComplexDouble,
StoreReal,
StoreRealDouble,
}
impl FftCallbackKind {
fn sys(self) -> sys_cufft::CufftXtCallbackType {
use sys_cufft::CufftXtCallbackType as T;
match self {
FftCallbackKind::LoadComplex => T::LoadComplex,
FftCallbackKind::LoadComplexDouble => T::LoadComplexDouble,
FftCallbackKind::LoadReal => T::LoadReal,
FftCallbackKind::LoadRealDouble => T::LoadRealDouble,
FftCallbackKind::StoreComplex => T::StoreComplex,
FftCallbackKind::StoreComplexDouble => T::StoreComplexDouble,
FftCallbackKind::StoreReal => T::StoreReal,
FftCallbackKind::StoreRealDouble => T::StoreRealDouble,
}
}
}
#[derive(Clone)]
pub struct FftPlan {
pub key: PlanKey,
inner: Arc<CudaFft>,
}
impl std::fmt::Debug for FftPlan {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FftPlan").field("key", &self.key).finish()
}
}
impl FftPlan {
pub fn key(&self) -> PlanKey {
self.key
}
pub unsafe fn with_callback(
&self,
kind: FftCallbackKind,
cb: *mut c_void,
caller_info: *mut c_void,
) -> Result<(), GpuError> {
let res = sys_cufft::xt_set_callback(self.inner.handle(), cb, kind.sys(), caller_info);
match res.result() {
Ok(()) => Ok(()),
Err(e) => Err(GpuError::LibraryError {
lib: LIB,
msg: format!("cufftXtSetCallback({kind:?}): {e:?}"),
}),
}
}
pub unsafe fn with_load_callback(
&self,
kind: FftCallbackKind,
cb: *mut c_void,
caller_info: *mut c_void,
) -> Result<(), GpuError> {
debug_assert!(matches!(
kind,
FftCallbackKind::LoadComplex
| FftCallbackKind::LoadComplexDouble
| FftCallbackKind::LoadReal
| FftCallbackKind::LoadRealDouble
));
self.with_callback(kind, cb, caller_info)
}
pub unsafe fn with_store_callback(
&self,
kind: FftCallbackKind,
cb: *mut c_void,
caller_info: *mut c_void,
) -> Result<(), GpuError> {
debug_assert!(matches!(
kind,
FftCallbackKind::StoreComplex
| FftCallbackKind::StoreComplexDouble
| FftCallbackKind::StoreReal
| FftCallbackKind::StoreRealDouble
));
self.with_callback(kind, cb, caller_info)
}
}
pub struct FftRequest<T: FftSupported, I = u8, O = u8> {
pub plan_key: PlanKey,
pub direction: FftDirection,
pub input: GpuRef<I>,
pub output: GpuRef<O>,
pub reply: oneshot::Sender<Result<(), GpuError>>,
_scalar: std::marker::PhantomData<T>,
}
impl<T: FftSupported, I, O> FftRequest<T, I, O> {
pub fn new(
plan_key: PlanKey,
direction: FftDirection,
input: GpuRef<I>,
output: GpuRef<O>,
reply: oneshot::Sender<Result<(), GpuError>>,
) -> Self {
Self {
plan_key,
direction,
input,
output,
reply,
_scalar: std::marker::PhantomData,
}
}
}
impl<T, I, O> FftDispatch for FftRequest<T, I, O>
where
T: FftSupported,
I: Send + Sync + 'static,
O: Send + Sync + 'static,
{
fn dtype_kind(&self) -> DType {
T::KIND
}
fn plan_key(&self) -> PlanKey {
self.plan_key
}
fn dispatch(self: Box<Self>, ctx: &FftDispatchCtx<'_>) {
let plan = match ctx.plan.clone().downcast::<CudaFft>() {
Ok(p) => p,
Err(_) => {
let _ = self.reply.send(Err(GpuError::Unrecoverable(
"FftDispatchCtx.plan downcast to CudaFft failed".into(),
)));
return;
}
};
let stream = ctx.stream.clone();
let stream_for_exec = stream.clone();
let completion = ctx.completion.clone();
let kind = self.plan_key.kind;
let direction = self.direction;
let (src_arc, dst_arc) = match envelope::access_all_2(&self.input, &self.output) {
Ok(t) => t,
Err(e) => {
let _ = self.reply.send(Err(e));
return;
}
};
self.output.record_write(&stream);
let reply = self.reply;
envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
let res = unsafe {
exec_kernel(&plan, &src_arc, &dst_arc, kind, direction, &stream_for_exec)
};
res.map(|_| (src_arc, dst_arc, plan))
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("exec_{:?}: {:?}", kind, e),
})
});
}
}
unsafe fn exec_kernel<I, O>(
plan: &Arc<CudaFft>,
src: &Arc<cudarc::driver::CudaSlice<I>>,
dst: &Arc<cudarc::driver::CudaSlice<O>>,
kind: FftKind,
direction: FftDirection,
stream: &Arc<cudarc::driver::CudaStream>,
) -> Result<(), cudarc::cufft::result::CufftError> {
use cudarc::driver::DevicePtr;
let (src_ptr, _src_rec) = src.device_ptr(stream);
let (dst_ptr, _dst_rec) = dst.device_ptr(stream);
let src_ptr = src_ptr as *mut c_void;
let dst_ptr = dst_ptr as *mut c_void;
let h = plan.handle();
use cudarc::cufft::sys as s;
let r = match kind {
FftKind::R2C => s::cufftExecR2C(
h,
src_ptr as *mut s::cufftReal,
dst_ptr as *mut s::cufftComplex,
),
FftKind::C2R => s::cufftExecC2R(
h,
src_ptr as *mut s::cufftComplex,
dst_ptr as *mut s::cufftReal,
),
FftKind::C2C => s::cufftExecC2C(
h,
src_ptr as *mut s::cufftComplex,
dst_ptr as *mut s::cufftComplex,
direction.cudarc() as i32,
),
FftKind::D2Z => s::cufftExecD2Z(
h,
src_ptr as *mut s::cufftDoubleReal,
dst_ptr as *mut s::cufftDoubleComplex,
),
FftKind::Z2D => s::cufftExecZ2D(
h,
src_ptr as *mut s::cufftDoubleComplex,
dst_ptr as *mut s::cufftDoubleReal,
),
FftKind::Z2Z => s::cufftExecZ2Z(
h,
src_ptr as *mut s::cufftDoubleComplex,
dst_ptr as *mut s::cufftDoubleComplex,
direction.cudarc() as i32,
),
};
r.result()
}
#[allow(deprecated)]
pub enum FftMsg {
Exec(Box<dyn FftDispatch>),
#[deprecated(note = "use FftMsg::Exec with FftRequest<f32> { kind: R2C, .. }")]
Forward1dR2C {
n: i32,
batch: i32,
src: GpuRef<f32>,
dst: GpuRef<cufft_sys::float2>,
reply: oneshot::Sender<Result<(), GpuError>>,
},
#[deprecated(note = "use FftMsg::Exec with FftRequest<f32> { kind: C2R, .. }")]
Inverse1dC2R {
n: i32,
batch: i32,
src: GpuRef<cufft_sys::float2>,
dst: GpuRef<f32>,
reply: oneshot::Sender<Result<(), GpuError>>,
},
#[deprecated(note = "use FftMsg::Exec with FftRequest<f32> { kind: C2C, .. }")]
Exec1dC2C {
n: i32,
batch: i32,
direction: CudarcFftDirection,
src: GpuRef<cufft_sys::float2>,
dst: GpuRef<cufft_sys::float2>,
reply: oneshot::Sender<Result<(), GpuError>>,
},
#[deprecated(note = "use FftMsg::Exec with FftRequest<f32> { kind: R2C, rank=2, .. }")]
Forward2dR2C {
nx: i32,
ny: i32,
src: GpuRef<f32>,
dst: GpuRef<cufft_sys::float2>,
reply: oneshot::Sender<Result<(), GpuError>>,
},
}
pub struct FftActor {
inner: FftInner,
}
struct PlanCache {
cache: LruCache<PlanKey, Arc<CudaFft>>,
}
impl PlanCache {
fn new(cap: NonZeroUsize) -> Self {
Self {
cache: LruCache::new(cap),
}
}
}
enum FftInner {
Real {
stream: Arc<cudarc::driver::CudaStream>,
completion: Arc<dyn CompletionStrategy>,
plans: Mutex<PlanCache>,
#[allow(dead_code)]
state: Arc<DeviceState>,
},
Mock,
}
impl FftActor {
pub fn props(
stream: Arc<cudarc::driver::CudaStream>,
_allocator: Arc<dyn StreamAllocator>,
completion: Arc<dyn CompletionStrategy>,
state: Arc<DeviceState>,
_ctx: Arc<cudarc::driver::CudaContext>,
) -> Props<Self> {
Props::create(move || FftActor {
inner: FftInner::Real {
stream: stream.clone(),
completion: completion.clone(),
plans: Mutex::new(PlanCache::new(
NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap(),
)),
state: state.clone(),
},
})
}
pub fn mock_props() -> Props<Self> {
Props::create(|| FftActor {
inner: FftInner::Mock,
})
}
}
impl FftActor {
pub fn ensure_plan(&self, key: PlanKey) -> Result<FftPlan, GpuError> {
let arc = self.get_or_create_plan(key)?;
Ok(FftPlan { key, inner: arc })
}
pub fn ensure_plan_many(&self, builder: &FftPlanMany) -> Result<FftPlan, GpuError> {
let key = builder.key();
let FftInner::Real { stream, plans, .. } = &self.inner else {
return Err(GpuError::Unrecoverable("fft mock".into()));
};
{
let mut g = plans.lock();
if let Some(plan) = g.cache.get(&key) {
return Ok(FftPlan {
key,
inner: plan.clone(),
});
}
}
let plan = build_plan_many(builder, stream).map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("plan_many {key:?}: {e}"),
})?;
let plan = Arc::new(plan);
{
let mut g = plans.lock();
g.cache.put(key, plan.clone());
}
Ok(FftPlan { key, inner: plan })
}
fn get_or_create_plan(&self, key: PlanKey) -> Result<Arc<CudaFft>, GpuError> {
let FftInner::Real { stream, plans, .. } = &self.inner else {
return Err(GpuError::Unrecoverable("fft mock".into()));
};
{
let mut g = plans.lock();
if let Some(plan) = g.cache.get(&key) {
return Ok(plan.clone());
}
}
let plan = build_simple_plan(&key, stream).map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("plan {key:?}: {e}"),
})?;
let plan = Arc::new(plan);
{
let mut g = plans.lock();
g.cache.put(key, plan.clone());
}
Ok(plan)
}
}
fn build_simple_plan(
key: &PlanKey,
stream: &Arc<cudarc::driver::CudaStream>,
) -> Result<CudaFft, cudarc::cufft::result::CufftError> {
match key.rank {
1 => CudaFft::plan_1d(
key.dims[0],
key.kind.cufft_type(),
key.batch,
stream.clone(),
),
2 => CudaFft::plan_2d(
key.dims[0],
key.dims[1],
key.kind.cufft_type(),
stream.clone(),
),
3 => CudaFft::plan_3d(
key.dims[0],
key.dims[1],
key.dims[2],
key.kind.cufft_type(),
stream.clone(),
),
_ => CudaFft::plan_1d(1, key.kind.cufft_type(), 1, stream.clone()),
}
}
fn build_plan_many(
b: &FftPlanMany,
stream: &Arc<cudarc::driver::CudaStream>,
) -> Result<CudaFft, cudarc::cufft::result::CufftError> {
let n: &[i32] = &b.dims[..b.rank as usize];
let in_embed = b.in_embed;
let out_embed = b.out_embed;
let inembed: Option<&[i32]> = in_embed.as_ref().map(|e| &e[..b.rank as usize]);
let onembed: Option<&[i32]> = out_embed.as_ref().map(|e| &e[..b.rank as usize]);
CudaFft::plan_many(
n,
inembed,
b.in_stride,
b.in_dist,
onembed,
b.out_stride,
b.out_dist,
b.kind.cufft_type(),
b.batch,
stream.clone(),
)
}
#[allow(deprecated)]
#[async_trait]
impl Actor for FftActor {
type Msg = FftMsg;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: FftMsg) {
let (stream, completion) = match &self.inner {
FftInner::Mock => {
reply_mock(msg);
return;
}
FftInner::Real {
stream, completion, ..
} => (stream.clone(), completion.clone()),
};
match msg {
FftMsg::Exec(req) => {
let key = req.plan_key();
let plan_arc = match self.get_or_create_plan(key) {
Ok(p) => p,
Err(_e) => {
let dummy: Arc<dyn Any + Send + Sync> = Arc::new(());
let dispatch_ctx = FftDispatchCtx {
stream: &stream,
completion: &completion,
plan: dummy,
};
req.dispatch(&dispatch_ctx);
return;
}
};
let plan_any: Arc<dyn Any + Send + Sync> = plan_arc;
let dispatch_ctx = FftDispatchCtx {
stream: &stream,
completion: &completion,
plan: plan_any,
};
req.dispatch(&dispatch_ctx);
}
FftMsg::Forward1dR2C {
n,
batch,
src,
dst,
reply,
} => {
let plan = match self.get_or_create_plan(PlanKey::plan_1d(n, FftKind::R2C, batch)) {
Ok(p) => p,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let (src_slice, dst_slice) = match envelope::access_all_2(&src, &dst) {
Ok(t) => t,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let mut dst_owned = match Arc::try_unwrap(dst_slice) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"FFT dst has multiple live references".into(),
)));
return;
}
};
dst.record_write(&stream);
envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
plan.exec_r2c(&*src_slice, &mut dst_owned)
.map(|_| (src_slice, dst_owned, plan))
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("exec_r2c: {e}"),
})
});
}
FftMsg::Inverse1dC2R {
n,
batch,
src,
dst,
reply,
} => {
let plan = match self.get_or_create_plan(PlanKey::plan_1d(n, FftKind::C2R, batch)) {
Ok(p) => p,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let (src_slice, dst_slice) = match envelope::access_all_2(&src, &dst) {
Ok(t) => t,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let mut src_owned = match Arc::try_unwrap(src_slice) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"FFT C2R src has multiple live references".into(),
)));
return;
}
};
let mut dst_owned = match Arc::try_unwrap(dst_slice) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"FFT C2R dst has multiple live references".into(),
)));
return;
}
};
dst.record_write(&stream);
envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
plan.exec_c2r(&mut src_owned, &mut dst_owned)
.map(|_| (src_owned, dst_owned, plan))
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("exec_c2r: {e}"),
})
});
}
FftMsg::Exec1dC2C {
n,
batch,
direction,
src,
dst,
reply,
} => {
let plan = match self.get_or_create_plan(PlanKey::plan_1d(n, FftKind::C2C, batch)) {
Ok(p) => p,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let (src_slice, dst_slice) = match envelope::access_all_2(&src, &dst) {
Ok(t) => t,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let mut src_owned = match Arc::try_unwrap(src_slice) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"FFT C2C src has multiple live references".into(),
)));
return;
}
};
let mut dst_owned = match Arc::try_unwrap(dst_slice) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"FFT C2C dst has multiple live references".into(),
)));
return;
}
};
dst.record_write(&stream);
envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
plan.exec_c2c(&mut src_owned, &mut dst_owned, direction)
.map(|_| (src_owned, dst_owned, plan))
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("exec_c2c: {e}"),
})
});
}
FftMsg::Forward2dR2C {
nx,
ny,
src,
dst,
reply,
} => {
let plan = match self.get_or_create_plan(PlanKey::plan_2d(nx, ny, FftKind::R2C)) {
Ok(p) => p,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let (src_slice, dst_slice) = match envelope::access_all_2(&src, &dst) {
Ok(t) => t,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let mut dst_owned = match Arc::try_unwrap(dst_slice) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"FFT 2D dst has multiple live references".into(),
)));
return;
}
};
dst.record_write(&stream);
envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
plan.exec_r2c(&*src_slice, &mut dst_owned)
.map(|_| (src_slice, dst_owned, plan))
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("exec_r2c (2d): {e}"),
})
});
}
}
}
}
#[allow(deprecated)]
fn reply_mock(msg: FftMsg) {
let err = || GpuError::Unrecoverable("FftActor in mock mode".into());
match msg {
FftMsg::Exec(req) => {
drop(req);
}
FftMsg::Forward1dR2C { reply, .. } => {
let _ = reply.send(Err(err()));
}
FftMsg::Inverse1dC2R { reply, .. } => {
let _ = reply.send(Err(err()));
}
FftMsg::Exec1dC2C { reply, .. } => {
let _ = reply.send(Err(err()));
}
FftMsg::Forward2dR2C { reply, .. } => {
let _ = reply.send(Err(err()));
}
}
}
#[cfg(test)]
mod tests {
#![allow(deprecated)]
use super::*;
#[cfg(feature = "f16")]
use crate::dtype::CudaDtype;
#[test]
fn plan_key_for_simple_plans_zeroes_unused_dims() {
let k1 = PlanKey::plan_1d(1024, FftKind::R2C, 1);
assert_eq!(k1.rank, 1);
assert_eq!(k1.dims, [1024, 0, 0]);
assert_eq!(k1.dtype, DType::F32);
assert!(k1.many_layout.is_none());
let k2 = PlanKey::plan_2d(64, 64, FftKind::R2C);
assert_eq!(k2.rank, 2);
assert_eq!(k2.dims, [64, 64, 0]);
assert_eq!(k2.dtype, DType::F32);
let k3 = PlanKey::plan_3d(32, 32, 32, FftKind::Z2Z);
assert_eq!(k3.rank, 3);
assert_eq!(k3.dims, [32, 32, 32]);
assert_eq!(k3.dtype, DType::F64);
}
#[test]
fn fft_3d_plan_dim_handling() {
let k = PlanKey::plan_3d(8, 16, 32, FftKind::C2C);
assert_eq!(k.rank, 3);
assert_eq!(k.dims[0], 8);
assert_eq!(k.dims[1], 16);
assert_eq!(k.dims[2], 32);
assert_eq!(k.kind, FftKind::C2C);
}
#[test]
fn plan_many_descriptor_correct() {
let many = FftPlanMany {
rank: 2,
dims: [4, 8, 0],
in_embed: Some([4, 8, 0]),
in_stride: 1,
in_dist: 32,
out_embed: Some([4, 5, 0]),
out_stride: 1,
out_dist: 20,
kind: FftKind::R2C,
batch: 2,
};
let key = many.key();
assert_eq!(key.rank, 2);
assert_eq!(key.dims, [4, 8, 0]);
assert_eq!(key.kind, FftKind::R2C);
assert_eq!(key.dtype, DType::F32);
assert_eq!(key.batch, 2);
assert!(
key.many_layout.is_some(),
"plan_many keys must carry a layout discriminator"
);
let mut other = many.clone();
other.in_dist = 64;
let key2 = other.key();
assert_ne!(
key.many_layout, key2.many_layout,
"different in_dist must produce different layout seeds"
);
assert_ne!(key, key2);
}
#[test]
fn plan_cache_hit_miss() {
let cap = NonZeroUsize::new(2).unwrap();
let mut cache: LruCache<PlanKey, ()> = LruCache::new(cap);
let k1 = PlanKey::plan_1d(1024, FftKind::R2C, 1);
let k2 = PlanKey::plan_2d(64, 64, FftKind::C2C);
let k3 = PlanKey::plan_3d(8, 8, 8, FftKind::Z2Z);
assert!(cache.get(&k1).is_none());
cache.put(k1, ());
assert!(cache.get(&k1).is_some(), "k1 hit after insert");
cache.put(k2, ());
assert!(cache.get(&k2).is_some());
cache.put(k3, ());
assert!(cache.get(&k3).is_some());
assert!(cache.get(&k1).is_none(), "k1 should have been LRU-evicted");
assert!(cache.get(&k2).is_some());
}
#[test]
fn deprecated_r2c1d_still_constructs() {
fn _shape_check() {
let (tx, _rx) = oneshot::channel::<Result<(), GpuError>>();
fn handle(msg: FftMsg) {
match msg {
FftMsg::Forward1dR2C { .. }
| FftMsg::Inverse1dC2R { .. }
| FftMsg::Exec1dC2C { .. }
| FftMsg::Forward2dR2C { .. } => {}
FftMsg::Exec(_) => {}
}
}
drop(tx);
let _ = handle;
}
_shape_check();
}
#[test]
fn request_round_trip_f32_f64_f16() {
fn check<T: FftSupported>(scalar_kind: DType, transform: FftKind) {
assert_eq!(T::KIND, scalar_kind);
let key = match transform {
FftKind::R2C | FftKind::C2R | FftKind::C2C => PlanKey::plan_1d(8, transform, 1),
FftKind::D2Z | FftKind::Z2D | FftKind::Z2Z => PlanKey::plan_1d(8, transform, 1),
};
assert_eq!(key.dtype, scalar_kind);
assert_eq!(key.kind, transform);
}
check::<f32>(DType::F32, FftKind::R2C);
check::<f32>(DType::F32, FftKind::C2C);
check::<f64>(DType::F64, FftKind::D2Z);
check::<f64>(DType::F64, FftKind::Z2Z);
#[cfg(feature = "f16")]
{
assert_eq!(<half::f16 as atomr_accel::AccelDtype>::KIND, DType::F16);
}
}
#[test]
fn fft_request_implements_fft_dispatch_for_all_dtypes() {
fn assert_dispatch<U: FftDispatch>() {}
assert_dispatch::<FftRequest<f32>>();
assert_dispatch::<FftRequest<f64>>();
#[cfg(feature = "f16")]
assert_dispatch::<FftRequest<half::f16>>();
}
}