use candle_core::{Device, Tensor};
use candle_nn::Module;
use crate::config::BitNetConfig;
use crate::error::Result;
use crate::quantization::{
dequantize_activations, dequantize_weights, quantize_activations, quantize_weights,
TernaryWeight,
};
fn warn_cpu_fallback(device: &Device) {
static WARN_ONCE: std::sync::Once = std::sync::Once::new();
if matches!(device, Device::Cpu) {
WARN_ONCE.call_once(|| {
eprintln!(
"bitnet-quantize: CPU device in use. CUDA is the intended default; \
enable the 'cuda' feature and use Device::cuda_if_available(0) when possible."
);
});
}
}
#[derive(Debug)]
pub struct BitLinear {
weight: TernaryWeight,
bias: Option<Tensor>,
config: BitNetConfig,
device: Device,
}
impl BitLinear {
pub fn from_weight(
weight: &Tensor,
bias: Option<&Tensor>,
config: &BitNetConfig,
) -> Result<Self> {
config.validate()?;
let device = weight.device().clone();
warn_cpu_fallback(&device);
let quantized_weight = quantize_weights(weight, config)?;
Ok(Self {
weight: quantized_weight,
bias: bias.cloned(),
config: config.clone(),
device,
})
}
#[must_use]
pub fn from_quantized(
weight: TernaryWeight,
bias: Option<Tensor>,
config: BitNetConfig,
device: Device,
) -> Self {
warn_cpu_fallback(&device);
Self {
weight,
bias,
config,
device,
}
}
#[must_use]
pub fn in_features(&self) -> usize {
self.weight.in_features()
}
#[must_use]
pub fn out_features(&self) -> usize {
self.weight.out_features()
}
#[must_use]
pub const fn quantized_weight(&self) -> &TernaryWeight {
&self.weight
}
#[must_use]
pub const fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
#[must_use]
pub const fn config(&self) -> &BitNetConfig {
&self.config
}
#[must_use]
pub const fn device(&self) -> &Device {
&self.device
}
#[must_use]
pub fn sparsity(&self) -> f32 {
self.weight.sparsity()
}
#[must_use]
pub fn compression_ratio(&self) -> f32 {
self.weight.compression_ratio()
}
pub fn forward_quantized(&self, input: &Tensor) -> Result<Tensor> {
let quantized_input = quantize_activations(input, &self.config)?;
let dequant_input = dequantize_activations(&quantized_input, &self.device)?;
#[cfg(feature = "cuda")]
let output = {
if crate::kernels::should_use_gpu(&dequant_input, &self.weight) {
crate::kernels::ternary_matmul_gpu(&dequant_input, &self.weight)?
} else {
let dequant_weight = dequantize_weights(&self.weight, &self.device)?;
dequant_input.matmul(&dequant_weight.t()?)?
}
};
#[cfg(not(feature = "cuda"))]
let output = {
let dequant_weight = dequantize_weights(&self.weight, &self.device)?;
dequant_input.matmul(&dequant_weight.t()?)?
};
let output = if let Some(ref bias) = self.bias {
output.broadcast_add(bias)?
} else {
output
};
Ok(output)
}
}
impl Module for BitLinear {
fn forward(&self, input: &Tensor) -> candle_core::Result<Tensor> {
let dims = input.dims();
let (flat_input, original_shape) = if dims.len() == 3 {
let (batch, seq_len, hidden) = (dims[0], dims[1], dims[2]);
(
input.reshape((batch * seq_len, hidden))?,
Some((batch, seq_len)),
)
} else {
(input.clone(), None)
};
#[cfg(feature = "cuda")]
let output = {
if crate::kernels::cuda_available()
&& crate::kernels::should_use_gpu(&flat_input, &self.weight)
{
crate::kernels::ternary_matmul_gpu(&flat_input, &self.weight)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?
} else {
let dequant_weight = dequantize_weights(&self.weight, &self.device)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
flat_input.matmul(&dequant_weight.t()?)?
}
};
#[cfg(not(feature = "cuda"))]
let output = {
let dequant_weight = dequantize_weights(&self.weight, &self.device)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
flat_input.matmul(&dequant_weight.t()?)?
};
let output = if let Some((batch, seq_len)) = original_shape {
output.reshape((batch, seq_len, self.out_features()))?
} else {
output
};
let output = if let Some(ref bias) = self.bias {
output.broadcast_add(bias)?
} else {
output
};
Ok(output)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bitlinear_creation() {
let device = Device::Cpu;
let config = BitNetConfig::default();
let weight = Tensor::randn(0.0f32, 1.0, (128, 256), &device).unwrap();
let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
assert_eq!(layer.in_features(), 256);
assert_eq!(layer.out_features(), 128);
}
#[test]
fn test_bitlinear_forward() {
let device = Device::Cpu;
let config = BitNetConfig::default().with_group_size(64);
let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
let output = layer.forward(&input).unwrap();
assert_eq!(output.shape().dims(), &[4, 64]);
}
#[test]
fn test_bitlinear_forward_quantized() {
let device = Device::Cpu;
let config = BitNetConfig::default().with_group_size(64);
let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
let output = layer.forward_quantized(&input).unwrap();
assert_eq!(output.shape().dims(), &[4, 64]);
}
#[test]
fn test_bitlinear_with_bias() {
let device = Device::Cpu;
let config = BitNetConfig::default().with_group_size(64);
let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
let bias = Tensor::randn(0.0f32, 1.0, (64,), &device).unwrap();
let layer = BitLinear::from_weight(&weight, Some(&bias), &config).unwrap();
let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
let output = layer.forward(&input).unwrap();
assert_eq!(output.shape().dims(), &[4, 64]);
}
#[test]
fn test_bitlinear_3d_input() {
let device = Device::Cpu;
let config = BitNetConfig::default().with_group_size(64);
let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
let input = Tensor::randn(0.0f32, 1.0, (2, 16, 128), &device).unwrap();
let output = layer.forward(&input).unwrap();
assert_eq!(output.shape().dims(), &[2, 16, 64]);
}
#[test]
fn test_bitlinear_sparsity() {
let device = Device::Cpu;
let config = BitNetConfig::default().with_group_size(64);
let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
let sparsity = layer.sparsity();
assert!(sparsity >= 0.0 && sparsity <= 1.0);
}
#[test]
fn test_bitlinear_compression() {
let device = Device::Cpu;
let config = BitNetConfig::default();
let weight = Tensor::randn(0.0f32, 1.0, (1024, 4096), &device).unwrap();
let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
let ratio = layer.compression_ratio();
assert!(ratio > 1.0, "should achieve some compression");
}
}