#![allow(non_camel_case_types, non_snake_case)]
use std::ffi::c_void;
use std::os::raw::{c_int, c_uint};
use crate::error::GpuError;
const LIB: &str = "cusparselt";
#[repr(C, align(8))]
pub struct cusparseLtHandle_t(pub [u64; 1408]);
impl cusparseLtHandle_t {
pub fn zeroed() -> Self {
Self([0; 1408])
}
}
#[repr(C, align(8))]
pub struct cusparseLtMatDescriptor_t(pub [u64; 1408]);
impl cusparseLtMatDescriptor_t {
pub fn zeroed() -> Self {
Self([0; 1408])
}
}
#[repr(C, align(8))]
pub struct cusparseLtMatmulDescriptor_t(pub [u64; 1408]);
impl cusparseLtMatmulDescriptor_t {
pub fn zeroed() -> Self {
Self([0; 1408])
}
}
#[repr(C, align(8))]
pub struct cusparseLtMatmulAlgSelection_t(pub [u64; 1408]);
impl cusparseLtMatmulAlgSelection_t {
pub fn zeroed() -> Self {
Self([0; 1408])
}
}
#[repr(C, align(8))]
pub struct cusparseLtMatmulPlan_t(pub [u64; 1408]);
impl cusparseLtMatmulPlan_t {
pub fn zeroed() -> Self {
Self([0; 1408])
}
}
pub type cusparseStatus_t = c_uint;
pub const CUSPARSE_STATUS_SUCCESS: cusparseStatus_t = 0;
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum cusparseLtPruneAlg_t {
CUSPARSELT_PRUNE_SPMMA_TILE = 0,
CUSPARSELT_PRUNE_SPMMA_STRIP = 1,
}
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum cusparseComputeType {
CUSPARSE_COMPUTE_16F = 0,
CUSPARSE_COMPUTE_32I = 1,
CUSPARSE_COMPUTE_TF32 = 2,
CUSPARSE_COMPUTE_TF32_FAST = 3,
}
pub struct SendCuSparseLtHandle {
pub raw: Box<cusparseLtHandle_t>,
}
unsafe impl Send for SendCuSparseLtHandle {}
unsafe impl Sync for SendCuSparseLtHandle {}
impl Default for SendCuSparseLtHandle {
fn default() -> Self {
Self {
raw: Box::new(cusparseLtHandle_t::zeroed()),
}
}
}
#[cfg(unix)]
mod linux {
use std::ffi::{c_void, CStr};
use std::os::raw::c_char;
use std::sync::OnceLock;
extern "C" {
fn dlopen(filename: *const c_char, flag: i32) -> *mut c_void;
fn dlsym(handle: *mut c_void, symbol: *const c_char) -> *mut c_void;
}
const RTLD_NOW: i32 = 2;
const RTLD_GLOBAL: i32 = 0x100;
pub struct LtLib {
handle: *mut c_void,
}
unsafe impl Send for LtLib {}
unsafe impl Sync for LtLib {}
pub fn lib() -> Option<&'static LtLib> {
static LIB: OnceLock<Option<LtLib>> = OnceLock::new();
LIB.get_or_init(|| {
for soname in [
b"libcusparseLt.so.0\0".as_ptr(),
b"libcusparseLt.so\0".as_ptr(),
] {
let h = unsafe { dlopen(soname as *const c_char, RTLD_NOW | RTLD_GLOBAL) };
if !h.is_null() {
return Some(LtLib { handle: h });
}
}
None
})
.as_ref()
}
impl LtLib {
pub fn sym(&self, name: &CStr) -> Option<*mut c_void> {
let p = unsafe { dlsym(self.handle, name.as_ptr()) };
if p.is_null() {
None
} else {
Some(p)
}
}
}
}
#[cfg(not(unix))]
mod linux {
pub struct LtLib;
pub fn lib() -> Option<&'static LtLib> {
None
}
impl LtLib {
pub fn sym(&self, _: &std::ffi::CStr) -> Option<*mut std::ffi::c_void> {
None
}
}
}
pub fn probe() -> Result<(), GpuError> {
if linux::lib().is_some() {
Ok(())
} else {
Err(GpuError::LibraryError {
lib: LIB,
msg: "libcusparseLt.so not loadable; install cuSPARSELt or unset cusparse-lt".into(),
})
}
}
macro_rules! lt_sym {
($vis:vis $name:ident: fn($($arg:ty),* $(,)?) -> $ret:ty) => {
$vis fn $name() -> Option<unsafe extern "C" fn($($arg),*) -> $ret> {
use std::ffi::CString;
let lib = linux::lib()?;
let cname = CString::new(stringify!($name)).ok()?;
let raw = lib.sym(&cname)?;
Some(unsafe { std::mem::transmute::<*mut c_void, unsafe extern "C" fn($($arg),*) -> $ret>(raw) })
}
};
}
lt_sym!(pub cusparseLtInit: fn(*mut cusparseLtHandle_t) -> cusparseStatus_t);
lt_sym!(pub cusparseLtDestroy: fn(*const cusparseLtHandle_t) -> cusparseStatus_t);
lt_sym!(
pub cusparseLtSpMMAPrune: fn(
*const cusparseLtHandle_t,
*const cusparseLtMatmulDescriptor_t,
*const c_void,
*mut c_void,
cusparseLtPruneAlg_t,
*mut c_void, ) -> cusparseStatus_t
);
lt_sym!(
pub cusparseLtSpMMACompressedSize2: fn(
*const cusparseLtHandle_t,
*const cusparseLtMatDescriptor_t,
*mut usize,
*mut usize,
) -> cusparseStatus_t
);
lt_sym!(
pub cusparseLtSpMMACompress2: fn(
*const cusparseLtHandle_t,
*const cusparseLtMatDescriptor_t,
c_int, c_int, *const c_void, *mut c_void, *mut c_void, *mut c_void, ) -> cusparseStatus_t
);
lt_sym!(
pub cusparseLtMatmulPlanInit: fn(
*const cusparseLtHandle_t,
*mut cusparseLtMatmulPlan_t,
*const cusparseLtMatmulDescriptor_t,
*const cusparseLtMatmulAlgSelection_t,
) -> cusparseStatus_t
);
lt_sym!(
pub cusparseLtMatmul: fn(
*const cusparseLtHandle_t,
*const cusparseLtMatmulPlan_t,
*const c_void, *const c_void, *const c_void, *const c_void, *const c_void, *mut c_void, *mut c_void, *mut *mut c_void, c_uint, ) -> cusparseStatus_t
);
lt_sym!(
pub cusparseLtMatmulPlanDestroy: fn(
*const cusparseLtMatmulPlan_t,
) -> cusparseStatus_t
);
#[inline]
pub fn ok(status: cusparseStatus_t, what: &'static str) -> Result<(), GpuError> {
if status == CUSPARSE_STATUS_SUCCESS {
Ok(())
} else {
Err(GpuError::LibraryError {
lib: LIB,
msg: format!("{what}: status={status}"),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn handle_alignment() {
let h = cusparseLtHandle_t::zeroed();
let p = &h as *const _ as usize;
assert_eq!(p % 8, 0);
}
#[test]
fn probe_returns_typed_error_when_library_missing() {
let _ = probe();
}
}