Skip to main content

bitnet_quantize/quantization/
weight.rs

1//! Weight quantization for BitNet.
2//!
3//! Implements AbsMean quantization: `W_q = round(W / mean(|W|))` clamped to {-1, 0, +1}.
4
5use candle_core::{Device, Tensor};
6use serde::{Deserialize, Serialize};
7use trit_vsa::PackedTritVec;
8
9use crate::config::BitNetConfig;
10use crate::error::{BitNetError, Result};
11
12/// Ternary weight representation with per-group scales.
13///
14/// Weights are quantized to {-1, 0, +1} using AbsMean quantization,
15/// with a scale factor stored per group.
16#[derive(Clone, Serialize, Deserialize)]
17pub struct TernaryWeight {
18    /// Packed ternary values (bitsliced storage).
19    /// Shape: [out_features, in_features] flattened.
20    pub data: Vec<PackedTritVec>,
21
22    /// Scale factors per group.
23    /// For a weight matrix [out, in], scales has shape [out, in/group_size].
24    pub scales: Vec<f32>,
25
26    /// Original shape [out_features, in_features].
27    pub shape: (usize, usize),
28
29    /// Group size used for quantization.
30    pub group_size: usize,
31}
32
33impl std::fmt::Debug for TernaryWeight {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("TernaryWeight")
36            .field("shape", &self.shape)
37            .field("group_size", &self.group_size)
38            .field("num_scales", &self.scales.len())
39            .field("sparsity", &self.sparsity())
40            .finish_non_exhaustive()
41    }
42}
43
44impl TernaryWeight {
45    /// Get the output features dimension.
46    #[must_use]
47    pub const fn out_features(&self) -> usize {
48        self.shape.0
49    }
50
51    /// Get the input features dimension.
52    #[must_use]
53    pub const fn in_features(&self) -> usize {
54        self.shape.1
55    }
56
57    /// Calculate the sparsity (fraction of zeros).
58    #[must_use]
59    #[allow(clippy::cast_precision_loss)]
60    pub fn sparsity(&self) -> f32 {
61        let total_nonzero: usize = self.data.iter().map(PackedTritVec::count_nonzero).sum();
62        let total_elements = self.shape.0 * self.shape.1;
63        1.0 - (total_nonzero as f32 / total_elements as f32)
64    }
65
66    /// Memory size in bytes.
67    #[must_use]
68    pub fn memory_bytes(&self) -> usize {
69        // Packed trits: 2 bits per trit, so num_words * 4 * 2 per row
70        let trit_bytes: usize = self.data.iter().map(|v| v.num_words() * 8).sum();
71        // Scales: f32 per group
72        let scale_bytes = self.scales.len() * 4;
73        trit_bytes + scale_bytes
74    }
75
76    /// Compression ratio vs FP32.
77    #[must_use]
78    #[allow(clippy::cast_precision_loss)]
79    pub fn compression_ratio(&self) -> f32 {
80        let fp32_bytes = self.shape.0 * self.shape.1 * 4;
81        fp32_bytes as f32 / self.memory_bytes() as f32
82    }
83}
84
85/// Quantize a weight tensor to ternary using AbsMean quantization.
86///
87/// # Algorithm
88///
89/// For each group of weights:
90/// 1. Compute `scale = mean(|W|)`
91/// 2. Compute `W_q = round(W / scale)` clamped to {-1, 0, +1}
92///
93/// # Arguments
94///
95/// * `weight` - Input weight tensor [out_features, in_features]
96/// * `config` - BitNet configuration
97///
98/// # Errors
99///
100/// Returns error if weight has wrong shape or quantization fails.
101pub fn quantize_weights(weight: &Tensor, config: &BitNetConfig) -> Result<TernaryWeight> {
102    let shape = weight.shape().dims();
103    if shape.len() != 2 {
104        return Err(BitNetError::InvalidConfig(
105            "weight must be 2D [out_features, in_features]".to_string(),
106        ));
107    }
108
109    let out_features = shape[0];
110    let in_features = shape[1];
111    let group_size = config.group_size;
112
113    // Ensure in_features is divisible by group_size
114    if !in_features.is_multiple_of(group_size) {
115        return Err(BitNetError::InvalidConfig(format!(
116            "in_features ({in_features}) must be divisible by group_size ({group_size})"
117        )));
118    }
119
120    let num_groups_per_row = in_features / group_size;
121    let mut scales = Vec::with_capacity(out_features * num_groups_per_row);
122    let mut data = Vec::with_capacity(out_features);
123
124    // Convert to f32 for processing
125    let weight_f32 = weight.to_dtype(candle_core::DType::F32)?.to_vec2::<f32>()?;
126
127    for row in &weight_f32 {
128        let mut packed = PackedTritVec::new(in_features);
129
130        for g in 0..num_groups_per_row {
131            let start = g * group_size;
132            let end = start + group_size;
133            let group = &row[start..end];
134
135            // Compute AbsMean scale
136            let abs_mean: f32 = group.iter().map(|x| x.abs()).sum::<f32>() / group_size as f32;
137            let scale = if abs_mean > 0.0 { abs_mean } else { 1.0 };
138            scales.push(scale);
139
140            // Quantize each value in the group
141            for (i, &val) in group.iter().enumerate() {
142                let normalized = val / scale;
143                let quantized = normalized.round().clamp(-1.0, 1.0) as i8;
144                let trit = trit_vsa::Trit::from_value(quantized as i32)?;
145                packed.set(start + i, trit);
146            }
147        }
148
149        data.push(packed);
150    }
151
152    Ok(TernaryWeight {
153        data,
154        scales,
155        shape: (out_features, in_features),
156        group_size,
157    })
158}
159
160/// Dequantize ternary weights back to float tensor.
161///
162/// # Arguments
163///
164/// * `ternary` - Ternary weight to dequantize
165/// * `device` - Device to create output tensor on
166///
167/// # Errors
168///
169/// Returns error if tensor creation fails.
170pub fn dequantize_weights(ternary: &TernaryWeight, device: &Device) -> Result<Tensor> {
171    let out_features = ternary.out_features();
172    let in_features = ternary.in_features();
173    let group_size = ternary.group_size;
174    let num_groups_per_row = in_features / group_size;
175
176    let mut output = vec![0.0f32; out_features * in_features];
177
178    for (row_idx, packed) in ternary.data.iter().enumerate() {
179        let row_start = row_idx * in_features;
180
181        for g in 0..num_groups_per_row {
182            let scale_idx = row_idx * num_groups_per_row + g;
183            let scale = ternary.scales[scale_idx];
184            let group_start = g * group_size;
185
186            for i in 0..group_size {
187                let trit = packed.get(group_start + i);
188                let value = trit.value() as f32 * scale;
189                output[row_start + group_start + i] = value;
190            }
191        }
192    }
193
194    let tensor = Tensor::from_vec(output, (out_features, in_features), device)?;
195    Ok(tensor)
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use candle_core::DType;
202
203    #[test]
204    fn test_quantize_dequantize_roundtrip() {
205        let device = Device::Cpu;
206        let config = BitNetConfig::default();
207
208        // Create a weight tensor
209        let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
210
211        // Quantize
212        let ternary = quantize_weights(&weight, &config).unwrap();
213
214        // Check structure
215        assert_eq!(ternary.shape, (64, 128));
216        assert_eq!(ternary.data.len(), 64);
217        assert_eq!(ternary.scales.len(), 64 * (128 / 64)); // 64 rows * 2 groups
218
219        // Dequantize
220        let restored = dequantize_weights(&ternary, &device).unwrap();
221        assert_eq!(restored.shape().dims(), &[64, 128]);
222    }
223
224    #[test]
225    fn test_quantize_preserves_sign() {
226        let device = Device::Cpu;
227        let config = BitNetConfig::default().with_group_size(4);
228
229        // Create a simple weight with known values
230        let values: Vec<f32> = vec![1.0, -1.0, 0.5, -0.5, 2.0, -2.0, 0.1, -0.1];
231        let weight = Tensor::from_vec(values, (2, 4), &device).unwrap();
232
233        let ternary = quantize_weights(&weight, &config).unwrap();
234
235        // Check that signs are preserved
236        // Row 0: [1, -1, 0.5, -0.5] -> scale = (1+1+0.5+0.5)/4 = 0.75
237        // Normalized: [1.33, -1.33, 0.67, -0.67] -> [+1, -1, +1, -1]
238        assert_eq!(ternary.data[0].get(0), trit_vsa::Trit::P);
239        assert_eq!(ternary.data[0].get(1), trit_vsa::Trit::N);
240    }
241
242    #[test]
243    fn test_sparsity() {
244        let device = Device::Cpu;
245        let config = BitNetConfig::default().with_group_size(4);
246
247        // Create a sparse weight (many zeros)
248        let values: Vec<f32> = vec![1.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0];
249        let weight = Tensor::from_vec(values, (2, 4), &device).unwrap();
250
251        let ternary = quantize_weights(&weight, &config).unwrap();
252
253        // Should have high sparsity
254        let sparsity = ternary.sparsity();
255        assert!(sparsity > 0.5, "expected high sparsity, got {sparsity}");
256    }
257
258    #[test]
259    fn test_compression_ratio() {
260        let device = Device::Cpu;
261        let config = BitNetConfig::default();
262
263        let weight = Tensor::randn(0.0f32, 1.0, (1024, 4096), &device).unwrap();
264        let ternary = quantize_weights(&weight, &config).unwrap();
265
266        let ratio = ternary.compression_ratio();
267        // Should achieve significant compression (typically 8-16x)
268        assert!(ratio > 4.0, "expected >4x compression, got {ratio:.2}x");
269    }
270
271    #[test]
272    fn test_invalid_shape() {
273        let device = Device::Cpu;
274        let config = BitNetConfig::default();
275
276        // 1D tensor should fail
277        let weight = Tensor::zeros(&[64], DType::F32, &device).unwrap();
278        assert!(quantize_weights(&weight, &config).is_err());
279
280        // 3D tensor should fail
281        let weight = Tensor::zeros(&[2, 64, 64], DType::F32, &device).unwrap();
282        assert!(quantize_weights(&weight, &config).is_err());
283    }
284
285    #[test]
286    fn test_indivisible_group_size() {
287        let device = Device::Cpu;
288        let config = BitNetConfig::default().with_group_size(64);
289
290        // in_features=100 is not divisible by 64
291        let weight = Tensor::zeros(&[32, 100], DType::F32, &device).unwrap();
292        assert!(quantize_weights(&weight, &config).is_err());
293    }
294}