Skip to main content

bitnet_quantize/layer/
bitlinear.rs

1//! BitLinear layer - drop-in replacement for nn::Linear with ternary weights.
2
3use candle_core::{Device, Tensor};
4use candle_nn::Module;
5
6use crate::config::BitNetConfig;
7use crate::error::Result;
8use crate::quantization::{
9    dequantize_activations, dequantize_weights, quantize_activations, quantize_weights,
10    TernaryWeight,
11};
12
13/// Emit a warning when CPU fallback is used.
14///
15/// GPU (CUDA) is the intended default for BitNet operations.
16/// This function warns users once per process when CPU is being used.
17fn warn_cpu_fallback(device: &Device) {
18    static WARN_ONCE: std::sync::Once = std::sync::Once::new();
19    if matches!(device, Device::Cpu) {
20        WARN_ONCE.call_once(|| {
21            eprintln!(
22                "bitnet-quantize: CPU device in use. CUDA is the intended default; \
23                 enable the 'cuda' feature and use Device::cuda_if_available(0) when possible."
24            );
25        });
26    }
27}
28
29/// BitLinear layer with ternary weights and INT8 activations.
30///
31/// This is a drop-in replacement for `candle_nn::Linear` that uses:
32/// - Ternary weights {-1, 0, +1} with per-group scales
33/// - INT8 activation quantization with per-token scales
34///
35/// # Example
36///
37/// ```ignore
38/// use bitnet_rs::{BitLinear, BitNetConfig};
39/// use candle_core::{Device, Tensor};
40///
41/// let device = Device::Cpu;
42/// let config = BitNetConfig::default();
43///
44/// // Create from existing weights
45/// let weight = Tensor::randn(0.0f32, 1.0, (512, 256), &device)?;
46/// let layer = BitLinear::from_weight(&weight, None, &config)?;
47///
48/// // Forward pass
49/// let input = Tensor::randn(0.0f32, 1.0, (4, 256), &device)?;
50/// let output = layer.forward(&input)?;
51/// ```
52#[derive(Debug)]
53pub struct BitLinear {
54    /// Quantized ternary weights.
55    weight: TernaryWeight,
56
57    /// Optional bias (not quantized).
58    bias: Option<Tensor>,
59
60    /// Configuration.
61    config: BitNetConfig,
62
63    /// Device for tensor operations.
64    device: Device,
65}
66
67impl BitLinear {
68    /// Create a new BitLinear layer from a weight tensor.
69    ///
70    /// # Arguments
71    ///
72    /// * `weight` - Weight tensor [out_features, in_features]
73    /// * `bias` - Optional bias tensor [out_features]
74    /// * `config` - BitNet configuration
75    ///
76    /// # Errors
77    ///
78    /// Returns error if weight quantization fails.
79    pub fn from_weight(
80        weight: &Tensor,
81        bias: Option<&Tensor>,
82        config: &BitNetConfig,
83    ) -> Result<Self> {
84        config.validate()?;
85
86        let device = weight.device().clone();
87        warn_cpu_fallback(&device);
88        let quantized_weight = quantize_weights(weight, config)?;
89
90        Ok(Self {
91            weight: quantized_weight,
92            bias: bias.cloned(),
93            config: config.clone(),
94            device,
95        })
96    }
97
98    /// Create a new BitLinear layer from pre-quantized weights.
99    ///
100    /// # Arguments
101    ///
102    /// * `weight` - Pre-quantized ternary weight
103    /// * `bias` - Optional bias tensor
104    /// * `config` - BitNet configuration
105    /// * `device` - Device for operations
106    #[must_use]
107    pub fn from_quantized(
108        weight: TernaryWeight,
109        bias: Option<Tensor>,
110        config: BitNetConfig,
111        device: Device,
112    ) -> Self {
113        warn_cpu_fallback(&device);
114        Self {
115            weight,
116            bias,
117            config,
118            device,
119        }
120    }
121
122    /// Get the input features dimension.
123    #[must_use]
124    pub fn in_features(&self) -> usize {
125        self.weight.in_features()
126    }
127
128    /// Get the output features dimension.
129    #[must_use]
130    pub fn out_features(&self) -> usize {
131        self.weight.out_features()
132    }
133
134    /// Get reference to the quantized weights.
135    #[must_use]
136    pub const fn quantized_weight(&self) -> &TernaryWeight {
137        &self.weight
138    }
139
140    /// Get reference to the bias.
141    #[must_use]
142    pub const fn bias(&self) -> Option<&Tensor> {
143        self.bias.as_ref()
144    }
145
146    /// Get reference to the configuration.
147    #[must_use]
148    pub const fn config(&self) -> &BitNetConfig {
149        &self.config
150    }
151
152    /// Get the device.
153    #[must_use]
154    pub const fn device(&self) -> &Device {
155        &self.device
156    }
157
158    /// Get the weight sparsity.
159    #[must_use]
160    pub fn sparsity(&self) -> f32 {
161        self.weight.sparsity()
162    }
163
164    /// Get the compression ratio.
165    #[must_use]
166    pub fn compression_ratio(&self) -> f32 {
167        self.weight.compression_ratio()
168    }
169
170    /// Forward pass with explicit activation quantization.
171    ///
172    /// This method:
173    /// 1. Quantizes input activations to INT8
174    /// 2. Uses GPU-optimized ternary matmul if available, otherwise dequantizes
175    /// 3. Performs the linear transformation
176    /// 4. Adds bias if present
177    ///
178    /// When the `cuda` feature is enabled and a CUDA device is detected,
179    /// this uses optimized GPU kernels that exploit ternary weight sparsity.
180    ///
181    /// # Arguments
182    ///
183    /// * `input` - Input tensor [batch, ..., in_features]
184    ///
185    /// # Errors
186    ///
187    /// Returns error if forward pass fails.
188    pub fn forward_quantized(&self, input: &Tensor) -> Result<Tensor> {
189        // Quantize activations
190        let quantized_input = quantize_activations(input, &self.config)?;
191        let dequant_input = dequantize_activations(&quantized_input, &self.device)?;
192
193        // Try GPU-optimized ternary matmul if available
194        #[cfg(feature = "cuda")]
195        let output = {
196            if crate::kernels::should_use_gpu(&dequant_input, &self.weight) {
197                crate::kernels::ternary_matmul_gpu(&dequant_input, &self.weight)?
198            } else {
199                // Fallback to standard matmul
200                let dequant_weight = dequantize_weights(&self.weight, &self.device)?;
201                dequant_input.matmul(&dequant_weight.t()?)?
202            }
203        };
204
205        #[cfg(not(feature = "cuda"))]
206        let output = {
207            // Standard dequantize + matmul path
208            let dequant_weight = dequantize_weights(&self.weight, &self.device)?;
209            dequant_input.matmul(&dequant_weight.t()?)?
210        };
211
212        // Add bias
213        let output = if let Some(ref bias) = self.bias {
214            output.broadcast_add(bias)?
215        } else {
216            output
217        };
218
219        Ok(output)
220    }
221}
222
223impl Module for BitLinear {
224    fn forward(&self, input: &Tensor) -> candle_core::Result<Tensor> {
225        let dims = input.dims();
226
227        // Handle 3D input [batch, seq_len, hidden] by flattening
228        let (flat_input, original_shape) = if dims.len() == 3 {
229            let (batch, seq_len, hidden) = (dims[0], dims[1], dims[2]);
230            (
231                input.reshape((batch * seq_len, hidden))?,
232                Some((batch, seq_len)),
233            )
234        } else {
235            (input.clone(), None)
236        };
237
238        // Use GPU-optimized ternary matmul if available
239        #[cfg(feature = "cuda")]
240        let output = {
241            if crate::kernels::cuda_available()
242                && crate::kernels::should_use_gpu(&flat_input, &self.weight)
243            {
244                crate::kernels::ternary_matmul_gpu(&flat_input, &self.weight)
245                    .map_err(|e| candle_core::Error::Msg(e.to_string()))?
246            } else {
247                // Fallback to standard dequantize + matmul
248                let dequant_weight = dequantize_weights(&self.weight, &self.device)
249                    .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
250                flat_input.matmul(&dequant_weight.t()?)?
251            }
252        };
253
254        #[cfg(not(feature = "cuda"))]
255        let output = {
256            let dequant_weight = dequantize_weights(&self.weight, &self.device)
257                .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
258            flat_input.matmul(&dequant_weight.t()?)?
259        };
260
261        // Reshape back to 3D if needed
262        let output = if let Some((batch, seq_len)) = original_shape {
263            output.reshape((batch, seq_len, self.out_features()))?
264        } else {
265            output
266        };
267
268        // Add bias
269        let output = if let Some(ref bias) = self.bias {
270            output.broadcast_add(bias)?
271        } else {
272            output
273        };
274
275        Ok(output)
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn test_bitlinear_creation() {
285        let device = Device::Cpu;
286        let config = BitNetConfig::default();
287
288        let weight = Tensor::randn(0.0f32, 1.0, (128, 256), &device).unwrap();
289        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
290
291        assert_eq!(layer.in_features(), 256);
292        assert_eq!(layer.out_features(), 128);
293    }
294
295    #[test]
296    fn test_bitlinear_forward() {
297        let device = Device::Cpu;
298        let config = BitNetConfig::default().with_group_size(64);
299
300        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
301        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
302
303        let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
304        let output = layer.forward(&input).unwrap();
305
306        assert_eq!(output.shape().dims(), &[4, 64]);
307    }
308
309    #[test]
310    fn test_bitlinear_forward_quantized() {
311        let device = Device::Cpu;
312        let config = BitNetConfig::default().with_group_size(64);
313
314        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
315        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
316
317        let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
318        let output = layer.forward_quantized(&input).unwrap();
319
320        assert_eq!(output.shape().dims(), &[4, 64]);
321    }
322
323    #[test]
324    fn test_bitlinear_with_bias() {
325        let device = Device::Cpu;
326        let config = BitNetConfig::default().with_group_size(64);
327
328        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
329        let bias = Tensor::randn(0.0f32, 1.0, (64,), &device).unwrap();
330        let layer = BitLinear::from_weight(&weight, Some(&bias), &config).unwrap();
331
332        let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
333        let output = layer.forward(&input).unwrap();
334
335        assert_eq!(output.shape().dims(), &[4, 64]);
336    }
337
338    #[test]
339    fn test_bitlinear_3d_input() {
340        let device = Device::Cpu;
341        let config = BitNetConfig::default().with_group_size(64);
342
343        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
344        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
345
346        // 3D input [batch, seq_len, hidden]
347        let input = Tensor::randn(0.0f32, 1.0, (2, 16, 128), &device).unwrap();
348        let output = layer.forward(&input).unwrap();
349
350        assert_eq!(output.shape().dims(), &[2, 16, 64]);
351    }
352
353    #[test]
354    fn test_bitlinear_sparsity() {
355        let device = Device::Cpu;
356        let config = BitNetConfig::default().with_group_size(64);
357
358        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
359        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
360
361        // Ternary quantization typically results in some sparsity
362        let sparsity = layer.sparsity();
363        assert!(sparsity >= 0.0 && sparsity <= 1.0);
364    }
365
366    #[test]
367    fn test_bitlinear_compression() {
368        let device = Device::Cpu;
369        let config = BitNetConfig::default();
370
371        // Larger weight for meaningful compression measurement
372        let weight = Tensor::randn(0.0f32, 1.0, (1024, 4096), &device).unwrap();
373        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
374
375        let ratio = layer.compression_ratio();
376        assert!(ratio > 1.0, "should achieve some compression");
377    }
378}