use crate::error::BitTTTError;
use crate::kernels::packing::PackedTensor;
use candle_core::{Device, Result, Tensor};
#[cfg(feature = "cuda")]
use candle_core::cuda_backend::cudarc::driver::DevicePtr;
#[cfg(feature = "cuda")]
#[cfg(feature = "cuda")]
use candle_core::cuda_backend::cudarc::driver::{LaunchAsync, LaunchConfig};
#[cfg(feature = "cuda")]
const _BIT_OP_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/bit_op.ptx"));
#[cfg(feature = "cuda")]
const ADAPTIVE_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/adaptive_bit_op.ptx"));
pub struct BitLinearCuda;
impl BitLinearCuda {
pub fn forward(
input: &Tensor, weights: &PackedTensor, ) -> Result<Tensor> {
let (m, k) = input.dims2()?;
let (n_out, k_w) = weights.shape.dims2()?;
if k != k_w {
return Err(BitTTTError::shape_mismatch(format!(
"Input [{}, {}] vs Weight [{}, {}]",
m, k, n_out, k_w
))
.into());
}
let device = input.device();
match device {
Device::Cuda(dev) => {
#[cfg(feature = "cuda")]
if let Some(ref scales) = weights.adaptive_scales {
if let Ok(scales_vec) = scales.to_vec1::<f32>() {
tracing::debug!(
"🔥 [CUDA] forward: n_out={}, k={}, scales={:?}",
n_out,
k,
scales_vec
);
}
static DEBUG_COMPARE: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(true);
if DEBUG_COMPARE.swap(false, std::sync::atomic::Ordering::SeqCst) {
let legacy_result = {
let w_dequant = weights.unpack(&Device::Cuda(dev.clone()))?;
if let Ok(w_flat) = w_dequant.flatten_all()?.to_vec1::<f32>() {
tracing::info!(
"🔬 [DEBUG] Unpacked weights first 16: {:?}",
&w_flat[..16.min(w_flat.len())]
);
let non_zero = w_flat.iter().filter(|&&x| x.abs() > 1e-6).count();
tracing::info!(
"🔬 [DEBUG] Unpacked weights: {} non-zero out of {}",
non_zero,
w_flat.len()
);
}
let w_t = w_dequant.t()?;
input.matmul(&w_t)?
};
let adaptive_result =
Self::adaptive_forward(input, &weights.data, scales, n_out)?;
if let (Ok(leg), Ok(adp)) = (
legacy_result.flatten_all()?.to_vec1::<f32>(),
adaptive_result.flatten_all()?.to_vec1::<f32>(),
) {
let diff: f32 = leg
.iter()
.zip(adp.iter())
.map(|(a, b)| (a - b).abs())
.sum::<f32>()
/ leg.len() as f32;
tracing::info!(
"🔬 [DEBUG] Legacy first 8: {:?}",
&leg[..8.min(leg.len())]
);
tracing::info!(
"🔬 [DEBUG] Adaptive first 8: {:?}",
&adp[..8.min(adp.len())]
);
tracing::info!("🔬 [DEBUG] Mean absolute diff: {}", diff);
}
return Ok(adaptive_result);
}
return Self::adaptive_forward(input, &weights.data, scales, n_out);
}
let w_dequant = weights.unpack(&Device::Cuda(dev.clone()))?;
let w_t = w_dequant.t()?;
let output = input.matmul(&w_t)?;
Ok(output)
}
_ => Err(BitTTTError::device_error("BitLinearCuda called on non-CUDA device").into()),
}
}
pub fn forward_legacy(input: &Tensor, weights: &PackedTensor) -> Result<Tensor> {
let (m, k) = input.dims2()?;
let (n_out, k_w) = weights.shape.dims2()?;
if k != k_w {
return Err(BitTTTError::shape_mismatch(format!(
"Input [{}, {}] vs Weight [{}, {}]",
m, k, n_out, k_w
))
.into());
}
let device = input.device();
match device {
Device::Cuda(dev) => {
let w_dequant = weights.unpack(&Device::Cuda(dev.clone()))?;
let w_t = w_dequant.t()?;
let output = input.matmul(&w_t)?;
Ok(output)
}
_ => Err(BitTTTError::device_error("BitLinearCuda called on non-CUDA device").into()),
}
}
#[cfg(feature = "cuda")]
pub fn adaptive_forward(
input: &Tensor, weights: &Tensor, scales: &Tensor, out_dim: usize, ) -> Result<Tensor> {
let (batch, in_dim) = input.dims2()?;
let dev = match input.device() {
Device::Cuda(d) => d,
_ => {
return Err(
BitTTTError::device_error("adaptive_forward called on non-CUDA device").into(),
)
}
};
let inp_ptr = {
let inp_storage = input.storage_and_layout().0;
match &*inp_storage {
candle_core::Storage::Cuda(s) => *s.as_cuda_slice::<f32>()?.device_ptr(),
_ => return Err(BitTTTError::storage_error("Input must be CUDA F32").into()),
}
};
let w_ptr = {
let w_storage = weights.storage_and_layout().0;
match &*w_storage {
candle_core::Storage::Cuda(s) => *s.as_cuda_slice::<u8>()?.device_ptr(),
_ => return Err(BitTTTError::storage_error("Weights must be CUDA U8").into()),
}
};
let w_cu_ptr = w_ptr;
let s_ptr = {
let s_storage = scales.storage_and_layout().0;
match &*s_storage {
candle_core::Storage::Cuda(s) => *s.as_cuda_slice::<f32>()?.device_ptr(),
_ => return Err(BitTTTError::storage_error("Scales must be CUDA F32").into()),
}
};
let output = Tensor::zeros(
(batch, out_dim),
candle_core::DType::F32,
&Device::Cuda(dev.clone()),
)?;
let out_ptr = {
let out_storage = output.storage_and_layout().0;
match &*out_storage {
candle_core::Storage::Cuda(s) => *s.as_cuda_slice::<f32>()?.device_ptr(),
_ => return Err(BitTTTError::storage_error("Output allocation failed").into()),
}
};
let module_name = "adaptive_gemm";
let func_name = "adaptive_gemm_n3_kernel_f32";
let core_dev = dev.cuda_device();
core_dev
.load_ptx(ADAPTIVE_PTX.into(), module_name, &[func_name])
.map_err(candle_core::Error::wrap)?;
let f = core_dev
.get_func(module_name, func_name)
.ok_or_else(|| BitTTTError::kernel_error(format!("Kernel '{}' not found", func_name)))
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
let block_dim = 256;
let grid_x = (out_dim as u32 + block_dim - 1) / block_dim;
let grid_y = batch as u32;
let cfg = LaunchConfig {
grid_dim: (grid_x, grid_y, 1),
block_dim: (block_dim, 1, 1),
shared_mem_bytes: 0,
};
tracing::debug!(
"🔥 [CUDA] adaptive_forward: batch={}, in_dim={}, out_dim={}, grid=({},{}), block={}",
batch,
in_dim,
out_dim,
grid_x,
grid_y,
block_dim
);
let params = (
inp_ptr, w_cu_ptr, s_ptr, out_ptr, batch as i32,
in_dim as i32,
out_dim as i32,
);
unsafe { f.launch(cfg, params) }.map_err(candle_core::Error::wrap)?;
Ok(output)
}
#[cfg(not(feature = "cuda"))]
pub fn adaptive_forward(
_input: &Tensor,
_weights: &Tensor,
_scales: &Tensor,
_out_dim: usize,
) -> Result<Tensor> {
Err(BitTTTError::feature_not_enabled("CUDA (feature 'cuda' missing)").into())
}
pub fn smoke_test_compile() -> bool {
#[cfg(feature = "cuda")]
return !ADAPTIVE_PTX.is_empty();
#[cfg(not(feature = "cuda"))]
return false;
}
}