bitnet_quantize/layer/
bitlinear.rs

1//! BitLinear layer - drop-in replacement for nn::Linear with ternary weights.
2
3use candle_core::{Device, Tensor};
4use candle_nn::Module;
5
6use crate::config::BitNetConfig;
7use crate::error::Result;
8use crate::quantization::{
9    dequantize_activations, dequantize_weights, quantize_activations, quantize_weights,
10    TernaryWeight,
11};
12
13/// BitLinear layer with ternary weights and INT8 activations.
14///
15/// This is a drop-in replacement for `candle_nn::Linear` that uses:
16/// - Ternary weights {-1, 0, +1} with per-group scales
17/// - INT8 activation quantization with per-token scales
18///
19/// # Example
20///
21/// ```ignore
22/// use bitnet_rs::{BitLinear, BitNetConfig};
23/// use candle_core::{Device, Tensor};
24///
25/// let device = Device::Cpu;
26/// let config = BitNetConfig::default();
27///
28/// // Create from existing weights
29/// let weight = Tensor::randn(0.0f32, 1.0, (512, 256), &device)?;
30/// let layer = BitLinear::from_weight(&weight, None, &config)?;
31///
32/// // Forward pass
33/// let input = Tensor::randn(0.0f32, 1.0, (4, 256), &device)?;
34/// let output = layer.forward(&input)?;
35/// ```
36#[derive(Debug)]
37pub struct BitLinear {
38    /// Quantized ternary weights.
39    weight: TernaryWeight,
40
41    /// Optional bias (not quantized).
42    bias: Option<Tensor>,
43
44    /// Configuration.
45    config: BitNetConfig,
46
47    /// Device for tensor operations.
48    device: Device,
49}
50
51impl BitLinear {
52    /// Create a new BitLinear layer from a weight tensor.
53    ///
54    /// # Arguments
55    ///
56    /// * `weight` - Weight tensor [out_features, in_features]
57    /// * `bias` - Optional bias tensor [out_features]
58    /// * `config` - BitNet configuration
59    ///
60    /// # Errors
61    ///
62    /// Returns error if weight quantization fails.
63    pub fn from_weight(weight: &Tensor, bias: Option<&Tensor>, config: &BitNetConfig) -> Result<Self> {
64        config.validate()?;
65
66        let device = weight.device().clone();
67        let quantized_weight = quantize_weights(weight, config)?;
68
69        Ok(Self {
70            weight: quantized_weight,
71            bias: bias.cloned(),
72            config: config.clone(),
73            device,
74        })
75    }
76
77    /// Create a new BitLinear layer from pre-quantized weights.
78    ///
79    /// # Arguments
80    ///
81    /// * `weight` - Pre-quantized ternary weight
82    /// * `bias` - Optional bias tensor
83    /// * `config` - BitNet configuration
84    /// * `device` - Device for operations
85    #[must_use]
86    pub fn from_quantized(
87        weight: TernaryWeight,
88        bias: Option<Tensor>,
89        config: BitNetConfig,
90        device: Device,
91    ) -> Self {
92        Self {
93            weight,
94            bias,
95            config,
96            device,
97        }
98    }
99
100    /// Get the input features dimension.
101    #[must_use]
102    pub fn in_features(&self) -> usize {
103        self.weight.in_features()
104    }
105
106    /// Get the output features dimension.
107    #[must_use]
108    pub fn out_features(&self) -> usize {
109        self.weight.out_features()
110    }
111
112    /// Get reference to the quantized weights.
113    #[must_use]
114    pub const fn quantized_weight(&self) -> &TernaryWeight {
115        &self.weight
116    }
117
118    /// Get reference to the bias.
119    #[must_use]
120    pub const fn bias(&self) -> Option<&Tensor> {
121        self.bias.as_ref()
122    }
123
124    /// Get reference to the configuration.
125    #[must_use]
126    pub const fn config(&self) -> &BitNetConfig {
127        &self.config
128    }
129
130    /// Get the device.
131    #[must_use]
132    pub const fn device(&self) -> &Device {
133        &self.device
134    }
135
136    /// Get the weight sparsity.
137    #[must_use]
138    pub fn sparsity(&self) -> f32 {
139        self.weight.sparsity()
140    }
141
142    /// Get the compression ratio.
143    #[must_use]
144    pub fn compression_ratio(&self) -> f32 {
145        self.weight.compression_ratio()
146    }
147
148    /// Forward pass with explicit activation quantization.
149    ///
150    /// This method:
151    /// 1. Quantizes input activations to INT8
152    /// 2. Dequantizes weights for matmul (or uses optimized kernel)
153    /// 3. Performs the linear transformation
154    /// 4. Adds bias if present
155    ///
156    /// # Arguments
157    ///
158    /// * `input` - Input tensor [batch, ..., in_features]
159    ///
160    /// # Errors
161    ///
162    /// Returns error if forward pass fails.
163    pub fn forward_quantized(&self, input: &Tensor) -> Result<Tensor> {
164        // Quantize activations
165        let quantized_input = quantize_activations(input, &self.config)?;
166        let dequant_input = dequantize_activations(&quantized_input, &self.device)?;
167
168        // Dequantize weights for matmul
169        let dequant_weight = dequantize_weights(&self.weight, &self.device)?;
170
171        // Linear transformation: y = x @ W^T
172        let output = dequant_input.matmul(&dequant_weight.t()?)?;
173
174        // Add bias
175        let output = if let Some(ref bias) = self.bias {
176            output.broadcast_add(bias)?
177        } else {
178            output
179        };
180
181        Ok(output)
182    }
183}
184
185impl Module for BitLinear {
186    fn forward(&self, input: &Tensor) -> candle_core::Result<Tensor> {
187        // For standard forward, dequantize and compute
188        // In a production implementation, this would use optimized kernels
189        let dequant_weight = dequantize_weights(&self.weight, &self.device)
190            .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
191
192        let dims = input.dims();
193        let output = if dims.len() == 3 {
194            // Handle 3D input [batch, seq_len, hidden]
195            let (batch, seq_len, hidden) = (dims[0], dims[1], dims[2]);
196            let flat_input = input.reshape((batch * seq_len, hidden))?;
197            let flat_output = flat_input.matmul(&dequant_weight.t()?)?;
198            flat_output.reshape((batch, seq_len, self.out_features()))?
199        } else {
200            // Standard 2D matmul
201            input.matmul(&dequant_weight.t()?)?
202        };
203
204        let output = if let Some(ref bias) = self.bias {
205            output.broadcast_add(bias)?
206        } else {
207            output
208        };
209
210        Ok(output)
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn test_bitlinear_creation() {
220        let device = Device::Cpu;
221        let config = BitNetConfig::default();
222
223        let weight = Tensor::randn(0.0f32, 1.0, (128, 256), &device).unwrap();
224        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
225
226        assert_eq!(layer.in_features(), 256);
227        assert_eq!(layer.out_features(), 128);
228    }
229
230    #[test]
231    fn test_bitlinear_forward() {
232        let device = Device::Cpu;
233        let config = BitNetConfig::default().with_group_size(64);
234
235        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
236        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
237
238        let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
239        let output = layer.forward(&input).unwrap();
240
241        assert_eq!(output.shape().dims(), &[4, 64]);
242    }
243
244    #[test]
245    fn test_bitlinear_forward_quantized() {
246        let device = Device::Cpu;
247        let config = BitNetConfig::default().with_group_size(64);
248
249        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
250        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
251
252        let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
253        let output = layer.forward_quantized(&input).unwrap();
254
255        assert_eq!(output.shape().dims(), &[4, 64]);
256    }
257
258    #[test]
259    fn test_bitlinear_with_bias() {
260        let device = Device::Cpu;
261        let config = BitNetConfig::default().with_group_size(64);
262
263        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
264        let bias = Tensor::randn(0.0f32, 1.0, (64,), &device).unwrap();
265        let layer = BitLinear::from_weight(&weight, Some(&bias), &config).unwrap();
266
267        let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
268        let output = layer.forward(&input).unwrap();
269
270        assert_eq!(output.shape().dims(), &[4, 64]);
271    }
272
273    #[test]
274    fn test_bitlinear_3d_input() {
275        let device = Device::Cpu;
276        let config = BitNetConfig::default().with_group_size(64);
277
278        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
279        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
280
281        // 3D input [batch, seq_len, hidden]
282        let input = Tensor::randn(0.0f32, 1.0, (2, 16, 128), &device).unwrap();
283        let output = layer.forward(&input).unwrap();
284
285        assert_eq!(output.shape().dims(), &[2, 16, 64]);
286    }
287
288    #[test]
289    fn test_bitlinear_sparsity() {
290        let device = Device::Cpu;
291        let config = BitNetConfig::default().with_group_size(64);
292
293        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
294        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
295
296        // Ternary quantization typically results in some sparsity
297        let sparsity = layer.sparsity();
298        assert!(sparsity >= 0.0 && sparsity <= 1.0);
299    }
300
301    #[test]
302    fn test_bitlinear_compression() {
303        let device = Device::Cpu;
304        let config = BitNetConfig::default();
305
306        // Larger weight for meaningful compression measurement
307        let weight = Tensor::randn(0.0f32, 1.0, (1024, 4096), &device).unwrap();
308        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
309
310        let ratio = layer.compression_ratio();
311        assert!(ratio > 1.0, "should achieve some compression");
312    }
313}