use std::sync::Arc;
use cudarc::cublaslt::{Activation, Matmul, MatmulConfig};
use tokio::sync::oneshot;
use crate::dtype::{DTypeKind, GemmSupported};
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
use crate::kernel::blas_lt::epilogue::Epilogue;
use crate::kernel::blas_lt::scaling::ScaleSet;
use crate::kernel::dispatch::{BlasLtDispatch, BlasLtDispatchCtx};
use crate::kernel::envelope;
const LIB: &str = "cublaslt";
pub struct MatmulRequest<T: GemmSupported> {
pub a: GpuRef<T>,
pub b: GpuRef<T>,
pub c: GpuRef<T>,
pub d: Option<GpuRef<T>>,
pub m: i32,
pub n: i32,
pub k: i32,
pub alpha: T::Scalar,
pub beta: T::Scalar,
pub transa: bool,
pub transb: bool,
pub lda: i64,
pub ldb: i64,
pub ldc: i64,
pub ldd: i64,
pub epilogue: Epilogue,
pub bias: Option<GpuRef<T>>,
pub gelu_aux: Option<GpuRef<T>>,
pub scales: ScaleSet,
pub workspace_size: usize,
pub reply: oneshot::Sender<Result<(), GpuError>>,
}
impl<T: GemmSupported> std::fmt::Debug for MatmulRequest<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MatmulRequest")
.field("dtype", &T::NAME)
.field("m", &self.m)
.field("n", &self.n)
.field("k", &self.k)
.field("transa", &self.transa)
.field("transb", &self.transb)
.field("epilogue", &self.epilogue)
.field("workspace_size", &self.workspace_size)
.finish()
}
}
trait CudarcMatmulPath: GemmSupported {
fn dispatch_safe(req: Box<MatmulRequest<Self>>, ctx: &BlasLtDispatchCtx<'_>);
}
impl CudarcMatmulPath for f32 {
fn dispatch_safe(req: Box<MatmulRequest<f32>>, ctx: &BlasLtDispatchCtx<'_>) {
dispatch_safe_path::<f32>(req, ctx);
}
}
#[cfg(feature = "f16")]
impl CudarcMatmulPath for half::f16 {
fn dispatch_safe(req: Box<MatmulRequest<half::f16>>, ctx: &BlasLtDispatchCtx<'_>) {
dispatch_safe_path::<half::f16>(req, ctx);
}
}
#[cfg(feature = "f16")]
impl CudarcMatmulPath for half::bf16 {
fn dispatch_safe(req: Box<MatmulRequest<half::bf16>>, ctx: &BlasLtDispatchCtx<'_>) {
dispatch_safe_path::<half::bf16>(req, ctx);
}
}
trait UnsupportedMatmulPath {
fn dispatch_unsupported(reply: oneshot::Sender<Result<(), GpuError>>, dtype: &'static str);
}
impl<T> UnsupportedMatmulPath for T {
fn dispatch_unsupported(reply: oneshot::Sender<Result<(), GpuError>>, dtype: &'static str) {
let _ = reply.send(Err(GpuError::Unrecoverable(format!(
"BlasLtActor: matmul<{dtype}> not yet implemented (Phase 1 sys-level wiring pending)"
))));
}
}
impl BlasLtDispatch for MatmulRequest<f64> {
fn dtype_kind(&self) -> DTypeKind {
DTypeKind::F64
}
fn dispatch(self: Box<Self>, _ctx: &BlasLtDispatchCtx<'_>) {
<f64 as UnsupportedMatmulPath>::dispatch_unsupported(self.reply, "f64");
}
}
impl BlasLtDispatch for MatmulRequest<f32> {
fn dtype_kind(&self) -> DTypeKind {
DTypeKind::F32
}
fn dispatch(self: Box<Self>, ctx: &BlasLtDispatchCtx<'_>) {
<f32 as CudarcMatmulPath>::dispatch_safe(self, ctx);
}
}
#[cfg(feature = "f16")]
impl BlasLtDispatch for MatmulRequest<half::f16> {
fn dtype_kind(&self) -> DTypeKind {
DTypeKind::F16
}
fn dispatch(self: Box<Self>, ctx: &BlasLtDispatchCtx<'_>) {
<half::f16 as CudarcMatmulPath>::dispatch_safe(self, ctx);
}
}
#[cfg(feature = "f16")]
impl BlasLtDispatch for MatmulRequest<half::bf16> {
fn dtype_kind(&self) -> DTypeKind {
DTypeKind::Bf16
}
fn dispatch(self: Box<Self>, ctx: &BlasLtDispatchCtx<'_>) {
<half::bf16 as CudarcMatmulPath>::dispatch_safe(self, ctx);
}
}
#[cfg(feature = "cublas-fp8")]
impl BlasLtDispatch for MatmulRequest<crate::dtype::F8E4m3> {
fn dtype_kind(&self) -> DTypeKind {
DTypeKind::F8E4m3
}
fn dispatch(self: Box<Self>, _ctx: &BlasLtDispatchCtx<'_>) {
<crate::dtype::F8E4m3 as UnsupportedMatmulPath>::dispatch_unsupported(
self.reply, "fp8e4m3",
);
}
}
#[cfg(feature = "cublas-fp8")]
impl BlasLtDispatch for MatmulRequest<crate::dtype::F8E5m2> {
fn dtype_kind(&self) -> DTypeKind {
DTypeKind::F8E5m2
}
fn dispatch(self: Box<Self>, _ctx: &BlasLtDispatchCtx<'_>) {
<crate::dtype::F8E5m2 as UnsupportedMatmulPath>::dispatch_unsupported(
self.reply, "fp8e5m2",
);
}
}
fn dispatch_safe_path<T>(req: Box<MatmulRequest<T>>, ctx: &BlasLtDispatchCtx<'_>)
where
T: GemmSupported + cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
cudarc::cublaslt::CudaBlasLT: Matmul<T>,
T::Scalar: Into<f32> + Copy,
{
let MatmulRequest {
a,
b,
c,
d: _d,
m,
n,
k,
alpha,
beta,
transa,
transb,
lda,
ldb,
ldc,
ldd: _ldd,
epilogue,
bias,
gelu_aux: _gelu_aux,
scales: _scales,
workspace_size: _workspace_size,
reply,
} = *req;
let _entry = ctx
.heuristic
.get(&crate::kernel::blas_lt::heuristic::HeuristicKey::new(
m,
n,
k,
T::KIND,
transa,
transb,
epilogue,
ctx.sm_arch,
));
let activation = match epilogue {
Epilogue::Relu | Epilogue::ReluBias | Epilogue::ReluAux | Epilogue::ReluAuxBias => {
Some(Activation::Relu)
}
Epilogue::Gelu | Epilogue::GeluBias | Epilogue::GeluAux | Epilogue::GeluAuxBias => {
Some(Activation::Gelu)
}
_ => None,
};
let cfg = MatmulConfig {
transa,
transb,
transc: false,
m: m as u64,
n: n as u64,
k: k as u64,
alpha: alpha.into(),
lda,
ldb,
beta: beta.into(),
ldc,
stride_a: None,
stride_b: None,
stride_c: None,
stride_bias: None,
batch_size: None,
};
let (a_slice, b_slice, c_slice) = match envelope::access_all_3(&a, &b, &c) {
Ok(t) => t,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let bias_slice = match bias.as_ref() {
None => None,
Some(g) => match g.access() {
Ok(s) => Some(s.clone()),
Err(e) => {
let _ = reply.send(Err(e));
return;
}
},
};
let mut c_owned = match Arc::try_unwrap(c_slice) {
Ok(s) => s,
Err(_) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"BlasLt C has multiple live references".into(),
)));
return;
}
};
c.record_write(ctx.stream);
let blas_lt = ctx.blas_lt.clone();
let stream = ctx.stream;
let completion = ctx.completion;
envelope::run_kernel(LIB, stream, completion, (), reply, move || {
let bias_ref = bias_slice.as_ref().map(|s| &**s);
let act_ref = activation.as_ref();
let res =
unsafe { blas_lt.matmul(cfg, &*a_slice, &*b_slice, &mut c_owned, bias_ref, act_ref) };
match res {
Ok(()) => Ok((a_slice, b_slice, c_owned, bias_slice, blas_lt)),
Err(e) => Err(GpuError::LibraryError {
lib: LIB,
msg: format!("matmul: {e}"),
}),
}
});
}
#[cfg(test)]
mod tests {
use super::*;
fn make_request<T: GemmSupported>() -> MatmulRequest<T>
where
T::Scalar: Default,
{
let (tx, _rx) = oneshot::channel::<Result<(), GpuError>>();
let _ = (T::NAME, tx);
unreachable!("type-instantiation-only helper")
}
#[test]
fn matmul_request_dispatches_for_f32_f16_bf16() {
fn _accepts_f32(b: Box<dyn BlasLtDispatch>) -> Box<dyn BlasLtDispatch> {
b
}
let _f32_kind: fn(&MatmulRequest<f32>) -> DTypeKind = MatmulRequest::<f32>::dtype_kind;
let _f64_kind: fn(&MatmulRequest<f64>) -> DTypeKind = MatmulRequest::<f64>::dtype_kind;
#[cfg(feature = "f16")]
let _f16_kind: fn(&MatmulRequest<half::f16>) -> DTypeKind =
MatmulRequest::<half::f16>::dtype_kind;
#[cfg(feature = "f16")]
let _bf16_kind: fn(&MatmulRequest<half::bf16>) -> DTypeKind =
MatmulRequest::<half::bf16>::dtype_kind;
assert_eq!(<f32 as atomr_accel::AccelDtype>::KIND, DTypeKind::F32);
assert_eq!(<f64 as atomr_accel::AccelDtype>::KIND, DTypeKind::F64);
#[cfg(feature = "f16")]
{
assert_eq!(<half::f16 as atomr_accel::AccelDtype>::KIND, DTypeKind::F16);
assert_eq!(
<half::bf16 as atomr_accel::AccelDtype>::KIND,
DTypeKind::Bf16
);
}
let _ = make_request::<f32> as fn() -> MatmulRequest<f32>;
}
}