#![allow(dead_code)]
pub mod activation;
pub mod attention;
pub mod conv;
pub mod graph;
pub mod norm;
pub mod pool;
pub mod rnn;
use std::sync::Arc;
use async_trait::async_trait;
use atomr_core::actor::{Actor, Context, Props};
use cudarc::cudnn::Cudnn;
use cudarc::driver::CudaSlice;
use parking_lot::Mutex;
use tokio::sync::oneshot;
use crate::completion::CompletionStrategy;
use crate::device::DeviceState;
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
use crate::kernel::dispatch::{CudnnDispatch, CudnnDispatchCtx};
use crate::stream::StreamAllocator;
pub use activation::{
ActivationFwdRequest, ActivationKind, DropoutFwdRequest, LrnFwdRequest, LrnParams,
SoftmaxFwdRequest, SoftmaxMode,
};
pub use attention::{
AttentionMask, AttentionParams, MultiHeadAttnBwdRequest, MultiHeadAttnFwdRequest,
};
pub use conv::{
ConvBwdDataRequest, ConvBwdFilterRequest, ConvDescParams, ConvFwdRequest, EpilogueKind,
};
pub use graph::{
cache_key, CachedPlan, DtypeTag, NormMode, NormPhase, OpSpec, OperationGraphSpec, PlanCache,
PlanCacheKey, PointwiseMode, PoolKind, ReduceOp, TensorLayout, TensorSpec,
DEFAULT_PLAN_CACHE_SIZE,
};
pub use norm::{
BatchNormRequest, GroupNormRequest, InstanceNormRequest, LayerNormRequest, NormBwdRequest,
};
pub use pool::{PoolBwdRequest, PoolFwdRequest, PoolMode, PoolParams};
pub use rnn::{RnnBwdRequest, RnnDirection, RnnFwdRequest, RnnMode, RnnParams};
const LIB: &str = "cudnn";
#[derive(Debug, Clone, Copy)]
pub struct ConvParams {
pub pad: [i32; 2],
pub stride: [i32; 2],
pub dilation: [i32; 2],
}
pub struct ConvForwardRequest {
pub x: GpuRef<f32>,
pub x_dims: [i32; 4],
pub w: GpuRef<f32>,
pub w_dims: [i32; 4],
pub y: GpuRef<f32>,
pub y_dims: [i32; 4],
pub conv: ConvParams,
pub alpha: f32,
pub beta: f32,
pub reply: oneshot::Sender<Result<(), GpuError>>,
}
pub struct ActivationRequest {
pub kind: ActivationKind,
pub x: GpuRef<f32>,
pub y: GpuRef<f32>,
pub dims: [i32; 4],
pub alpha: f32,
pub beta: f32,
pub reply: oneshot::Sender<Result<(), GpuError>>,
}
pub struct SoftmaxRequest {
pub x: GpuRef<f32>,
pub y: GpuRef<f32>,
pub dims: [i32; 4],
pub alpha: f32,
pub beta: f32,
pub reply: oneshot::Sender<Result<(), GpuError>>,
}
pub enum CudnnMsg {
Op(Box<dyn CudnnDispatch>),
#[deprecated(note = "use CudnnMsg::Op with ConvFwdRequest<f32>")]
ConvForward(Box<ConvForwardRequest>),
#[deprecated(note = "use CudnnMsg::Op with ActivationFwdRequest<f32>")]
Activation(Box<ActivationRequest>),
#[deprecated(note = "use CudnnMsg::Op with SoftmaxFwdRequest<f32>")]
Softmax(Box<SoftmaxRequest>),
}
pub struct CudnnActor {
inner: CudnnInner,
}
struct SendCudnn(Arc<Cudnn>);
unsafe impl Send for SendCudnn {}
unsafe impl Sync for SendCudnn {}
enum CudnnInner {
Real {
handle: SendCudnn,
stream: Arc<cudarc::driver::CudaStream>,
completion: Arc<dyn CompletionStrategy>,
plan_cache: Mutex<PlanCache>,
workspace: Mutex<Option<CudaSlice<u8>>>,
#[allow(dead_code)]
state: Arc<DeviceState>,
},
Mock,
}
impl CudnnActor {
pub fn props(
stream: Arc<cudarc::driver::CudaStream>,
_allocator: Arc<dyn StreamAllocator>,
completion: Arc<dyn CompletionStrategy>,
state: Arc<DeviceState>,
) -> Props<Self> {
Props::create(move || {
let handle = match Cudnn::new(stream.clone()) {
Ok(h) => h,
Err(e) => panic!("ContextPoisoned: Cudnn::new failed: {e}"),
};
CudnnActor {
inner: CudnnInner::Real {
handle: SendCudnn(handle),
stream: stream.clone(),
completion: completion.clone(),
plan_cache: Mutex::new(PlanCache::new(DEFAULT_PLAN_CACHE_SIZE)),
workspace: Mutex::new(None),
state: state.clone(),
},
}
})
}
pub fn mock_props() -> Props<Self> {
Props::create(|| CudnnActor {
inner: CudnnInner::Mock,
})
}
}
#[async_trait]
impl Actor for CudnnActor {
type Msg = CudnnMsg;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: CudnnMsg) {
match &self.inner {
CudnnInner::Mock => reply_mock(msg),
CudnnInner::Real {
handle,
stream,
completion,
plan_cache,
workspace,
..
} => match msg {
CudnnMsg::Op(op) => {
let ctx = CudnnDispatchCtx {
handle: handle.0.clone(),
stream: stream.clone(),
completion: completion.clone(),
plan_cache,
workspace,
};
op.dispatch(&ctx);
}
#[allow(deprecated)]
CudnnMsg::ConvForward(req) => {
handle_legacy_conv_fwd(*req);
}
#[allow(deprecated)]
CudnnMsg::Activation(req) => {
handle_legacy_activation(*req);
}
#[allow(deprecated)]
CudnnMsg::Softmax(req) => {
handle_legacy_softmax(*req);
}
},
}
}
}
fn reply_mock(msg: CudnnMsg) {
let err = || GpuError::Unrecoverable("CudnnActor in mock mode".into());
match msg {
CudnnMsg::Op(_) => {
}
#[allow(deprecated)]
CudnnMsg::ConvForward(r) => {
let _ = r.reply.send(Err(err()));
}
#[allow(deprecated)]
CudnnMsg::Activation(r) => {
let _ = r.reply.send(Err(err()));
}
#[allow(deprecated)]
CudnnMsg::Softmax(r) => {
let _ = r.reply.send(Err(err()));
}
}
}
#[allow(deprecated)]
fn handle_legacy_conv_fwd(req: ConvForwardRequest) {
let _ = req.reply.send(Err(GpuError::LibraryError {
lib: LIB,
msg: "ConvForward (legacy) is deprecated; send CudnnMsg::Op(ConvFwdRequest<f32>) \
for v9 frontend dispatch"
.to_string(),
}));
}
#[allow(deprecated)]
fn handle_legacy_activation(req: ActivationRequest) {
let _ = req.reply.send(Err(GpuError::LibraryError {
lib: LIB,
msg: "Activation (legacy) is deprecated; send CudnnMsg::Op(ActivationFwdRequest<f32>)"
.to_string(),
}));
}
#[allow(deprecated)]
fn handle_legacy_softmax(req: SoftmaxRequest) {
let _ = req.reply.send(Err(GpuError::LibraryError {
lib: LIB,
msg: "Softmax (legacy) is deprecated; send CudnnMsg::Op(SoftmaxFwdRequest<f32>)"
.to_string(),
}));
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[allow(deprecated)]
fn deprecated_conv_forward_alias_still_constructs() {
let (tx, _rx) = oneshot::channel();
let p = ConvParams {
pad: [0, 0],
stride: [1, 1],
dilation: [1, 1],
};
assert_eq!(p.pad, [0, 0]);
assert_eq!(p.stride, [1, 1]);
fn _accepts_legacy(_: &CudnnMsg) {}
struct Probe(oneshot::Sender<Result<(), GpuError>>);
impl CudnnDispatch for Probe {
fn dtype_name(&self) -> &'static str {
"f32"
}
fn op_kind(&self) -> &'static str {
"probe"
}
fn dispatch(self: Box<Self>, _ctx: &CudnnDispatchCtx<'_>) {
let _ = self.0.send(Ok(()));
}
}
let msg = CudnnMsg::Op(Box::new(Probe(tx)));
_accepts_legacy(&msg);
}
#[test]
fn cudnn_dispatch_is_object_safe() {
fn _accept(_: Box<dyn CudnnDispatch>) {}
}
#[test]
fn plan_cache_default_size_matches_constant() {
let pc = PlanCache::default();
assert_eq!(pc.cap(), DEFAULT_PLAN_CACHE_SIZE);
}
}