use candle_core::{Device, Result, Tensor};
use candle_nn::VarBuilder;
use crate::error::BitTTTError;
use crate::kernels::packing::PackedTensor;
use crate::kernels::{cpu::BitLinearCpu, cuda::BitLinearCuda};
#[derive(Clone)]
pub struct BitLinear {
pub weight: Tensor,
#[allow(dead_code)]
pub in_features: usize,
#[allow(dead_code)]
pub out_features: usize,
pub packed_params: Option<PackedTensor>,
}
impl BitLinear {
pub fn load(in_dim: usize, out_dim: usize, vb: VarBuilder, device: &Device) -> Result<Self> {
let init = candle_nn::init::DEFAULT_KAIMING_NORMAL;
let weight = vb.get_with_hints((out_dim, in_dim), "weight", init)?;
let weight = if device.is_cpu() {
let data = weight.to_vec1::<f32>()?;
Tensor::from_vec(data, weight.shape(), device)?
} else {
weight.to_device(device)?
};
Ok(Self {
weight,
in_features: in_dim,
out_features: out_dim,
packed_params: None,
})
}
pub fn from_weight_tensor(
weight: &Tensor,
in_dim: usize,
out_dim: usize,
device: &Device,
) -> Result<Self> {
let weight = weight
.to_dtype(candle_core::DType::F32)?
.to_device(device)?;
let dims = weight.dims();
if dims != [out_dim, in_dim] {
return Err(candle_core::Error::Msg(format!(
"Weight shape mismatch: expected [{}, {}], got {:?}",
out_dim, in_dim, dims
)));
}
let weight = if device.is_cpu() {
let data = weight.flatten_all()?.to_vec1::<f32>()?;
Tensor::from_vec(data, (out_dim, in_dim), device)?
} else {
weight
};
Ok(Self {
weight,
in_features: in_dim,
out_features: out_dim,
packed_params: None,
})
}
pub fn from_packed_tensors(
weight_packed: &Tensor,
scales: &Tensor,
device: &Device,
) -> Result<Self> {
let dims = weight_packed.dims();
let (out_dim, in_dim, _n_bases) = match dims.len() {
2 => (dims[0], dims[1] * 4, 1usize),
3 => (dims[0], dims[1] * 4, dims[2]),
_ => {
return Err(candle_core::Error::Msg(format!(
"Invalid weight_packed shape: expected 2D or 3D, got {:?}",
dims
)))
}
};
let packed_data = if weight_packed.dtype() != candle_core::DType::U8 {
eprintln!(
"⚠️ [PACKED] Converting {:?} → U8 (VarBuilder dtype issue)",
weight_packed.dtype()
);
weight_packed
.to_dtype(candle_core::DType::U8)?
.to_device(device)?
} else {
weight_packed.to_device(device)?
};
let scales_data = scales
.to_dtype(candle_core::DType::F32)?
.to_device(device)?;
let packed_params =
PackedTensor::from_loaded(packed_data, scales_data, out_dim, in_dim, device)?;
let weight = Tensor::zeros((out_dim, in_dim), candle_core::DType::F32, device)?;
Ok(Self {
weight,
in_features: in_dim,
out_features: out_dim,
packed_params: Some(packed_params),
})
}
#[allow(dead_code)]
pub fn load_packed(
in_dim: usize,
out_dim: usize,
n_bases: usize,
vb: VarBuilder,
device: &Device,
) -> Result<Self> {
let packed_shape = if n_bases == 1 {
vec![out_dim, in_dim / 4]
} else {
vec![out_dim, in_dim / 4, n_bases]
};
let packed_result = vb.get(packed_shape.as_slice(), "weight_packed");
let scales_result = vb.get(&[n_bases], "scales");
match (packed_result, scales_result) {
(Ok(packed_raw), Ok(scales)) => {
let packed_data = packed_raw.to_device(device)?;
let scales_data = scales
.to_dtype(candle_core::DType::F32)?
.to_device(device)?;
let packed_params =
PackedTensor::from_loaded(packed_data, scales_data, out_dim, in_dim, device)?;
let weight = Tensor::zeros((out_dim, in_dim), candle_core::DType::F32, device)?;
Ok(Self {
weight,
in_features: in_dim,
out_features: out_dim,
packed_params: Some(packed_params),
})
}
_ => {
Self::load(in_dim, out_dim, vb, device)
}
}
}
pub fn precompute_packed(&mut self) -> Result<()> {
let packed = PackedTensor::pack(&self.weight)?;
self.packed_params = Some(packed);
Ok(())
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (input, original_shape) = if x.rank() > 2 {
let dims = x.dims();
let last_dim = dims[dims.len() - 1];
let flattened_dim = x.elem_count() / last_dim;
(x.reshape(&[flattened_dim, last_dim])?, Some(dims.to_vec()))
} else {
(x.clone(), None)
};
if let Some(packed) = &self.packed_params {
let result = match input.device() {
Device::Cpu => {
BitLinearCpu::forward(&input, packed)
}
Device::Cuda(_) => {
BitLinearCuda::forward(&input, packed)
}
_ => {
return Err(BitTTTError::kernel_error(
"Packed params present but Custom Kernel not implemented for this device",
)
.into());
}
}?;
if let Some(mut dims) = original_shape {
let last_idx = dims.len() - 1;
let (_total, out_dim) = result.dims2()?;
dims[last_idx] = out_dim;
return result.reshape(&dims[..]);
} else {
return Ok(result);
}
}
#[cfg(debug_assertions)]
tracing::debug!("📦 BitLinear: Using legacy FP path (no STE quantization)");
let result = input.matmul(&self.weight.t()?)?;
if let Some(mut dims) = original_shape {
let last_idx = dims.len() - 1;
dims[last_idx] = self.out_features;
result.reshape(&dims[..])
} else {
Ok(result)
}
}
}