use crate::{
compat::*,
cuda::*,
error::{CudaKernelError, Result},
kernel::Kernels,
kernels::macros::ops,
source::Source,
};
use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, StridedBatchedConfig};
use half::{bf16, f16};
ops!(matmul, dot);
macro_rules! impl_cublas_matmul {
($ty:ty) => {
paste::paste! {
#[doc = "cuBLAS-accelerated matmul for " $ty]
#[allow(dead_code)]
pub fn [<call_ops_matmul_cublas_ $ty>](
_kernel: crate::kernels::macros::Kernel,
_kernels: &Kernels,
context: &Arc<CudaContext>,
lhs: &CudaSlice<$ty>,
rhs: &CudaSlice<$ty>,
output: &mut CudaSlice<$ty>,
metadata: &[usize],
) -> Result<()> {
match call_ops_matmul_cublas(context, lhs, rhs, output, metadata) {
Ok(()) => Ok(()),
Err(_) => {
call_ops_matmul_kernel(_kernel, _kernels, context, lhs, rhs, output, metadata)
}
}
}
}
};
}
trait CublasGemm: cudarc::driver::DeviceRepr + Sized {
fn one() -> Self;
fn zero() -> Self;
fn cublas_gemm<A, B, C>(
blas: &CudaBlas,
cfg: GemmConfig<Self>,
lhs: &A,
rhs: &B,
output: &mut C,
) -> core::result::Result<(), cudarc::cublas::result::CublasError>
where
A: cudarc::driver::DevicePtr<Self>,
B: cudarc::driver::DevicePtr<Self>,
C: cudarc::driver::DevicePtrMut<Self>;
fn cublas_gemm_batched<A, B, C>(
blas: &CudaBlas,
cfg: StridedBatchedConfig<Self>,
lhs: &A,
rhs: &B,
output: &mut C,
) -> core::result::Result<(), cudarc::cublas::result::CublasError>
where
A: cudarc::driver::DevicePtr<Self>,
B: cudarc::driver::DevicePtr<Self>,
C: cudarc::driver::DevicePtrMut<Self>;
}
impl CublasGemm for bf16 {
fn one() -> Self {
bf16::from_f32(1.0)
}
fn zero() -> Self {
bf16::from_f32(0.0)
}
fn cublas_gemm<A, B, C>(
blas: &CudaBlas,
cfg: GemmConfig<Self>,
lhs: &A,
rhs: &B,
output: &mut C,
) -> core::result::Result<(), cudarc::cublas::result::CublasError>
where
A: cudarc::driver::DevicePtr<Self>,
B: cudarc::driver::DevicePtr<Self>,
C: cudarc::driver::DevicePtrMut<Self>,
{
unsafe { blas.gemm(cfg, lhs, rhs, output) }
}
fn cublas_gemm_batched<A, B, C>(
blas: &CudaBlas,
cfg: StridedBatchedConfig<Self>,
lhs: &A,
rhs: &B,
output: &mut C,
) -> core::result::Result<(), cudarc::cublas::result::CublasError>
where
A: cudarc::driver::DevicePtr<Self>,
B: cudarc::driver::DevicePtr<Self>,
C: cudarc::driver::DevicePtrMut<Self>,
{
unsafe { blas.gemm_strided_batched(cfg, lhs, rhs, output) }
}
}
impl CublasGemm for f16 {
fn one() -> Self {
f16::from_f32(1.0)
}
fn zero() -> Self {
f16::from_f32(0.0)
}
fn cublas_gemm<A, B, C>(
blas: &CudaBlas,
cfg: GemmConfig<Self>,
lhs: &A,
rhs: &B,
output: &mut C,
) -> core::result::Result<(), cudarc::cublas::result::CublasError>
where
A: cudarc::driver::DevicePtr<Self>,
B: cudarc::driver::DevicePtr<Self>,
C: cudarc::driver::DevicePtrMut<Self>,
{
unsafe { blas.gemm(cfg, lhs, rhs, output) }
}
fn cublas_gemm_batched<A, B, C>(
blas: &CudaBlas,
cfg: StridedBatchedConfig<Self>,
lhs: &A,
rhs: &B,
output: &mut C,
) -> core::result::Result<(), cudarc::cublas::result::CublasError>
where
A: cudarc::driver::DevicePtr<Self>,
B: cudarc::driver::DevicePtr<Self>,
C: cudarc::driver::DevicePtrMut<Self>,
{
unsafe { blas.gemm_strided_batched(cfg, lhs, rhs, output) }
}
}
impl CublasGemm for f32 {
fn one() -> Self {
1.0
}
fn zero() -> Self {
0.0
}
fn cublas_gemm<A, B, C>(
blas: &CudaBlas,
cfg: GemmConfig<Self>,
lhs: &A,
rhs: &B,
output: &mut C,
) -> core::result::Result<(), cudarc::cublas::result::CublasError>
where
A: cudarc::driver::DevicePtr<Self>,
B: cudarc::driver::DevicePtr<Self>,
C: cudarc::driver::DevicePtrMut<Self>,
{
unsafe { blas.gemm(cfg, lhs, rhs, output) }
}
fn cublas_gemm_batched<A, B, C>(
blas: &CudaBlas,
cfg: StridedBatchedConfig<Self>,
lhs: &A,
rhs: &B,
output: &mut C,
) -> core::result::Result<(), cudarc::cublas::result::CublasError>
where
A: cudarc::driver::DevicePtr<Self>,
B: cudarc::driver::DevicePtr<Self>,
C: cudarc::driver::DevicePtrMut<Self>,
{
unsafe { blas.gemm_strided_batched(cfg, lhs, rhs, output) }
}
}
impl CublasGemm for f64 {
fn one() -> Self {
1.0
}
fn zero() -> Self {
0.0
}
fn cublas_gemm<A, B, C>(
blas: &CudaBlas,
cfg: GemmConfig<Self>,
lhs: &A,
rhs: &B,
output: &mut C,
) -> core::result::Result<(), cudarc::cublas::result::CublasError>
where
A: cudarc::driver::DevicePtr<Self>,
B: cudarc::driver::DevicePtr<Self>,
C: cudarc::driver::DevicePtrMut<Self>,
{
unsafe { blas.gemm(cfg, lhs, rhs, output) }
}
fn cublas_gemm_batched<A, B, C>(
blas: &CudaBlas,
cfg: StridedBatchedConfig<Self>,
lhs: &A,
rhs: &B,
output: &mut C,
) -> core::result::Result<(), cudarc::cublas::result::CublasError>
where
A: cudarc::driver::DevicePtr<Self>,
B: cudarc::driver::DevicePtr<Self>,
C: cudarc::driver::DevicePtrMut<Self>,
{
unsafe { blas.gemm_strided_batched(cfg, lhs, rhs, output) }
}
}
fn call_ops_matmul_cublas<T>(
context: &Arc<CudaContext>,
lhs: &CudaSlice<T>,
rhs: &CudaSlice<T>,
output: &mut CudaSlice<T>,
metadata: &[usize],
) -> Result<()>
where
T: CublasGemm,
{
let lhs_ndim = metadata[1];
let rhs_ndim = metadata[2];
let batch_ndim = metadata[3];
let metadata_base = 4 + lhs_ndim + rhs_ndim + batch_ndim + lhs_ndim + rhs_ndim;
let lhs_offset = metadata[metadata_base];
let rhs_offset = metadata[metadata_base + 1];
let m = metadata[metadata_base + 2];
let k = metadata[metadata_base + 3];
let n = metadata[metadata_base + 4];
let num_batches = if batch_ndim == 0 {
1
} else {
let batch_shape = &metadata[4 + lhs_ndim + rhs_ndim..4 + lhs_ndim + rhs_ndim + batch_ndim];
batch_shape.iter().product()
};
let lhs_strides = &metadata[4 + lhs_ndim + rhs_ndim + batch_ndim..4 + 2 * lhs_ndim + rhs_ndim + batch_ndim];
let rhs_strides = &metadata[4 + 2 * lhs_ndim + rhs_ndim + batch_ndim..4 + 2 * lhs_ndim + 2 * rhs_ndim + batch_ndim];
let lhs_last_contiguous = lhs_strides[lhs_ndim - 1] == 1;
let rhs_last_contiguous = rhs_strides[rhs_ndim - 1] == 1;
if lhs_last_contiguous && rhs_last_contiguous {
let stream = context.default_stream();
let blas = CudaBlas::new(stream)
.map_err(|e| CudaKernelError::LaunchError(format!("Failed to create cuBLAS: {:?}", e)))?;
let mut lda = lhs_strides[lhs_ndim - 2] as i32;
let mut ldb = rhs_strides[rhs_ndim - 2] as i32;
if lda < k as i32 {
lda = k as i32;
}
if ldb < n as i32 {
ldb = n as i32;
}
let lhs_batch_stride = if batch_ndim > 0 { lhs_strides[0] as i64 } else { 0 };
let rhs_batch_stride = if batch_ndim > 0 { rhs_strides[0] as i64 } else { 0 };
let lhs_view = lhs.slice(lhs_offset..);
let rhs_view = rhs.slice(rhs_offset..);
let cfg = StridedBatchedConfig {
gemm: GemmConfig {
transa: cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
transb: cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
m: n as i32, n: m as i32,
k: k as i32,
alpha: T::one(),
lda: ldb, ldb: lda, beta: T::zero(),
ldc: n as i32,
},
batch_size: num_batches as i32,
stride_a: rhs_batch_stride, stride_b: lhs_batch_stride, stride_c: (m * n) as i64,
};
T::cublas_gemm_batched(&blas, cfg, &rhs_view, &lhs_view, output)
.map_err(|e| CudaKernelError::LaunchError(format!("cuBLAS GEMM failed: {:?}", e)))?;
Ok(())
} else {
Err(CudaKernelError::LaunchError(
"Only row-major matrices with contiguous last dimension are supported by cuBLAS path".into(),
))
}
}
fn call_ops_matmul_kernel<T>(
kernel: crate::kernels::macros::Kernel,
kernels: &Kernels,
context: &Arc<CudaContext>,
lhs: &CudaSlice<T>,
rhs: &CudaSlice<T>,
output: &mut CudaSlice<T>,
metadata: &[usize],
) -> Result<()>
where
T: cudarc::driver::DeviceRepr,
{
let func = kernels.load_function(context, Source::OpsMatrix, kernel.0)?;
let lhs_ndim = metadata[1];
let rhs_ndim = metadata[2];
let batch_ndim = metadata[3];
let metadata_base = 4 + lhs_ndim + rhs_ndim + batch_ndim + lhs_ndim + rhs_ndim;
let m = metadata[metadata_base + 2];
let n = metadata[metadata_base + 4];
let num_batches = if batch_ndim == 0 {
1
} else {
let batch_shape = &metadata[4 + lhs_ndim + rhs_ndim..4 + lhs_ndim + rhs_ndim + batch_ndim];
batch_shape.iter().product()
};
const TILE_SIZE: u32 = 16;
let grid_width = (n as u32).div_ceil(TILE_SIZE).max(1);
let grid_height = (m as u32).div_ceil(TILE_SIZE).max(1);
let cfg = LaunchConfig {
grid_dim: (grid_width, grid_height, num_batches as u32),
block_dim: (TILE_SIZE, TILE_SIZE, 1),
shared_mem_bytes: 0,
};
let stream = context.default_stream();
let metadata_dev = stream
.memcpy_stod(metadata)
.map_err(|e| CudaKernelError::MemoryError(format!("Failed to copy metadata: {:?}", e)))?;
unsafe {
func.launch(&stream, cfg, |args| {
args.arg(lhs).arg(rhs).arg(output).arg(&metadata_dev);
})
.map_err(|e| CudaKernelError::LaunchError(format!("Failed to launch kernel: {:?}", e)))?;
}
Ok(())
}
pub fn call_ops_matmul<T>(
kernel: crate::kernels::macros::Kernel,
kernels: &Kernels,
context: &Arc<CudaContext>,
lhs: &CudaSlice<T>,
rhs: &CudaSlice<T>,
output: &mut CudaSlice<T>,
metadata: &[usize],
) -> Result<()>
where
T: cudarc::driver::DeviceRepr,
{
call_ops_matmul_kernel(kernel, kernels, context, lhs, rhs, output, metadata)
}
impl_cublas_matmul!(bf16);
impl_cublas_matmul!(f16);
impl_cublas_matmul!(f32);
impl_cublas_matmul!(f64);
macro_rules! impl_cublas_dot {
($ty:ty) => {
paste::paste! {
#[doc = "cuBLAS-accelerated dot for " $ty]
#[allow(dead_code)]
pub fn [<call_ops_dot_cublas_ $ty>](
_kernel: crate::kernels::macros::Kernel,
_kernels: &Kernels,
context: &Arc<CudaContext>,
lhs: &CudaSlice<$ty>,
rhs: &CudaSlice<$ty>,
output: &mut CudaSlice<$ty>,
metadata: &[usize],
) -> Result<()> {
match call_ops_dot_cublas(context, lhs, rhs, output, metadata) {
Ok(()) => Ok(()),
Err(_) => {
call_ops_dot_kernel(_kernel, _kernels, context, lhs, rhs, output, metadata)
}
}
}
}
};
}
fn call_ops_dot_cublas<T>(
context: &Arc<CudaContext>,
lhs: &CudaSlice<T>,
rhs: &CudaSlice<T>,
output: &mut CudaSlice<T>,
metadata: &[usize],
) -> Result<()>
where
T: CublasGemm,
{
let m = metadata[0];
let k = metadata[1];
let n = metadata[2];
let lhs_stride_m = metadata[3];
let lhs_stride_k = metadata[4];
let rhs_stride_k = metadata[5];
let rhs_stride_n = metadata[6];
let lhs_offset = metadata[7];
let rhs_offset = metadata[8];
let lhs_last_contiguous = lhs_stride_k == 1;
let rhs_last_contiguous = rhs_stride_n == 1;
if lhs_last_contiguous && rhs_last_contiguous {
let stream = context.default_stream();
let blas = CudaBlas::new(stream)
.map_err(|e| CudaKernelError::LaunchError(format!("Failed to create cuBLAS: {:?}", e)))?;
let mut lda = lhs_stride_m as i32;
let mut ldb = rhs_stride_k as i32;
if lda < k as i32 {
lda = k as i32;
}
if ldb < n as i32 {
ldb = n as i32;
}
let lhs_view = lhs.slice(lhs_offset..);
let rhs_view = rhs.slice(rhs_offset..);
let cfg = GemmConfig {
transa: cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
transb: cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
m: n as i32, n: m as i32,
k: k as i32,
alpha: T::one(),
lda: ldb, ldb: lda, beta: T::zero(),
ldc: n as i32,
};
T::cublas_gemm(&blas, cfg, &rhs_view, &lhs_view, output)
.map_err(|e| CudaKernelError::LaunchError(format!("cuBLAS GEMM failed: {:?}", e)))?;
Ok(())
} else {
Err(CudaKernelError::LaunchError(
"Only row-major matrices with contiguous last dimension are supported by cuBLAS path".into(),
))
}
}
fn call_ops_dot_kernel<T>(
kernel: crate::kernels::macros::Kernel,
kernels: &Kernels,
context: &Arc<CudaContext>,
lhs: &CudaSlice<T>,
rhs: &CudaSlice<T>,
output: &mut CudaSlice<T>,
metadata: &[usize],
) -> Result<()>
where
T: cudarc::driver::DeviceRepr,
{
let func = kernels.load_function(context, Source::OpsMatrix, kernel.0)?;
let m = metadata[0];
let n = metadata[2];
const DOT_TILE_SIZE: u32 = 32;
const BLOCK_SIZE: u32 = 4;
const THREADS_PER_DIM: u32 = DOT_TILE_SIZE / BLOCK_SIZE;
let grid_width = (n as u32).div_ceil(DOT_TILE_SIZE).max(1);
let grid_height = (m as u32).div_ceil(DOT_TILE_SIZE).max(1);
let cfg = LaunchConfig {
grid_dim: (grid_width, grid_height, 1),
block_dim: (THREADS_PER_DIM, THREADS_PER_DIM, 1),
shared_mem_bytes: 0,
};
let stream = context.default_stream();
let metadata_dev = stream
.memcpy_stod(metadata)
.map_err(|e| CudaKernelError::MemoryError(format!("Failed to copy metadata: {:?}", e)))?;
unsafe {
func.launch(&stream, cfg, |args| {
args.arg(lhs).arg(rhs).arg(output).arg(&metadata_dev);
})
.map_err(|e| CudaKernelError::LaunchError(format!("Failed to launch kernel: {:?}", e)))?;
}
Ok(())
}
pub fn call_ops_dot<T>(
kernel: crate::kernels::macros::Kernel,
kernels: &Kernels,
context: &Arc<CudaContext>,
lhs: &CudaSlice<T>,
rhs: &CudaSlice<T>,
output: &mut CudaSlice<T>,
metadata: &[usize],
) -> Result<()>
where
T: cudarc::driver::DeviceRepr,
{
call_ops_dot_kernel(kernel, kernels, context, lhs, rhs, output, metadata)
}
impl_cublas_dot!(bf16);
impl_cublas_dot!(f16);
impl_cublas_dot!(f32);
impl_cublas_dot!(f64);