use std::any::Any;
use std::sync::Arc;
use cudarc::driver::{CudaSlice, DeviceRepr, LaunchArgs, PushKernelArg};
use atomr_accel::DType;
use crate::completion::CompletionStrategy;
use crate::device::DeviceState;
use crate::dtype::CudaDtype;
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
pub trait NvrtcLaunchDispatch: Send + 'static {
fn op_name(&self) -> &'static str;
fn dtype(&self) -> Option<DType>;
fn dispatch(self: Box<Self>, ctx: &NvrtcDispatchCtx<'_>);
}
pub struct NvrtcDispatchCtx<'a> {
pub stream: &'a Arc<cudarc::driver::CudaStream>,
pub completion: &'a Arc<dyn CompletionStrategy>,
pub state: &'a Arc<DeviceState>,
}
pub struct BlasDispatchCtx<'a> {
pub cublas: &'a Arc<cudarc::cublas::CudaBlas>,
pub stream: &'a Arc<cudarc::driver::CudaStream>,
pub completion: &'a Arc<dyn CompletionStrategy>,
pub state: &'a Arc<DeviceState>,
}
pub trait DevSliceArg: Send + Sync + 'static {
fn validate(&self) -> Result<Box<dyn Any + Send>, GpuError>;
fn push<'a>(&'a self, builder: &mut LaunchArgs<'a>) -> Result<(), GpuError>;
fn dtype(&self) -> Option<DType>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<T> DevSliceArg for GpuRef<T>
where
T: CudaDtype,
{
#[inline]
fn validate(&self) -> Result<Box<dyn Any + Send>, GpuError> {
let arc: Arc<CudaSlice<T>> = self.access()?.clone();
Ok(Box::new(arc))
}
#[inline]
fn push<'a>(&'a self, builder: &mut LaunchArgs<'a>) -> Result<(), GpuError> {
let arc = self.access()?;
builder.arg(&**arc);
Ok(())
}
#[inline]
fn dtype(&self) -> Option<DType> {
Some(<T as atomr_accel::AccelDtype>::KIND)
}
#[inline]
fn len(&self) -> usize {
GpuRef::<T>::len(self)
}
}
pub trait ScalarArg: Send + Sync + 'static {
fn push<'a>(&'a self, builder: &mut LaunchArgs<'a>);
fn dtype(&self) -> Option<DType>;
}
impl<T> ScalarArg for T
where
T: CudaDtype + DeviceRepr + Sync,
{
#[inline]
fn push<'a>(&'a self, builder: &mut LaunchArgs<'a>) {
builder.arg(self);
}
#[inline]
fn dtype(&self) -> Option<DType> {
Some(<T as atomr_accel::AccelDtype>::KIND)
}
}
pub type GemmDispatchCtx<'a> = BlasDispatchCtx<'a>;
pub trait GemmDispatch: Send + 'static {
fn dtype_name(&self) -> &'static str;
fn op_name(&self) -> &'static str;
fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>);
}
pub trait GemmStridedBatchedDispatch: Send + 'static {
fn dtype_name(&self) -> &'static str;
fn op_name(&self) -> &'static str;
fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>);
}
pub trait BlasL1Dispatch: Send + 'static {
fn dtype_name(&self) -> &'static str;
fn op_name(&self) -> &'static str;
fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>);
}
pub trait BlasL2Dispatch: Send + 'static {
fn dtype_name(&self) -> &'static str;
fn op_name(&self) -> &'static str;
fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>);
}
pub trait BlasL3Dispatch: Send + 'static {
fn dtype_name(&self) -> &'static str;
fn op_name(&self) -> &'static str;
fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>);
}
#[cfg(feature = "cublaslt")]
mod blaslt_dispatch_internal {
use std::sync::Arc;
use cudarc::cublaslt::CudaBlasLT;
use tokio::sync::oneshot;
use crate::completion::CompletionStrategy;
use crate::error::GpuError;
use crate::kernel::blas_lt::heuristic::HeuristicCacheRef;
use crate::kernel::blas_lt::workspace::WorkspacePool;
pub struct BlasLtDispatchCtx<'a> {
pub blas_lt: Arc<CudaBlasLT>,
pub stream: &'a Arc<cudarc::driver::CudaStream>,
pub completion: &'a Arc<dyn CompletionStrategy>,
pub workspace: &'a WorkspacePool,
pub heuristic: HeuristicCacheRef,
pub sm_arch: u32,
}
pub fn reply_unsupported(
reply: oneshot::Sender<Result<(), GpuError>>,
dtype_name: &'static str,
) {
let _ = reply.send(Err(GpuError::Unrecoverable(format!(
"BlasLtDispatch: dtype {dtype_name} unsupported in this build"
))));
}
}
#[cfg(feature = "cublaslt")]
pub use blaslt_dispatch_internal::{reply_unsupported, BlasLtDispatchCtx};
#[cfg(feature = "cublaslt")]
pub trait BlasLtDispatch: Send + 'static {
fn dtype_kind(&self) -> crate::dtype::DTypeKind;
fn dispatch(self: Box<Self>, ctx: &BlasLtDispatchCtx<'_>);
}
#[cfg(feature = "cudnn")]
pub use cudnn_dispatch::{CudnnDispatch, CudnnDispatchCtx};
#[cfg(feature = "cudnn")]
mod cudnn_dispatch {
use std::sync::Arc;
use parking_lot::Mutex;
use crate::completion::CompletionStrategy;
pub struct CudnnDispatchCtx<'a> {
pub handle: Arc<cudarc::cudnn::Cudnn>,
pub stream: Arc<cudarc::driver::CudaStream>,
pub completion: Arc<dyn CompletionStrategy>,
pub plan_cache: &'a Mutex<crate::kernel::cudnn::graph::PlanCache>,
pub workspace: &'a Mutex<Option<cudarc::driver::CudaSlice<u8>>>,
}
pub trait CudnnDispatch: Send + 'static {
fn dtype_name(&self) -> &'static str;
fn op_kind(&self) -> &'static str;
fn dispatch(self: Box<Self>, ctx: &CudnnDispatchCtx<'_>);
}
}
#[cfg(feature = "cufft")]
pub struct FftDispatchCtx<'a> {
pub stream: &'a Arc<cudarc::driver::CudaStream>,
pub completion: &'a Arc<dyn CompletionStrategy>,
pub plan: Arc<dyn std::any::Any + Send + Sync>,
}
#[cfg(feature = "cufft")]
pub trait FftDispatch: Send + 'static {
fn dtype_kind(&self) -> DType;
fn plan_key(&self) -> crate::kernel::fft::PlanKey;
fn dispatch(self: Box<Self>, ctx: &FftDispatchCtx<'_>);
}
pub trait RngDispatch: Send + 'static {
fn fill(
self: Box<Self>,
generator: cudarc::curand::sys::curandGenerator_t,
stream: &Arc<cudarc::driver::CudaStream>,
completion: &Arc<dyn CompletionStrategy>,
) -> Result<(), GpuError>;
}
#[cfg(feature = "cusolver")]
pub use crate::kernel::solver::SolverDispatch;
#[cfg(feature = "cusparse")]
pub struct SendSparseHandle(pub cudarc::cusparse::sys::cusparseHandle_t);
#[cfg(feature = "cusparse")]
unsafe impl Send for SendSparseHandle {}
#[cfg(feature = "cusparse")]
unsafe impl Sync for SendSparseHandle {}
#[cfg(feature = "cusparse")]
pub struct SparseDispatchCtx<'a> {
pub handle: &'a parking_lot::Mutex<SendSparseHandle>,
pub stream: &'a Arc<cudarc::driver::CudaStream>,
pub completion: &'a Arc<dyn CompletionStrategy>,
pub workspace: &'a parking_lot::Mutex<Option<cudarc::driver::CudaSlice<u8>>>,
}
#[cfg(feature = "cusparse")]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SparseOp {
SpMv,
SpMm,
SpGemm,
SpSv,
Sddmm,
DenseToSparse,
SparseToDense,
Convert,
}
#[cfg(feature = "cusparse")]
impl SparseOp {
pub fn as_str(self) -> &'static str {
match self {
SparseOp::SpMv => "spmv",
SparseOp::SpMm => "spmm",
SparseOp::SpGemm => "spgemm",
SparseOp::SpSv => "spsv",
SparseOp::Sddmm => "sddmm",
SparseOp::DenseToSparse => "dense_to_sparse",
SparseOp::SparseToDense => "sparse_to_dense",
SparseOp::Convert => "convert",
}
}
}
#[cfg(feature = "cusparse")]
pub trait SparseDispatch: Send + 'static {
fn op_name(&self) -> SparseOp;
fn dtype(&self) -> DType;
fn dispatch(self: Box<Self>, ctx: &SparseDispatchCtx<'_>);
}
#[cfg(feature = "cutensor")]
pub use cutensor_dispatch::{TensorDispatch, TensorDispatchCtx, WorkspacePool};
#[cfg(feature = "cutensor")]
mod cutensor_dispatch {
use std::sync::Arc;
use parking_lot::Mutex;
use crate::completion::CompletionStrategy;
use crate::error::GpuError;
use crate::kernel::tensor::plan_cache::PlanCache;
use crate::kernel::tensor::SendHandle;
pub struct WorkspacePool {
stream: Arc<cudarc::driver::CudaStream>,
buckets: Mutex<Vec<Bucket>>,
}
struct Bucket {
size: usize,
slice: cudarc::driver::CudaSlice<u8>,
}
impl WorkspacePool {
pub fn new(stream: Arc<cudarc::driver::CudaStream>) -> Self {
Self {
stream,
buckets: Mutex::new(Vec::new()),
}
}
pub fn ensure(&self, n: usize) -> Result<usize, GpuError> {
if n == 0 {
return Ok(0);
}
let bucket_size = n.next_power_of_two();
let mut g = self.buckets.lock();
if g.iter().any(|b| b.size == bucket_size) {
return Ok(bucket_size);
}
let slice = self
.stream
.alloc_zeros::<u8>(bucket_size)
.map_err(|e| GpuError::OutOfMemory(format!("cutensor workspace: {e}")))?;
g.push(Bucket {
size: bucket_size,
slice,
});
Ok(bucket_size)
}
pub fn with_bucket<F, R>(&self, n: usize, f: F) -> Option<R>
where
F: FnOnce(&mut cudarc::driver::CudaSlice<u8>) -> R,
{
if n == 0 {
return None;
}
let bucket_size = n.next_power_of_two();
let mut g = self.buckets.lock();
let b = g.iter_mut().find(|b| b.size == bucket_size)?;
Some(f(&mut b.slice))
}
}
pub struct TensorDispatchCtx {
pub handle: Arc<Mutex<SendHandle>>,
pub stream: Arc<cudarc::driver::CudaStream>,
pub completion: Arc<dyn CompletionStrategy>,
pub plan_cache: Arc<PlanCache>,
pub workspace: Arc<WorkspacePool>,
}
pub trait TensorDispatch: Send + 'static {
fn op_tag(&self) -> &'static str;
fn dtype_tag(&self) -> &'static str;
fn dispatch(self: Box<Self>, ctx: &TensorDispatchCtx);
fn fail_mock(self: Box<Self>);
}
}
#[cfg(feature = "nccl")]
pub use atomr_accel::DType as DispatchDType;
#[cfg(feature = "nccl")]
pub trait CollectiveDispatch: Send + 'static {
fn dtype_kind(&self) -> DispatchDType;
fn device_id(&self) -> Option<u32>;
fn dispatch(self: Box<Self>, ctx: &CollectiveDispatchCtx<'_>);
}
#[cfg(feature = "nccl")]
pub struct CollectiveDispatchCtx<'a> {
pub comm: &'a cudarc::nccl::Comm,
pub state: &'a Arc<DeviceState>,
pub completion: &'a Arc<dyn CompletionStrategy>,
}
#[cfg(test)]
mod tests {
use super::*;
struct DummyNvrtc {
op: &'static str,
d: Option<DType>,
called: std::sync::atomic::AtomicBool,
}
impl NvrtcLaunchDispatch for DummyNvrtc {
fn op_name(&self) -> &'static str {
self.op
}
fn dtype(&self) -> Option<DType> {
self.d
}
fn dispatch(self: Box<Self>, _ctx: &NvrtcDispatchCtx<'_>) {
self.called.store(true, std::sync::atomic::Ordering::SeqCst);
}
}
#[test]
fn nvrtc_dispatch_box_round_trip() {
let req = DummyNvrtc {
op: "relu",
d: Some(DType::F32),
called: std::sync::atomic::AtomicBool::new(false),
};
let boxed: Box<dyn NvrtcLaunchDispatch> = Box::new(req);
assert_eq!(boxed.op_name(), "relu");
assert_eq!(boxed.dtype(), Some(DType::F32));
let req2 = DummyNvrtc {
op: "noop",
d: None,
called: std::sync::atomic::AtomicBool::new(false),
};
assert_eq!(req2.op_name(), "noop");
assert_eq!(req2.dtype(), None);
}
#[allow(dead_code)]
fn _assert_dev_slice_arg_object_safe() {
fn takes_box(_: Box<dyn DevSliceArg>) {}
let _: fn(GpuRef<f32>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
let _: fn(GpuRef<f64>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
let _: fn(GpuRef<u8>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
let _: fn(GpuRef<i32>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
let _ = takes_box;
#[cfg(feature = "f16")]
{
let _: fn(GpuRef<half::f16>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
let _: fn(GpuRef<half::bf16>) -> Box<dyn DevSliceArg> = |g| Box::new(g);
}
}
#[test]
fn dev_slice_arg_for_gpu_ref() {
_assert_dev_slice_arg_object_safe();
}
#[test]
fn scalar_arg_blanket_impls_compile() {
fn takes(_: Box<dyn ScalarArg>) {}
takes(Box::new(1.0f32));
takes(Box::new(2.0f64));
takes(Box::new(3i32));
takes(Box::new(4u32));
takes(Box::new(5u64));
#[cfg(feature = "f16")]
{
takes(Box::new(half::f16::ONE));
takes(Box::new(half::bf16::ONE));
}
}
#[test]
fn stub_dispatch_traits_compile() {
fn _gemm(_: Box<dyn GemmDispatch>) {}
#[cfg(feature = "cublaslt")]
fn _blaslt(_: Box<dyn BlasLtDispatch>) {}
#[cfg(feature = "cudnn")]
fn _cudnn(_: Box<dyn CudnnDispatch>) {}
#[cfg(feature = "cufft")]
fn _fft(_: Box<dyn FftDispatch>) {}
fn _rng(_: Box<dyn RngDispatch>) {}
#[cfg(feature = "cusolver")]
fn _solver(_: Box<dyn crate::kernel::solver::SolverDispatch>) {}
#[cfg(feature = "cusparse")]
fn _sparse(_: Box<dyn SparseDispatch>) {}
#[cfg(feature = "cutensor")]
fn _tensor(_: Box<dyn TensorDispatch>) {}
#[cfg(feature = "nccl")]
fn _coll(_: Box<dyn CollectiveDispatch>) {}
}
}