Skip to main content

bitnet_quantize/layer/
bitlinear.rs

1//! BitLinear layer - drop-in replacement for nn::Linear with ternary weights.
2
3use candle_core::{Device, Tensor};
4use candle_nn::Module;
5
6use crate::config::BitNetConfig;
7use crate::error::Result;
8use crate::quantization::{
9    dequantize_activations, dequantize_weights, quantize_activations, quantize_weights,
10    TernaryWeight,
11};
12
13fn warn_cpu_fallback(device: &Device) {
14    static WARN_ONCE: std::sync::Once = std::sync::Once::new();
15    if matches!(device, Device::Cpu) {
16        WARN_ONCE.call_once(|| {
17            eprintln!(
18                "bitnet-quantize: CPU device in use. CUDA is the intended default; enable the 'cuda' feature and use Device::cuda_if_available(0) when possible."
19            );
20        });
21    }
22}
23
24/// BitLinear layer with ternary weights and INT8 activations.
25///
26/// This is a drop-in replacement for `candle_nn::Linear` that uses:
27/// - Ternary weights {-1, 0, +1} with per-group scales
28/// - INT8 activation quantization with per-token scales
29///
30/// # Example
31///
32/// ```ignore
33/// use bitnet_rs::{BitLinear, BitNetConfig};
34/// use candle_core::{Device, Tensor};
35///
36/// let device = Device::Cpu;
37/// let config = BitNetConfig::default();
38///
39/// // Create from existing weights
40/// let weight = Tensor::randn(0.0f32, 1.0, (512, 256), &device)?;
41/// let layer = BitLinear::from_weight(&weight, None, &config)?;
42///
43/// // Forward pass
44/// let input = Tensor::randn(0.0f32, 1.0, (4, 256), &device)?;
45/// let output = layer.forward(&input)?;
46/// ```
47#[derive(Debug)]
48pub struct BitLinear {
49    /// Quantized ternary weights.
50    weight: TernaryWeight,
51
52    /// Optional bias (not quantized).
53    bias: Option<Tensor>,
54
55    /// Configuration.
56    config: BitNetConfig,
57
58    /// Device for tensor operations.
59    device: Device,
60}
61
62impl BitLinear {
63    /// Create a new BitLinear layer from a weight tensor.
64    ///
65    /// # Arguments
66    ///
67    /// * `weight` - Weight tensor [out_features, in_features]
68    /// * `bias` - Optional bias tensor [out_features]
69    /// * `config` - BitNet configuration
70    ///
71    /// # Errors
72    ///
73    /// Returns error if weight quantization fails.
74    pub fn from_weight(weight: &Tensor, bias: Option<&Tensor>, config: &BitNetConfig) -> Result<Self> {
75        config.validate()?;
76
77        let device = weight.device().clone();
78        warn_cpu_fallback(&device);
79        let quantized_weight = quantize_weights(weight, config)?;
80
81        Ok(Self {
82            weight: quantized_weight,
83            bias: bias.cloned(),
84            config: config.clone(),
85            device,
86        })
87    }
88
89    /// Create a new BitLinear layer from pre-quantized weights.
90    ///
91    /// # Arguments
92    ///
93    /// * `weight` - Pre-quantized ternary weight
94    /// * `bias` - Optional bias tensor
95    /// * `config` - BitNet configuration
96    /// * `device` - Device for operations
97    #[must_use]
98    pub fn from_quantized(
99        weight: TernaryWeight,
100        bias: Option<Tensor>,
101        config: BitNetConfig,
102        device: Device,
103    ) -> Self {
104        warn_cpu_fallback(&device);
105        Self {
106            weight,
107            bias,
108            config,
109            device,
110        }
111    }
112
113    /// Get the input features dimension.
114    #[must_use]
115    pub fn in_features(&self) -> usize {
116        self.weight.in_features()
117    }
118
119    /// Get the output features dimension.
120    #[must_use]
121    pub fn out_features(&self) -> usize {
122        self.weight.out_features()
123    }
124
125    /// Get reference to the quantized weights.
126    #[must_use]
127    pub const fn quantized_weight(&self) -> &TernaryWeight {
128        &self.weight
129    }
130
131    /// Get reference to the bias.
132    #[must_use]
133    pub const fn bias(&self) -> Option<&Tensor> {
134        self.bias.as_ref()
135    }
136
137    /// Get reference to the configuration.
138    #[must_use]
139    pub const fn config(&self) -> &BitNetConfig {
140        &self.config
141    }
142
143    /// Get the device.
144    #[must_use]
145    pub const fn device(&self) -> &Device {
146        &self.device
147    }
148
149    /// Get the weight sparsity.
150    #[must_use]
151    pub fn sparsity(&self) -> f32 {
152        self.weight.sparsity()
153    }
154
155    /// Get the compression ratio.
156    #[must_use]
157    pub fn compression_ratio(&self) -> f32 {
158        self.weight.compression_ratio()
159    }
160
161    /// Forward pass with explicit activation quantization.
162    ///
163    /// This method:
164    /// 1. Quantizes input activations to INT8
165    /// 2. Dequantizes weights for matmul (or uses optimized kernel)
166    /// 3. Performs the linear transformation
167    /// 4. Adds bias if present
168    ///
169    /// # Arguments
170    ///
171    /// * `input` - Input tensor [batch, ..., in_features]
172    ///
173    /// # Errors
174    ///
175    /// Returns error if forward pass fails.
176    pub fn forward_quantized(&self, input: &Tensor) -> Result<Tensor> {
177        // Quantize activations
178        let quantized_input = quantize_activations(input, &self.config)?;
179        let dequant_input = dequantize_activations(&quantized_input, &self.device)?;
180
181        // Dequantize weights for matmul
182        let dequant_weight = dequantize_weights(&self.weight, &self.device)?;
183
184        // Linear transformation: y = x @ W^T
185        let output = dequant_input.matmul(&dequant_weight.t()?)?;
186
187        // Add bias
188        let output = if let Some(ref bias) = self.bias {
189            output.broadcast_add(bias)?
190        } else {
191            output
192        };
193
194        Ok(output)
195    }
196}
197
198impl Module for BitLinear {
199    fn forward(&self, input: &Tensor) -> candle_core::Result<Tensor> {
200        // For standard forward, dequantize and compute
201        // In a production implementation, this would use optimized kernels
202        let dequant_weight = dequantize_weights(&self.weight, &self.device)
203            .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
204
205        let dims = input.dims();
206        let output = if dims.len() == 3 {
207            // Handle 3D input [batch, seq_len, hidden]
208            let (batch, seq_len, hidden) = (dims[0], dims[1], dims[2]);
209            let flat_input = input.reshape((batch * seq_len, hidden))?;
210            let flat_output = flat_input.matmul(&dequant_weight.t()?)?;
211            flat_output.reshape((batch, seq_len, self.out_features()))?
212        } else {
213            // Standard 2D matmul
214            input.matmul(&dequant_weight.t()?)?
215        };
216
217        let output = if let Some(ref bias) = self.bias {
218            output.broadcast_add(bias)?
219        } else {
220            output
221        };
222
223        Ok(output)
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_bitlinear_creation() {
233        let device = Device::Cpu;
234        let config = BitNetConfig::default();
235
236        let weight = Tensor::randn(0.0f32, 1.0, (128, 256), &device).unwrap();
237        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
238
239        assert_eq!(layer.in_features(), 256);
240        assert_eq!(layer.out_features(), 128);
241    }
242
243    #[test]
244    fn test_bitlinear_forward() {
245        let device = Device::Cpu;
246        let config = BitNetConfig::default().with_group_size(64);
247
248        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
249        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
250
251        let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
252        let output = layer.forward(&input).unwrap();
253
254        assert_eq!(output.shape().dims(), &[4, 64]);
255    }
256
257    #[test]
258    fn test_bitlinear_forward_quantized() {
259        let device = Device::Cpu;
260        let config = BitNetConfig::default().with_group_size(64);
261
262        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
263        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
264
265        let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
266        let output = layer.forward_quantized(&input).unwrap();
267
268        assert_eq!(output.shape().dims(), &[4, 64]);
269    }
270
271    #[test]
272    fn test_bitlinear_with_bias() {
273        let device = Device::Cpu;
274        let config = BitNetConfig::default().with_group_size(64);
275
276        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
277        let bias = Tensor::randn(0.0f32, 1.0, (64,), &device).unwrap();
278        let layer = BitLinear::from_weight(&weight, Some(&bias), &config).unwrap();
279
280        let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
281        let output = layer.forward(&input).unwrap();
282
283        assert_eq!(output.shape().dims(), &[4, 64]);
284    }
285
286    #[test]
287    fn test_bitlinear_3d_input() {
288        let device = Device::Cpu;
289        let config = BitNetConfig::default().with_group_size(64);
290
291        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
292        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
293
294        // 3D input [batch, seq_len, hidden]
295        let input = Tensor::randn(0.0f32, 1.0, (2, 16, 128), &device).unwrap();
296        let output = layer.forward(&input).unwrap();
297
298        assert_eq!(output.shape().dims(), &[2, 16, 64]);
299    }
300
301    #[test]
302    fn test_bitlinear_sparsity() {
303        let device = Device::Cpu;
304        let config = BitNetConfig::default().with_group_size(64);
305
306        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
307        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
308
309        // Ternary quantization typically results in some sparsity
310        let sparsity = layer.sparsity();
311        assert!(sparsity >= 0.0 && sparsity <= 1.0);
312    }
313
314    #[test]
315    fn test_bitlinear_compression() {
316        let device = Device::Cpu;
317        let config = BitNetConfig::default();
318
319        // Larger weight for meaningful compression measurement
320        let weight = Tensor::randn(0.0f32, 1.0, (1024, 4096), &device).unwrap();
321        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
322
323        let ratio = layer.compression_ratio();
324        assert!(ratio > 1.0, "should achieve some compression");
325    }
326}