#[allow(unused_imports)]
use crate::error::Status;
pub mod context;
pub mod descriptor;
pub mod matmul;
pub mod types;
pub(crate) mod utility;
use std::{ffi::CString, path::Path};
use singe_core::{CudaRuntimeVersion, LibraryVersion};
use singe_cublas_sys as sys;
use singe_cuda::types::LibraryProperty;
use crate::{
error::{Error, Result},
try_ffi,
utility::{to_u64, to_usize},
};
pub type LoggerCallback = sys::cublasLtLoggerCallback_t;
pub fn version() -> Result<LibraryVersion> {
to_u64(unsafe { sys::cublasLtGetVersion() }, "version").map(LibraryVersion::from)
}
pub fn cudart_version() -> Result<CudaRuntimeVersion> {
to_u64(unsafe { sys::cublasLtGetCudartVersion() }, "cudart version")
.map(CudaRuntimeVersion::from)
}
pub fn library_property(property: LibraryProperty) -> Result<i32> {
let mut value = 0;
unsafe {
try_ffi!(sys::cublasLtGetProperty(property.into(), &raw mut value))?;
}
Ok(value)
}
pub fn heuristics_cache_capacity() -> Result<usize> {
let mut capacity = 0;
unsafe {
try_ffi!(sys::cublasLtHeuristicsCacheGetCapacity(&raw mut capacity))?;
}
to_usize(capacity, "heuristics cache capacity")
}
pub fn set_heuristics_cache_capacity(capacity: usize) -> Result<()> {
let capacity = capacity.try_into().map_err(|_| Error::OutOfRange {
name: "heuristics cache capacity".into(),
})?;
unsafe {
try_ffi!(sys::cublasLtHeuristicsCacheSetCapacity(capacity))?;
}
Ok(())
}
pub fn disable_cpu_instructions_set_mask(mask: u32) {
unsafe {
sys::cublasLtDisableCpuInstructionsSetMask(mask);
}
}
pub fn set_logger_callback(callback: LoggerCallback) -> Result<()> {
unsafe {
try_ffi!(sys::cublasLtLoggerSetCallback(callback))?;
}
Ok(())
}
pub fn logger_open_file(path: &Path) -> Result<()> {
let path = CString::new(path.as_os_str().to_string_lossy().as_bytes())?;
unsafe {
try_ffi!(sys::cublasLtLoggerOpenFile(path.as_ptr()))?;
}
Ok(())
}
pub fn set_logger_level(level: i32) -> Result<()> {
unsafe {
try_ffi!(sys::cublasLtLoggerSetLevel(level))?;
}
Ok(())
}
pub fn set_logger_mask(mask: i32) -> Result<()> {
unsafe {
try_ffi!(sys::cublasLtLoggerSetMask(mask))?;
}
Ok(())
}
pub fn force_disable_logger() -> Result<()> {
unsafe {
try_ffi!(sys::cublasLtLoggerForceDisable())?;
}
Ok(())
}