#![allow(unused_variables, unused_imports, dead_code)]
use candle_core::{Device, DeviceLocation, Result, Tensor};
use candle_nn::Activation as CandleActivation;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{LazyLock, Mutex, Once};
pub struct CublasLtController {
handle: Mutex<Option<&'static CublasLtWrapper>>,
inhibit: AtomicBool,
device_location: Mutex<Option<DeviceLocation>>,
}
impl CublasLtController {
pub fn set_inhibit(&self, value: bool) {
self.inhibit.store(value, Ordering::SeqCst);
}
pub fn get(&self) -> Option<&'static CublasLtWrapper> {
if self.inhibit.load(Ordering::SeqCst) {
return None;
}
let handle_opt = self.handle.lock().unwrap();
*handle_opt
}
pub fn get_for_device(&self, device: &Device) -> Option<&'static CublasLtWrapper> {
if self.inhibit.load(Ordering::SeqCst) {
return None;
}
let device_loc = self.device_location.lock().unwrap();
if let Some(init_loc) = *device_loc {
if device.location() != init_loc {
return None;
}
}
let handle_opt = self.handle.lock().unwrap();
*handle_opt
}
}
pub static CUBLASLT_CONTROLLER: LazyLock<CublasLtController> =
LazyLock::new(|| CublasLtController {
handle: Mutex::new(None),
inhibit: AtomicBool::new(false),
device_location: Mutex::new(None),
});
#[cfg(feature = "cuda")]
mod api;
#[cfg(feature = "cuda")]
mod matmul;
#[cfg(test)]
#[cfg(feature = "cuda")]
mod tests;
#[cfg(feature = "cuda")]
pub use api::{fused_batch_matmul, fused_batch_matmul_f8, CublasLt};
pub fn maybe_init_cublas_lt_wrapper(device: Device) {
static INIT: Once = Once::new();
INIT.call_once(|| {
#[cfg(feature = "cuda")]
{
match device {
Device::Cuda(_) => {
let wrapper = Box::new(CublasLtWrapper {
cublaslt: CublasLt::new(&device).unwrap(),
});
let wrapper_ptr = Box::leak(wrapper) as &'static CublasLtWrapper;
let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
*handle_lock = Some(wrapper_ptr);
let mut device_loc = CUBLASLT_CONTROLLER.device_location.lock().unwrap();
*device_loc = Some(device.location());
}
_ => {
let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
*handle_lock = None;
}
}
}
#[cfg(not(feature = "cuda"))]
{
let mut handle_lock = CUBLASLT_CONTROLLER.handle.lock().unwrap();
*handle_lock = None;
}
});
}
#[derive(Debug, Clone)]
pub struct CublasLtWrapper {
#[cfg(feature = "cuda")]
pub cublaslt: CublasLt,
}
impl CublasLtWrapper {
#[allow(clippy::too_many_arguments)]
pub fn batch_matmul_f8(
&self,
a: &Tensor,
b: &Tensor,
dequant_a_scale: &Tensor,
dequant_b_scale: &Tensor,
quantize_scale: &Tensor,
out: Option<&Tensor>,
alpha: Option<f32>,
beta: Option<f32>,
bias: Option<&Tensor>,
act: Option<CandleActivation>,
) -> Result<Tensor> {
#[cfg(feature = "cuda")]
{
let inner_act = act.map(|a| match a {
CandleActivation::Relu => matmul::Activation::Relu,
CandleActivation::Gelu => matmul::Activation::Gelu,
_ => unreachable!("Unsupported activation in cublaslt matmul"),
});
let mut result = fused_batch_matmul_f8(
a,
b,
dequant_a_scale,
dequant_b_scale,
quantize_scale,
out,
alpha,
beta,
bias,
inner_act,
self.cublaslt.clone(),
)?;
if Some(CandleActivation::Swiglu) == act {
result = candle_nn::ops::swiglu(&result)?;
}
Ok(result)
}
#[cfg(not(feature = "cuda"))]
{
candle_core::bail!("`cuda` feature is not enabled")
}
}
#[allow(clippy::too_many_arguments)]
pub fn batch_matmul(
&self,
a: &Tensor,
b: &Tensor,
out: Option<&Tensor>,
alpha: Option<f32>,
beta: Option<f32>,
bias: Option<&Tensor>,
act: Option<CandleActivation>,
) -> Result<Tensor> {
#[cfg(feature = "cuda")]
{
let inner_act = act.map(|a| match a {
CandleActivation::Relu => matmul::Activation::Relu,
CandleActivation::Gelu => matmul::Activation::Gelu,
_ => unreachable!("Unsupported activation in cublaslt matmul"),
});
let mut result = fused_batch_matmul(
a,
b,
out,
alpha,
beta,
bias,
inner_act,
self.cublaslt.clone(),
)?;
if Some(CandleActivation::Swiglu) == act {
result = candle_nn::ops::swiglu(&result)?;
}
Ok(result)
}
#[cfg(not(feature = "cuda"))]
{
candle_core::bail!("`cuda` feature is not enabled")
}
}
}