use std::num::NonZeroUsize;
use std::sync::Arc;
use async_trait::async_trait;
use atomr_core::actor::{Actor, Context, Props};
use cudarc::cufft::sys as cufft_sys;
use cudarc::cufft::{CudaFft, FftDirection};
use lru::LruCache;
use parking_lot::Mutex;
use tokio::sync::oneshot;
use crate::completion::CompletionStrategy;
use crate::device::DeviceState;
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
use crate::kernel::envelope;
use crate::stream::StreamAllocator;
const LIB: &str = "cufft";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FftKind {
R2cF32,
C2rF32,
C2cF32,
}
impl FftKind {
fn cufft_type(self) -> cufft_sys::cufftType {
match self {
FftKind::R2cF32 => cufft_sys::cufftType::CUFFT_R2C,
FftKind::C2rF32 => cufft_sys::cufftType::CUFFT_C2R,
FftKind::C2cF32 => cufft_sys::cufftType::CUFFT_C2C,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PlanKey {
Plan1d {
n: i32,
kind: FftKind,
batch: i32,
},
Plan2d {
nx: i32,
ny: i32,
kind: FftKind,
},
Plan3d {
nx: i32,
ny: i32,
nz: i32,
kind: FftKind,
},
}
pub enum FftMsg {
Forward1dR2C {
n: i32,
batch: i32,
src: GpuRef<f32>,
dst: GpuRef<cufft_sys::float2>,
reply: oneshot::Sender<Result<(), GpuError>>,
},
Inverse1dC2R {
n: i32,
batch: i32,
src: GpuRef<cufft_sys::float2>,
dst: GpuRef<f32>,
reply: oneshot::Sender<Result<(), GpuError>>,
},
Exec1dC2C {
n: i32,
batch: i32,
direction: FftDirection,
src: GpuRef<cufft_sys::float2>,
dst: GpuRef<cufft_sys::float2>,
reply: oneshot::Sender<Result<(), GpuError>>,
},
Forward2dR2C {
nx: i32,
ny: i32,
src: GpuRef<f32>,
dst: GpuRef<cufft_sys::float2>,
reply: oneshot::Sender<Result<(), GpuError>>,
},
}
pub struct FftActor {
inner: FftInner,
}
struct PlanCache {
cache: LruCache<PlanKey, Arc<CudaFft>>,
}
impl PlanCache {
fn new(cap: NonZeroUsize) -> Self {
Self {
cache: LruCache::new(cap),
}
}
}
enum FftInner {
Real {
stream: Arc<cudarc::driver::CudaStream>,
completion: Arc<dyn CompletionStrategy>,
plans: Mutex<PlanCache>,
#[allow(dead_code)]
state: Arc<DeviceState>,
},
Mock,
}
const DEFAULT_CACHE_SIZE: usize = 64;
impl FftActor {
pub fn props(
stream: Arc<cudarc::driver::CudaStream>,
_allocator: Arc<dyn StreamAllocator>,
completion: Arc<dyn CompletionStrategy>,
state: Arc<DeviceState>,
_ctx: Arc<cudarc::driver::CudaContext>,
) -> Props<Self> {
Props::create(move || FftActor {
inner: FftInner::Real {
stream: stream.clone(),
completion: completion.clone(),
plans: Mutex::new(PlanCache::new(
NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap(),
)),
state: state.clone(),
},
})
}
pub fn mock_props() -> Props<Self> {
Props::create(|| FftActor {
inner: FftInner::Mock,
})
}
}
impl FftActor {
fn get_or_create_plan(&self, key: PlanKey) -> Result<Arc<CudaFft>, GpuError> {
let FftInner::Real { stream, plans, .. } = &self.inner else {
return Err(GpuError::Unrecoverable("fft mock".into()));
};
let mut g = plans.lock();
if let Some(plan) = g.cache.get(&key) {
return Ok(plan.clone());
}
let plan = match key {
PlanKey::Plan1d { n, kind, batch } => {
CudaFft::plan_1d(n, kind.cufft_type(), batch, stream.clone())
}
PlanKey::Plan2d { nx, ny, kind } => {
CudaFft::plan_2d(nx, ny, kind.cufft_type(), stream.clone())
}
PlanKey::Plan3d { nx, ny, nz, kind } => {
CudaFft::plan_3d(nx, ny, nz, kind.cufft_type(), stream.clone())
}
};
let plan = plan.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("plan {key:?}: {e}"),
})?;
let plan = Arc::new(plan);
g.cache.put(key, plan.clone());
Ok(plan)
}
}
#[async_trait]
impl Actor for FftActor {
type Msg = FftMsg;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: FftMsg) {
let (stream, completion) = match &self.inner {
FftInner::Mock => {
reply_mock(msg);
return;
}
FftInner::Real {
stream, completion, ..
} => (stream.clone(), completion.clone()),
};
match msg {
FftMsg::Forward1dR2C {
n,
batch,
src,
dst,
reply,
} => {
let plan = match self.get_or_create_plan(PlanKey::Plan1d {
n,
kind: FftKind::R2cF32,
batch,
}) {
Ok(p) => p,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let (src_slice, dst_slice) = match envelope::access_all_2(&src, &dst) {
Ok(t) => t,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let mut dst_owned = match Arc::try_unwrap(dst_slice) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"FFT dst has multiple live references".into(),
)));
return;
}
};
dst.record_write(&stream);
envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
plan.exec_r2c(&*src_slice, &mut dst_owned)
.map(|_| (src_slice, dst_owned, plan))
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("exec_r2c: {e}"),
})
});
}
FftMsg::Inverse1dC2R {
n,
batch,
src,
dst,
reply,
} => {
let plan = match self.get_or_create_plan(PlanKey::Plan1d {
n,
kind: FftKind::C2rF32,
batch,
}) {
Ok(p) => p,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let (src_slice, dst_slice) = match envelope::access_all_2(&src, &dst) {
Ok(t) => t,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let mut src_owned = match Arc::try_unwrap(src_slice) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"FFT C2R src has multiple live references".into(),
)));
return;
}
};
let mut dst_owned = match Arc::try_unwrap(dst_slice) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"FFT C2R dst has multiple live references".into(),
)));
return;
}
};
dst.record_write(&stream);
envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
plan.exec_c2r(&mut src_owned, &mut dst_owned)
.map(|_| (src_owned, dst_owned, plan))
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("exec_c2r: {e}"),
})
});
}
FftMsg::Exec1dC2C {
n,
batch,
direction,
src,
dst,
reply,
} => {
let plan = match self.get_or_create_plan(PlanKey::Plan1d {
n,
kind: FftKind::C2cF32,
batch,
}) {
Ok(p) => p,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let (src_slice, dst_slice) = match envelope::access_all_2(&src, &dst) {
Ok(t) => t,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let mut src_owned = match Arc::try_unwrap(src_slice) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"FFT C2C src has multiple live references".into(),
)));
return;
}
};
let mut dst_owned = match Arc::try_unwrap(dst_slice) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"FFT C2C dst has multiple live references".into(),
)));
return;
}
};
dst.record_write(&stream);
envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
plan.exec_c2c(&mut src_owned, &mut dst_owned, direction)
.map(|_| (src_owned, dst_owned, plan))
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("exec_c2c: {e}"),
})
});
}
FftMsg::Forward2dR2C {
nx,
ny,
src,
dst,
reply,
} => {
let plan = match self.get_or_create_plan(PlanKey::Plan2d {
nx,
ny,
kind: FftKind::R2cF32,
}) {
Ok(p) => p,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let (src_slice, dst_slice) = match envelope::access_all_2(&src, &dst) {
Ok(t) => t,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let mut dst_owned = match Arc::try_unwrap(dst_slice) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"FFT 2D dst has multiple live references".into(),
)));
return;
}
};
dst.record_write(&stream);
envelope::run_kernel(LIB, &stream, &completion, (), reply, move || {
plan.exec_r2c(&*src_slice, &mut dst_owned)
.map(|_| (src_slice, dst_owned, plan))
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("exec_r2c (2d): {e}"),
})
});
}
}
}
}
fn reply_mock(msg: FftMsg) {
let err = || GpuError::Unrecoverable("FftActor in mock mode".into());
match msg {
FftMsg::Forward1dR2C { reply, .. } => {
let _ = reply.send(Err(err()));
}
FftMsg::Inverse1dC2R { reply, .. } => {
let _ = reply.send(Err(err()));
}
FftMsg::Exec1dC2C { reply, .. } => {
let _ = reply.send(Err(err()));
}
FftMsg::Forward2dR2C { reply, .. } => {
let _ = reply.send(Err(err()));
}
}
}