use candle_core::cuda::cudarc::cublaslt::result::set_matrix_layout_attribute;
use candle_core::cuda::cudarc::cublaslt::{result, result::CublasError, sys};
use candle_core::cuda::cudarc::driver::sys::{CUdevice_attribute, CUdeviceptr, CUstream};
use candle_core::cuda::cudarc::driver::{
CudaSlice, CudaStream, DevicePtr, DevicePtrMut, DeviceRepr, DriverError,
};
use candle_core::cuda::CudaDType;
use candle_core::DType;
use core::ffi::c_int;
use core::mem;
use float8::F8E4M3;
use half::{bf16, f16};
use std::sync::Arc;
#[derive(Debug)]
pub struct CudaBlasLT {
handle: sys::cublasLtHandle_t,
workspace: Workspace,
stream: Arc<CudaStream>,
}
unsafe impl Send for CudaBlasLT {}
unsafe impl Sync for CudaBlasLT {}
impl CudaBlasLT {
pub fn new(stream: Arc<CudaStream>) -> Result<Self, CublasError> {
let handle = result::create_handle()?;
let workspace = Workspace::new(stream.clone()).unwrap();
Ok(Self {
handle,
workspace,
stream,
})
}
}
impl Drop for CudaBlasLT {
fn drop(&mut self) {
let handle = mem::replace(&mut self.handle, std::ptr::null_mut());
if !handle.is_null() {
unsafe { result::destroy_handle(handle) }.unwrap();
}
}
}
#[derive(Debug, Clone)]
pub struct Workspace {
pub(crate) buffer: CudaSlice<u8>,
pub(crate) size: usize,
}
impl Workspace {
pub fn new(stream: Arc<CudaStream>) -> Result<Self, DriverError> {
stream.context().bind_to_thread()?;
let major = stream
.context()
.attribute(CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)?;
let workspace_size = if major >= 9 { 33_554_432 } else { 4_194_304 };
let buffer = unsafe { stream.alloc::<u8>(workspace_size)? };
Ok(Self {
buffer,
size: workspace_size,
})
}
}
#[derive(Debug, Clone)]
pub enum Activation {
Relu,
Gelu,
}
struct MatrixLayout {
handle: sys::cublasLtMatrixLayout_t,
}
impl MatrixLayout {
fn new(
matrix_type: sys::cudaDataType,
rows: u64,
cols: u64,
ld: i64,
) -> Result<Self, CublasError> {
let handle = result::create_matrix_layout(matrix_type, rows, cols, ld)?;
Ok(Self { handle })
}
fn set_batch(&self, size: c_int, stride: i64) -> Result<(), CublasError> {
unsafe {
set_matrix_layout_attribute(
self.handle,
sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
(&size) as *const _ as *const _,
mem::size_of::<c_int>(),
)?;
set_matrix_layout_attribute(
self.handle,
sys::cublasLtMatrixLayoutAttribute_t::CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
(&stride) as *const _ as *const _,
mem::size_of::<i64>(),
)?;
}
Ok(())
}
}
impl Drop for MatrixLayout {
fn drop(&mut self) {
unsafe {
result::destroy_matrix_layout(self.handle).expect("Unable to destroy matrix layout")
}
}
}
enum Matrix {
A,
B,
C,
D,
}
struct MatmulDesc {
handle: sys::cublasLtMatmulDesc_t,
}
impl MatmulDesc {
fn new(
compute_type: sys::cublasComputeType_t,
scale_type: sys::cudaDataType,
) -> Result<Self, CublasError> {
let handle = result::create_matmul_desc(compute_type, scale_type)?;
Ok(Self { handle })
}
fn set_transpose(&self, transpose: bool, matrix: Matrix) -> Result<(), CublasError> {
let transpose = transpose as i32;
let attr = match matrix {
Matrix::A => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSA,
Matrix::B => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSB,
Matrix::C => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_TRANSC,
Matrix::D => unreachable!(),
};
unsafe {
result::set_matmul_desc_attribute(
self.handle,
attr,
(&transpose) as *const _ as *const _,
mem::size_of::<u32>(),
)?;
}
Ok(())
}
fn set_scale_ptr(&self, device_ptr: &CUdeviceptr, matrix: Matrix) -> Result<(), CublasError> {
let attr = match matrix {
Matrix::A => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
Matrix::B => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
Matrix::C => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_C_SCALE_POINTER,
Matrix::D => sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_D_SCALE_POINTER,
};
unsafe {
result::set_matmul_desc_attribute(
self.handle,
attr,
device_ptr as *const CUdeviceptr as *const _,
mem::size_of::<CUdeviceptr>(),
)?;
}
Ok(())
}
fn set_epilogue(
&self,
act: Option<&Activation>,
bias_ptr: Option<&CUdeviceptr>,
stride_bias: Option<i64>,
) -> Result<(), CublasError> {
let epilogue = if let Some(bias_ptr) = bias_ptr {
let epilogue = act
.map(|act| match act {
Activation::Relu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_BIAS,
Activation::Gelu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_BIAS,
})
.unwrap_or(sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS);
unsafe {
result::set_matmul_desc_attribute(
self.handle,
sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_POINTER,
bias_ptr as *const CUdeviceptr as *const _,
mem::size_of::<CUdeviceptr>(),
)?;
}
if let Some(stride_bias) = stride_bias {
unsafe {
result::set_matmul_desc_attribute(
self.handle,
sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_BIAS_BATCH_STRIDE,
(&stride_bias) as *const _ as *const _,
mem::size_of::<i64>(),
)?;
}
}
epilogue
} else if let Some(act) = act {
match act {
Activation::Relu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU,
Activation::Gelu => sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU,
}
} else {
sys::cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT
};
unsafe {
result::set_matmul_desc_attribute(
self.handle,
sys::cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_EPILOGUE,
(&epilogue) as *const _ as *const _,
mem::size_of::<sys::cublasLtMatmulDescAttributes_t>(),
)?;
}
Ok(())
}
}
impl Drop for MatmulDesc {
fn drop(&mut self) {
unsafe { result::destroy_matmul_desc(self.handle).expect("Unable to destroy matmul desc") }
}
}
struct MatmulPref {
handle: sys::cublasLtMatmulPreference_t,
}
impl MatmulPref {
fn new() -> Result<Self, CublasError> {
let handle = result::create_matmul_pref()?;
Ok(Self { handle })
}
fn set_workspace_size(&self, size: usize) -> Result<(), CublasError> {
unsafe {
result::set_matmul_pref_attribute(
self.handle,
sys::cublasLtMatmulPreferenceAttributes_t::CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
(&size) as *const _ as *const _,
mem::size_of::<usize>(),
)?;
}
Ok(())
}
}
impl Drop for MatmulPref {
fn drop(&mut self) {
unsafe { result::destroy_matmul_pref(self.handle).expect("Unable to destroy matmul pref") }
}
}
pub trait MatmulShared {
fn handle(&self) -> &sys::cublasLtHandle_t;
fn workspace(&self) -> &Workspace;
fn stream(&self) -> &Arc<CudaStream>;
}
#[derive(Debug, Copy, Clone)]
pub struct MatmulConfig {
pub transa: bool,
pub transb: bool,
pub m: u64,
pub n: u64,
pub k: u64,
pub alpha: f32,
pub lda: i64,
pub ldb: i64,
pub beta: f32,
pub ldc: i64,
pub stride_a: Option<i64>,
pub stride_b: Option<i64>,
pub stride_c: Option<i64>,
pub stride_bias: Option<i64>,
pub batch_size: Option<c_int>,
}
pub enum CublasLTInternalDType {
F32,
BF16,
F16,
F8E4M3,
}
pub trait CublasLTDType: CudaDType + DeviceRepr {
const T: CublasLTInternalDType;
}
impl CublasLTDType for f32 {
const T: CublasLTInternalDType = CublasLTInternalDType::F32;
}
impl CublasLTDType for f16 {
const T: CublasLTInternalDType = CublasLTInternalDType::F16;
}
impl CublasLTDType for bf16 {
const T: CublasLTInternalDType = CublasLTInternalDType::BF16;
}
impl CublasLTDType for F8E4M3 {
const T: CublasLTInternalDType = CublasLTInternalDType::F8E4M3;
}
pub trait Matmul<T: CublasLTDType>: MatmulShared {
fn matrix_type() -> sys::cudaDataType;
fn compute_type() -> sys::cublasComputeType_t;
#[allow(clippy::too_many_arguments)]
unsafe fn matmul_fp8_like<
I: DevicePtr<T>,
C: DevicePtr<bf16>,
OB: DevicePtrMut<bf16>,
S: DevicePtr<f32>,
B: DevicePtr<bf16>,
>(
&self,
cfg: MatmulConfig,
a: &I,
b: &I,
scale_a: &S,
scale_b: &S,
scale_d: &S,
c: &C,
out: &mut OB,
bias: Option<&B>,
act: Option<&Activation>,
) -> Result<(), CublasError> {
let (a_rows, a_cols) = (cfg.k, cfg.m);
let (b_rows, b_cols) = (cfg.k, cfg.n);
assert!(cfg.transa);
assert!(!cfg.transb);
let matmul_desc = MatmulDesc::new(
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
sys::cudaDataType_t::CUDA_R_32F,
)?;
matmul_desc.set_transpose(cfg.transa, Matrix::A)?;
matmul_desc.set_transpose(cfg.transb, Matrix::B)?;
let a_layout = MatrixLayout::new(Self::matrix_type(), a_rows, a_cols, cfg.lda)?;
if let (Some(batch_size), Some(stride_a)) = (cfg.batch_size, cfg.stride_a) {
a_layout.set_batch(batch_size, stride_a)?;
}
let b_layout = MatrixLayout::new(Self::matrix_type(), b_rows, b_cols, cfg.ldb)?;
if let (Some(batch_size), Some(stride_b)) = (cfg.batch_size, cfg.stride_b) {
b_layout.set_batch(batch_size, stride_b)?;
}
let c_layout = MatrixLayout::new(sys::cudaDataType_t::CUDA_R_16BF, cfg.m, cfg.n, cfg.ldc)?;
if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) {
c_layout.set_batch(batch_size, stride_c)?;
}
let out_ty = sys::cudaDataType_t::CUDA_R_16BF;
let d_layout = MatrixLayout::new(out_ty, cfg.m, cfg.n, cfg.ldc)?;
if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) {
d_layout.set_batch(batch_size, stride_c)?;
}
let (scale_a, _scale_a_guard) = scale_a.device_ptr(self.stream());
let (scale_b, _scale_b_guard) = scale_b.device_ptr(self.stream());
let (_scale_d, _scale_d_guard) = scale_d.device_ptr(self.stream());
matmul_desc.set_scale_ptr(&scale_a, Matrix::A)?;
matmul_desc.set_scale_ptr(&scale_b, Matrix::B)?;
let (bias_ptr, _bias_ptr_guard) = bias.map(|b| b.device_ptr(self.stream())).unzip();
matmul_desc.set_epilogue(act, bias_ptr.as_ref(), cfg.stride_bias)?;
let matmul_pref = MatmulPref::new()?;
matmul_pref.set_workspace_size(self.workspace().size)?;
let heuristic = result::get_matmul_algo_heuristic(
*self.handle(),
matmul_desc.handle,
a_layout.handle,
b_layout.handle,
c_layout.handle,
d_layout.handle,
matmul_pref.handle,
)?;
let (out_ptr, _out_guard) = out.device_ptr_mut(self.stream());
let (a, _a_guard) = a.device_ptr(self.stream());
let (b, _b_guard) = b.device_ptr(self.stream());
let (c, _c_guard) = c.device_ptr(self.stream());
let workspace = &self.workspace().buffer;
let (workspace, _workspace_guard) = workspace.device_ptr(self.stream());
result::matmul(
*self.handle(),
matmul_desc.handle,
(&cfg.alpha) as *const _ as *const _,
(&cfg.beta) as *const _ as *const _,
a as *const _,
a_layout.handle,
b as *const _,
b_layout.handle,
c as *const _,
c_layout.handle,
out_ptr as *mut _,
d_layout.handle,
(&heuristic.algo) as *const _,
workspace as *mut _,
self.workspace().size,
self.stream().cu_stream() as *mut _,
)
}
unsafe fn matmul<I: DevicePtr<T>, O: DevicePtrMut<T>>(
&self,
cfg: MatmulConfig,
a: &I,
b: &I,
c: &mut O,
bias: Option<&I>,
act: Option<&Activation>,
) -> Result<(), CublasError> {
let (a_rows, a_cols) = if cfg.transa {
(cfg.k, cfg.m)
} else {
(cfg.m, cfg.k)
};
let (b_rows, b_cols) = if cfg.transb {
(cfg.n, cfg.k)
} else {
(cfg.k, cfg.n)
};
let a_layout = MatrixLayout::new(Self::matrix_type(), a_rows, a_cols, cfg.lda)?;
if let (Some(batch_size), Some(stride_a)) = (cfg.batch_size, cfg.stride_a) {
a_layout.set_batch(batch_size, stride_a)?;
}
let b_layout = MatrixLayout::new(Self::matrix_type(), b_rows, b_cols, cfg.ldb)?;
if let (Some(batch_size), Some(stride_b)) = (cfg.batch_size, cfg.stride_b) {
b_layout.set_batch(batch_size, stride_b)?;
}
let c_layout = MatrixLayout::new(Self::matrix_type(), cfg.m, cfg.n, cfg.ldc)?;
if let (Some(batch_size), Some(stride_c)) = (cfg.batch_size, cfg.stride_c) {
c_layout.set_batch(batch_size, stride_c)?;
}
let matmul_desc = MatmulDesc::new(Self::compute_type(), sys::cudaDataType_t::CUDA_R_32F)?;
matmul_desc.set_transpose(cfg.transa, Matrix::A)?;
matmul_desc.set_transpose(cfg.transb, Matrix::B)?;
let (bias_ptr, _bias_ptr_guard) = bias.map(|b| b.device_ptr(self.stream())).unzip();
matmul_desc.set_epilogue(act, bias_ptr.as_ref(), cfg.stride_bias)?;
let matmul_pref = MatmulPref::new()?;
matmul_pref.set_workspace_size(self.workspace().size)?;
let heuristic = result::get_matmul_algo_heuristic(
*self.handle(),
matmul_desc.handle,
a_layout.handle,
b_layout.handle,
c_layout.handle,
c_layout.handle,
matmul_pref.handle,
)?;
let (a, _a_guard) = a.device_ptr(self.stream());
let (b, _b_guard) = b.device_ptr(self.stream());
let (c, _c_guard) = c.device_ptr_mut(self.stream());
let workspace = &self.workspace().buffer;
let (workspace, _workspace_guard) = workspace.device_ptr(self.stream());
result::matmul(
*self.handle(),
matmul_desc.handle,
(&cfg.alpha) as *const _ as *const _,
(&cfg.beta) as *const _ as *const _,
a as *const _,
a_layout.handle,
b as *const _,
b_layout.handle,
c as *const _,
c_layout.handle,
c as *mut _,
c_layout.handle,
(&heuristic.algo) as *const _,
workspace as *mut _,
self.workspace().size,
self.stream().cu_stream() as *mut _,
)
}
}
impl MatmulShared for CudaBlasLT {
fn handle(&self) -> &sys::cublasLtHandle_t {
&self.handle
}
fn workspace(&self) -> &Workspace {
&self.workspace
}
fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
}
impl<T: CublasLTDType> Matmul<T> for CudaBlasLT {
fn matrix_type() -> sys::cudaDataType {
match T::T {
CublasLTInternalDType::F32 => sys::cudaDataType_t::CUDA_R_32F,
CublasLTInternalDType::BF16 => sys::cudaDataType_t::CUDA_R_16BF,
CublasLTInternalDType::F16 => sys::cudaDataType_t::CUDA_R_16F,
CublasLTInternalDType::F8E4M3 => sys::cudaDataType_t::CUDA_R_8F_E4M3,
}
}
fn compute_type() -> sys::cublasComputeType_t {
match T::T {
CublasLTInternalDType::F32 => sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32,
CublasLTInternalDType::BF16
| CublasLTInternalDType::F16
| CublasLTInternalDType::F8E4M3 => sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
}
}
}