Skip to main content

bitnet_quantize/quantization/
activation.rs

1//! Activation quantization for BitNet.
2//!
3//! Implements per-token AbsMax quantization to INT8.
4
5use candle_core::{Device, Tensor};
6use serde::{Deserialize, Serialize};
7
8use crate::config::BitNetConfig;
9use crate::error::{BitNetError, Result};
10
11/// Quantized activations with per-token scales.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct QuantizedActivations {
14    /// Quantized INT8 values (stored as i8 in a Vec).
15    pub data: Vec<i8>,
16
17    /// Per-token scale factors.
18    /// Shape depends on `per_token`: [batch, seq_len] or [batch].
19    pub scales: Vec<f32>,
20
21    /// Original shape [batch, seq_len, hidden_dim].
22    pub shape: Vec<usize>,
23
24    /// Whether per-token scaling was used.
25    pub per_token: bool,
26}
27
28impl QuantizedActivations {
29    /// Get the batch size.
30    #[must_use]
31    pub fn batch_size(&self) -> usize {
32        self.shape.first().copied().unwrap_or(1)
33    }
34
35    /// Get the sequence length.
36    #[must_use]
37    pub fn seq_len(&self) -> usize {
38        self.shape.get(1).copied().unwrap_or(1)
39    }
40
41    /// Get the hidden dimension.
42    #[must_use]
43    pub fn hidden_dim(&self) -> usize {
44        self.shape.last().copied().unwrap_or(0)
45    }
46
47    /// Get the total number of elements.
48    #[must_use]
49    pub fn numel(&self) -> usize {
50        self.shape.iter().product()
51    }
52}
53
54/// Quantize activations using per-token AbsMax scaling to INT8.
55///
56/// # Algorithm
57///
58/// For each token (row):
59/// 1. Compute `scale = max(|X|) / 127`
60/// 2. Compute `X_q = round(X / scale)` clamped to [-127, 127]
61///
62/// # Arguments
63///
64/// * `activations` - Input tensor [batch, seq_len, hidden_dim] or [batch, hidden_dim]
65/// * `config` - BitNet configuration
66///
67/// # Errors
68///
69/// Returns error if quantization fails.
70pub fn quantize_activations(
71    activations: &Tensor,
72    config: &BitNetConfig,
73) -> Result<QuantizedActivations> {
74    let shape = activations.shape().dims().to_vec();
75
76    // Handle both 2D [batch, hidden] and 3D [batch, seq, hidden] inputs
77    let (batch_size, seq_len, hidden_dim) = match shape.len() {
78        2 => (shape[0], 1, shape[1]),
79        3 => (shape[0], shape[1], shape[2]),
80        _ => {
81            return Err(BitNetError::InvalidConfig(
82                "activations must be 2D or 3D".to_string(),
83            ))
84        }
85    };
86
87    // Reshape to [batch * seq_len, hidden_dim] for uniform processing
88    let flat = activations.reshape((batch_size * seq_len, hidden_dim))?;
89    let flat_f32 = flat.to_dtype(candle_core::DType::F32)?.to_vec2::<f32>()?;
90
91    let max_val = (1 << (config.activation_bits - 1)) - 1; // 127 for 8-bit
92    let max_val_f32 = max_val as f32;
93
94    let mut data = Vec::with_capacity(batch_size * seq_len * hidden_dim);
95    let mut scales = Vec::with_capacity(batch_size * seq_len);
96
97    for row in &flat_f32 {
98        // Compute AbsMax for this token
99        let abs_max = row.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
100        let scale = if abs_max > 0.0 {
101            abs_max / max_val_f32
102        } else {
103            1.0
104        };
105        scales.push(scale);
106
107        // Quantize
108        for &val in row {
109            let quantized = (val / scale).round().clamp(-max_val_f32, max_val_f32) as i8;
110            data.push(quantized);
111        }
112    }
113
114    Ok(QuantizedActivations {
115        data,
116        scales,
117        shape,
118        per_token: config.per_token_activation,
119    })
120}
121
122/// Dequantize INT8 activations back to float tensor.
123///
124/// # Arguments
125///
126/// * `quantized` - Quantized activations
127/// * `device` - Device to create output tensor on
128///
129/// # Errors
130///
131/// Returns error if tensor creation fails.
132pub fn dequantize_activations(quantized: &QuantizedActivations, device: &Device) -> Result<Tensor> {
133    let shape = &quantized.shape;
134    let (batch_size, seq_len, hidden_dim) = match shape.len() {
135        2 => (shape[0], 1, shape[1]),
136        3 => (shape[0], shape[1], shape[2]),
137        _ => {
138            return Err(BitNetError::InvalidConfig(
139                "invalid shape for dequantization".to_string(),
140            ))
141        }
142    };
143
144    let mut output = vec![0.0f32; batch_size * seq_len * hidden_dim];
145
146    for token_idx in 0..(batch_size * seq_len) {
147        let scale = quantized.scales[token_idx];
148        let token_start = token_idx * hidden_dim;
149
150        for i in 0..hidden_dim {
151            let q_val = quantized.data[token_start + i];
152            output[token_start + i] = q_val as f32 * scale;
153        }
154    }
155
156    let tensor = Tensor::from_vec(output, shape.clone(), device)?;
157    Ok(tensor)
158}
159
160/// Apply quantization in a differentiable way using Straight-Through Estimator.
161///
162/// During forward pass: quantize -> dequantize
163/// During backward pass: gradients flow through unchanged
164///
165/// # Arguments
166///
167/// * `activations` - Input tensor
168/// * `config` - BitNet configuration
169///
170/// # Errors
171///
172/// Returns error if quantization fails.
173pub fn quantize_ste(activations: &Tensor, config: &BitNetConfig) -> Result<Tensor> {
174    let device = activations.device();
175    let quantized = quantize_activations(activations, config)?;
176    dequantize_activations(&quantized, device)
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn test_quantize_dequantize_roundtrip_2d() {
185        let device = Device::Cpu;
186        let config = BitNetConfig::default();
187
188        // Create 2D activations [batch, hidden]
189        let activations = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
190
191        let quantized = quantize_activations(&activations, &config).unwrap();
192
193        assert_eq!(quantized.shape, vec![4, 128]);
194        assert_eq!(quantized.scales.len(), 4); // One per batch item
195        assert_eq!(quantized.data.len(), 4 * 128);
196
197        let restored = dequantize_activations(&quantized, &device).unwrap();
198        assert_eq!(restored.shape().dims(), &[4, 128]);
199    }
200
201    #[test]
202    fn test_quantize_dequantize_roundtrip_3d() {
203        let device = Device::Cpu;
204        let config = BitNetConfig::default();
205
206        // Create 3D activations [batch, seq, hidden]
207        let activations = Tensor::randn(0.0f32, 1.0, (2, 16, 128), &device).unwrap();
208
209        let quantized = quantize_activations(&activations, &config).unwrap();
210
211        assert_eq!(quantized.shape, vec![2, 16, 128]);
212        assert_eq!(quantized.scales.len(), 2 * 16); // Per token
213        assert_eq!(quantized.data.len(), 2 * 16 * 128);
214
215        let restored = dequantize_activations(&quantized, &device).unwrap();
216        assert_eq!(restored.shape().dims(), &[2, 16, 128]);
217    }
218
219    #[test]
220    fn test_quantization_range() {
221        let device = Device::Cpu;
222        let config = BitNetConfig::default();
223
224        // Create activations with known range
225        let values: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) / 10.0).collect();
226        let activations = Tensor::from_vec(values, (1, 64), &device).unwrap();
227
228        let quantized = quantize_activations(&activations, &config).unwrap();
229
230        // All quantized values should be in [-127, 127] (i8 type enforces upper bound)
231        for &val in &quantized.data {
232            assert!(val >= -127, "value {val} below -127");
233        }
234    }
235
236    #[test]
237    fn test_ste_passthrough() {
238        let device = Device::Cpu;
239        let config = BitNetConfig::training();
240
241        let activations = Tensor::randn(0.0f32, 1.0, (2, 64), &device).unwrap();
242
243        let result = quantize_ste(&activations, &config).unwrap();
244
245        // Shape should be preserved
246        assert_eq!(result.shape().dims(), activations.shape().dims());
247
248        // Values should be close (within quantization error)
249        let orig: Vec<f32> = activations.flatten_all().unwrap().to_vec1().unwrap();
250        let quant: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
251
252        for (o, q) in orig.iter().zip(quant.iter()) {
253            let error = (o - q).abs();
254            // INT8 quantization error should be bounded
255            assert!(error < 0.1, "error {error} too large");
256        }
257    }
258
259    #[test]
260    fn test_zero_activations() {
261        let device = Device::Cpu;
262        let config = BitNetConfig::default();
263
264        let activations = Tensor::zeros(&[4, 64], candle_core::DType::F32, &device).unwrap();
265
266        let quantized = quantize_activations(&activations, &config).unwrap();
267
268        // All quantized values should be zero
269        for &val in &quantized.data {
270            assert_eq!(val, 0);
271        }
272
273        // Scales should be 1.0 (fallback)
274        for &scale in &quantized.scales {
275            assert!((scale - 1.0).abs() < 0.001);
276        }
277    }
278
279    #[test]
280    fn test_invalid_shape() {
281        let device = Device::Cpu;
282        let config = BitNetConfig::default();
283
284        // 1D tensor should fail
285        let activations = Tensor::zeros(&[64], candle_core::DType::F32, &device).unwrap();
286        assert!(quantize_activations(&activations, &config).is_err());
287
288        // 4D tensor should fail
289        let activations = Tensor::zeros(&[2, 4, 8, 16], candle_core::DType::F32, &device).unwrap();
290        assert!(quantize_activations(&activations, &config).is_err());
291    }
292}