#![cfg(all(feature = "cuda", feature = "cusparselt"))]
#![allow(non_snake_case, non_camel_case_types, non_upper_case_globals)]
#![allow(dead_code)]
use cudarc::driver::DevicePtr;
use crate::buffer::CudaBuffer;
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
pub mod sys {
#![allow(clippy::all)]
#![allow(unused, non_snake_case, non_camel_case_types, non_upper_case_globals)]
include!(concat!(env!("OUT_DIR"), "/cusparselt_sys.rs"));
}
#[inline]
fn check(status: sys::cusparseStatus_t, op: &'static str) -> GpuResult<()> {
if status == sys::cusparseStatus_t::CUSPARSE_STATUS_SUCCESS {
Ok(())
} else {
Err(GpuError::InvalidState {
message: format!("{op} returned cuSPARSELt status {status:?}"),
})
}
}
pub struct CusparseLtHandle {
inner: sys::cusparseLtHandle_t,
}
unsafe impl Send for CusparseLtHandle {}
unsafe impl Sync for CusparseLtHandle {}
impl std::fmt::Debug for CusparseLtHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CusparseLtHandle").finish()
}
}
impl CusparseLtHandle {
pub fn new() -> GpuResult<Self> {
let mut inner: sys::cusparseLtHandle_t =
unsafe { std::mem::MaybeUninit::zeroed().assume_init() };
let status = unsafe { sys::cusparseLtInit(&mut inner as *mut _) };
check(status, "cusparseLtInit")?;
Ok(Self { inner })
}
#[inline]
pub fn raw(&self) -> *const sys::cusparseLtHandle_t {
&self.inner as *const _
}
#[inline]
pub fn raw_mut(&mut self) -> *mut sys::cusparseLtHandle_t {
&mut self.inner as *mut _
}
}
impl Drop for CusparseLtHandle {
fn drop(&mut self) {
unsafe {
let _ = sys::cusparseLtDestroy(&mut self.inner as *mut _);
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum CuSpLtDType {
F16,
Bf16,
F32,
}
impl CuSpLtDType {
fn cuda_dtype(self) -> sys::cudaDataType_t {
match self {
CuSpLtDType::F16 => sys::cudaDataType_t::CUDA_R_16F,
CuSpLtDType::Bf16 => sys::cudaDataType_t::CUDA_R_16BF,
CuSpLtDType::F32 => sys::cudaDataType_t::CUDA_R_32F,
}
}
fn compute_type(self) -> sys::cusparseComputeType {
match self {
CuSpLtDType::F16 | CuSpLtDType::Bf16 => sys::cusparseComputeType::CUSPARSE_COMPUTE_32F,
CuSpLtDType::F32 => sys::cusparseComputeType::CUSPARSE_COMPUTE_TF32,
}
}
fn elem_bytes(self) -> usize {
match self {
CuSpLtDType::F16 | CuSpLtDType::Bf16 => 2,
CuSpLtDType::F32 => 4,
}
}
fn alignment(self) -> u32 {
16
}
}
#[allow(clippy::too_many_arguments)]
pub fn gpu_sparse_matmul_24<T>(
handle: &CusparseLtHandle,
a_dense: &CudaBuffer<T>,
b_dense_decompressed: &CudaBuffer<T>,
m: usize,
k: usize,
n: usize,
dtype: CuSpLtDType,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<T>>
where
T: cudarc::driver::DeviceRepr + Default + Copy + 'static,
{
if a_dense.len() != m * k {
return Err(GpuError::ShapeMismatch {
op: "cusparselt::sparse_matmul_24",
expected: vec![m, k],
got: vec![a_dense.len()],
});
}
if b_dense_decompressed.len() != k * n {
return Err(GpuError::ShapeMismatch {
op: "cusparselt::sparse_matmul_24",
expected: vec![k, n],
got: vec![b_dense_decompressed.len()],
});
}
if m == 0 || n == 0 || k == 0 {
let stream = device.stream();
let slice = stream.alloc_zeros::<T>(m * n)?;
return Ok(CudaBuffer::<T> {
data: Some(slice),
len: m * n,
alloc_len: m * n,
device_ordinal: device.ordinal(),
pool_fn: None,
});
}
let elem_align: usize = match dtype {
CuSpLtDType::F16 | CuSpLtDType::Bf16 => 8,
CuSpLtDType::F32 => 4,
};
if k % elem_align != 0 || n % elem_align != 0 || m % elem_align != 0 {
return Err(GpuError::InvalidState {
message: format!(
"cusparselt::sparse_matmul_24: dims (m={m}, k={k}, n={n}) must each be a multiple of {elem_align} for dtype {dtype:?}"
),
});
}
let stream = device.stream();
let cu_stream = stream.cu_stream() as sys::cudaStream_t;
let dtype_cuda = dtype.cuda_dtype();
let compute = dtype.compute_type();
let align: u32 = dtype.alignment();
let mut a_descr: sys::cusparseLtMatDescriptor_t =
unsafe { std::mem::MaybeUninit::zeroed().assume_init() };
let mut b_descr: sys::cusparseLtMatDescriptor_t =
unsafe { std::mem::MaybeUninit::zeroed().assume_init() };
let mut c_descr: sys::cusparseLtMatDescriptor_t =
unsafe { std::mem::MaybeUninit::zeroed().assume_init() };
let mut matmul_descr: sys::cusparseLtMatmulDescriptor_t =
unsafe { std::mem::MaybeUninit::zeroed().assume_init() };
let mut alg_sel: sys::cusparseLtMatmulAlgSelection_t =
unsafe { std::mem::MaybeUninit::zeroed().assume_init() };
let mut plan: sys::cusparseLtMatmulPlan_t =
unsafe { std::mem::MaybeUninit::zeroed().assume_init() };
let mut out_slice = stream.alloc_zeros::<T>(m * n)?;
let m_i64 = i64::try_from(m).map_err(|_| GpuError::InvalidState {
message: format!("cusparselt: m={m} exceeds i64::MAX"),
})?;
let n_i64 = i64::try_from(n).map_err(|_| GpuError::InvalidState {
message: format!("cusparselt: n={n} exceeds i64::MAX"),
})?;
let k_i64 = i64::try_from(k).map_err(|_| GpuError::InvalidState {
message: format!("cusparselt: k={k} exceeds i64::MAX"),
})?;
let result = (|| -> GpuResult<CudaBuffer<T>> {
let status = unsafe {
sys::cusparseLtDenseDescriptorInit(
handle.raw(),
&mut a_descr as *mut _,
m_i64,
k_i64,
k_i64,
align,
dtype_cuda,
sys::cusparseOrder_t::CUSPARSE_ORDER_ROW,
)
};
check(status, "cusparseLtDenseDescriptorInit (A)")?;
let status = unsafe {
sys::cusparseLtStructuredDescriptorInit(
handle.raw(),
&mut b_descr as *mut _,
k_i64,
n_i64,
n_i64,
align,
dtype_cuda,
sys::cusparseOrder_t::CUSPARSE_ORDER_ROW,
sys::cusparseLtSparsity_t::CUSPARSELT_SPARSITY_50_PERCENT,
)
};
check(status, "cusparseLtStructuredDescriptorInit (B)")?;
let status = unsafe {
sys::cusparseLtDenseDescriptorInit(
handle.raw(),
&mut c_descr as *mut _,
m_i64,
n_i64,
n_i64,
align,
dtype_cuda,
sys::cusparseOrder_t::CUSPARSE_ORDER_ROW,
)
};
check(status, "cusparseLtDenseDescriptorInit (C)")?;
let status = unsafe {
sys::cusparseLtMatmulDescriptorInit(
handle.raw(),
&mut matmul_descr as *mut _,
sys::cusparseOperation_t::CUSPARSE_OPERATION_NON_TRANSPOSE,
sys::cusparseOperation_t::CUSPARSE_OPERATION_NON_TRANSPOSE,
&a_descr as *const _,
&b_descr as *const _,
&c_descr as *const _,
&c_descr as *const _,
compute,
)
};
check(status, "cusparseLtMatmulDescriptorInit")?;
let status = unsafe {
sys::cusparseLtMatmulAlgSelectionInit(
handle.raw(),
&mut alg_sel as *mut _,
&matmul_descr as *const _,
sys::cusparseLtMatmulAlg_t::CUSPARSELT_MATMUL_ALG_DEFAULT,
)
};
check(status, "cusparseLtMatmulAlgSelectionInit")?;
let status = unsafe {
sys::cusparseLtMatmulPlanInit(
handle.raw(),
&mut plan as *mut _,
&matmul_descr as *const _,
&alg_sel as *const _,
)
};
check(status, "cusparseLtMatmulPlanInit")?;
let mut workspace_size: usize = 0;
let status = unsafe {
sys::cusparseLtMatmulGetWorkspace(
handle.raw(),
&plan as *const _,
&mut workspace_size as *mut _,
)
};
check(status, "cusparseLtMatmulGetWorkspace")?;
let mut compressed_size: usize = 0;
let mut compressed_buffer_size: usize = 0;
let status = unsafe {
sys::cusparseLtSpMMACompressedSize(
handle.raw(),
&plan as *const _,
&mut compressed_size as *mut _,
&mut compressed_buffer_size as *mut _,
)
};
check(status, "cusparseLtSpMMACompressedSize")?;
let mut workspace = stream.alloc_zeros::<u8>(workspace_size.max(1))?;
let mut compressed = stream.alloc_zeros::<u8>(compressed_size.max(1))?;
let mut compressed_scratch = stream.alloc_zeros::<u8>(compressed_buffer_size.max(1))?;
use cudarc::driver::DevicePtrMut;
{
let (b_dense_ptr, _b_dense_sync) = b_dense_decompressed.inner().device_ptr(&stream);
let (compressed_ptr, _compressed_sync) = compressed.device_ptr_mut(&stream);
let (compressed_scratch_ptr, _compressed_scratch_sync) =
compressed_scratch.device_ptr_mut(&stream);
let status = unsafe {
sys::cusparseLtSpMMACompress(
handle.raw(),
&plan as *const _,
b_dense_ptr as *const std::ffi::c_void,
compressed_ptr as *mut std::ffi::c_void,
compressed_scratch_ptr as *mut std::ffi::c_void,
cu_stream,
)
};
check(status, "cusparseLtSpMMACompress")?;
}
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
{
let (a_ptr, _a_sync) = a_dense.inner().device_ptr(&stream);
let (compressed_ptr_ro, _compressed_sync_ro) = compressed.device_ptr_mut(&stream);
let (out_ptr, _out_sync) = out_slice.device_ptr_mut(&stream);
let (workspace_ptr, _workspace_sync) = workspace.device_ptr_mut(&stream);
let mut streams: [sys::cudaStream_t; 1] = [cu_stream];
let status = unsafe {
sys::cusparseLtMatmul(
handle.raw(),
&plan as *const _,
std::ptr::from_ref::<f32>(&alpha).cast::<std::ffi::c_void>(),
a_ptr as *const std::ffi::c_void,
compressed_ptr_ro as *const std::ffi::c_void,
std::ptr::from_ref::<f32>(&beta).cast::<std::ffi::c_void>(),
out_ptr as *const std::ffi::c_void,
out_ptr as *mut std::ffi::c_void,
workspace_ptr as *mut std::ffi::c_void,
streams.as_mut_ptr(),
1,
)
};
check(status, "cusparseLtMatmul")?;
}
Ok(CudaBuffer::<T> {
data: Some(out_slice),
len: m * n,
alloc_len: m * n,
device_ordinal: device.ordinal(),
pool_fn: None,
})
})();
unsafe {
let _ = sys::cusparseLtMatmulPlanDestroy(&mut plan as *mut _);
let _ = sys::cusparseLtMatDescriptorDestroy(&mut c_descr as *mut _);
let _ = sys::cusparseLtMatDescriptorDestroy(&mut b_descr as *mut _);
let _ = sys::cusparseLtMatDescriptorDestroy(&mut a_descr as *mut _);
}
result
}