use std::sync::Arc;
use cudarc::cublas::sys::cublasOperation_t;
use cudarc::cublas::{Gemm, GemmConfig};
use tokio::sync::oneshot;
use crate::dtype::GemmSupported;
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
use crate::kernel::dispatch::{BlasDispatchCtx, GemmDispatch};
use crate::kernel::envelope;
const LIB: &str = "cublas";
pub struct GemmRequest<T: GemmSupported> {
pub a: GpuRef<T>,
pub b: GpuRef<T>,
pub c: GpuRef<T>,
pub m: i32,
pub n: i32,
pub k: i32,
pub alpha: T,
pub beta: T,
pub trans_a: cublasOperation_t,
pub trans_b: cublasOperation_t,
pub lda: i32,
pub ldb: i32,
pub ldc: i32,
pub reply: oneshot::Sender<Result<(), GpuError>>,
}
impl<T> GemmRequest<T>
where
T: GemmSupported,
GemmRequest<T>: GemmDispatch,
{
pub fn into_msg(self) -> crate::kernel::BlasMsg {
crate::kernel::BlasMsg::Gemm(Box::new(self))
}
}
fn dispatch_gemm<T>(req: GemmRequest<T>, ctx: &BlasDispatchCtx<'_>)
where
T: GemmSupported + Copy,
cudarc::cublas::CudaBlas: Gemm<T>,
{
let GemmRequest {
a,
b,
c,
m,
n,
k,
alpha,
beta,
trans_a,
trans_b,
lda,
ldb,
ldc,
reply,
} = req;
let (a_slice, b_slice, c_slice) = match envelope::access_all_3(&a, &b, &c) {
Ok(t) => t,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let cfg = GemmConfig::<T> {
transa: trans_a,
transb: trans_b,
m,
n,
k,
alpha,
lda,
ldb,
beta,
ldc,
};
let mut c_owned = match Arc::try_unwrap(c_slice) {
Ok(s) => s,
Err(_arc) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"GEMM target buffer C has more than one live reference; \
caller must hold the unique GpuRef to write to it"
.into(),
)));
return;
}
};
c.record_write(ctx.stream);
let cublas = ctx.cublas.clone();
envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
let res = unsafe { cublas.gemm(cfg, &*a_slice, &*b_slice, &mut c_owned) };
match res {
Ok(()) => Ok((cublas, a_slice, b_slice, c_owned)),
Err(e) => Err(GpuError::LibraryError {
lib: LIB,
msg: format!("gemm enqueue: {e}"),
}),
}
});
}
impl GemmDispatch for GemmRequest<f32> {
fn dtype_name(&self) -> &'static str {
<f32 as atomr_accel::AccelDtype>::NAME
}
fn op_name(&self) -> &'static str {
"gemm"
}
fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
dispatch_gemm::<f32>(*self, ctx);
}
}
impl GemmDispatch for GemmRequest<f64> {
fn dtype_name(&self) -> &'static str {
<f64 as atomr_accel::AccelDtype>::NAME
}
fn op_name(&self) -> &'static str {
"gemm"
}
fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
dispatch_gemm::<f64>(*self, ctx);
}
}
#[cfg(feature = "f16")]
impl GemmDispatch for GemmRequest<half::f16> {
fn dtype_name(&self) -> &'static str {
<half::f16 as atomr_accel::AccelDtype>::NAME
}
fn op_name(&self) -> &'static str {
"gemm"
}
fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
dispatch_gemm::<half::f16>(*self, ctx);
}
}
#[cfg(feature = "f16")]
impl GemmDispatch for GemmRequest<half::bf16> {
fn dtype_name(&self) -> &'static str {
<half::bf16 as atomr_accel::AccelDtype>::NAME
}
fn op_name(&self) -> &'static str {
"gemm"
}
fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
dispatch_gemm::<half::bf16>(*self, ctx);
}
}
#[cfg(test)]
pub(crate) mod tests_helpers {
use crate::gpu_ref::GpuRef;
pub fn gpu_ref_stub<T>() -> GpuRef<T> {
GpuRef::<T>::for_test_no_gpu_leaked()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::oneshot;
fn _assert_send<T: Send>() {}
#[test]
fn gemm_request_dispatches_for_f32_f64_f16_bf16() {
_assert_send::<GemmRequest<f32>>();
_assert_send::<GemmRequest<f64>>();
#[cfg(feature = "f16")]
{
_assert_send::<GemmRequest<half::f16>>();
_assert_send::<GemmRequest<half::bf16>>();
}
let req = stub_request::<f32>();
let boxed: Box<dyn GemmDispatch> = Box::new(req);
assert_eq!(boxed.op_name(), "gemm");
assert_eq!(boxed.dtype_name(), "f32");
Box::leak(boxed);
let req = stub_request::<f64>();
let boxed: Box<dyn GemmDispatch> = Box::new(req);
assert_eq!(boxed.dtype_name(), "f64");
Box::leak(boxed);
#[cfg(feature = "f16")]
{
let req = stub_request::<half::f16>();
let boxed: Box<dyn GemmDispatch> = Box::new(req);
assert_eq!(boxed.dtype_name(), "f16");
Box::leak(boxed);
let req = stub_request::<half::bf16>();
let boxed: Box<dyn GemmDispatch> = Box::new(req);
assert_eq!(boxed.dtype_name(), "bf16");
Box::leak(boxed);
}
}
#[test]
fn deprecated_sgemm_alias_still_constructs() {
#[allow(deprecated)]
{
let (tx, _rx) = oneshot::channel();
let req = crate::device::SgemmRequest {
a: gpu_ref_stub::<f32>(),
b: gpu_ref_stub::<f32>(),
c: gpu_ref_stub::<f32>(),
m: 1,
n: 1,
k: 1,
alpha: 1.0,
beta: 0.0,
reply: tx,
};
let msg = crate::kernel::BlasMsg::Sgemm(Box::new(req));
Box::leak(Box::new(msg));
}
}
fn stub_request<T>() -> GemmRequest<T>
where
T: GemmSupported + num_one_zero::NumOneZero,
GemmRequest<T>: GemmDispatch,
{
let (tx, _rx) = oneshot::channel();
GemmRequest::<T> {
a: gpu_ref_stub::<T>(),
b: gpu_ref_stub::<T>(),
c: gpu_ref_stub::<T>(),
m: 1,
n: 1,
k: 1,
alpha: <T as num_one_zero::NumOneZero>::one(),
beta: <T as num_one_zero::NumOneZero>::zero(),
trans_a: cublasOperation_t::CUBLAS_OP_N,
trans_b: cublasOperation_t::CUBLAS_OP_N,
lda: 1,
ldb: 1,
ldc: 1,
reply: tx,
}
}
mod num_one_zero {
pub trait NumOneZero: Copy {
fn one() -> Self;
fn zero() -> Self;
}
impl NumOneZero for f32 {
fn one() -> Self {
1.0
}
fn zero() -> Self {
0.0
}
}
impl NumOneZero for f64 {
fn one() -> Self {
1.0
}
fn zero() -> Self {
0.0
}
}
#[cfg(feature = "f16")]
impl NumOneZero for half::f16 {
fn one() -> Self {
half::f16::ONE
}
fn zero() -> Self {
half::f16::ZERO
}
}
#[cfg(feature = "f16")]
impl NumOneZero for half::bf16 {
fn one() -> Self {
half::bf16::ONE
}
fn zero() -> Self {
half::bf16::ZERO
}
}
}
use super::tests_helpers::gpu_ref_stub;
}