Skip to main content

bitnet_quantize/quantization/
activation.rs

1//! Activation quantization for BitNet.
2//!
3//! Implements per-token AbsMax quantization to INT8.
4
5use candle_core::{Device, Tensor};
6use serde::{Deserialize, Serialize};
7
8use crate::config::BitNetConfig;
9use crate::error::{BitNetError, Result};
10
11/// Quantized activations with per-token scales.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct QuantizedActivations {
14    /// Quantized INT8 values (stored as i8 in a Vec).
15    pub data: Vec<i8>,
16
17    /// Per-token scale factors.
18    /// Shape depends on `per_token`: [batch, seq_len] or [batch].
19    pub scales: Vec<f32>,
20
21    /// Original shape [batch, seq_len, hidden_dim].
22    pub shape: Vec<usize>,
23
24    /// Whether per-token scaling was used.
25    pub per_token: bool,
26}
27
28impl QuantizedActivations {
29    /// Get the batch size.
30    #[must_use]
31    pub fn batch_size(&self) -> usize {
32        self.shape.first().copied().unwrap_or(1)
33    }
34
35    /// Get the sequence length.
36    #[must_use]
37    pub fn seq_len(&self) -> usize {
38        self.shape.get(1).copied().unwrap_or(1)
39    }
40
41    /// Get the hidden dimension.
42    #[must_use]
43    pub fn hidden_dim(&self) -> usize {
44        self.shape.last().copied().unwrap_or(0)
45    }
46
47    /// Get the total number of elements.
48    #[must_use]
49    pub fn numel(&self) -> usize {
50        self.shape.iter().product()
51    }
52}
53
54/// Quantize activations using per-token AbsMax scaling to INT8.
55///
56/// # Algorithm
57///
58/// For each token (row):
59/// 1. Compute `scale = max(|X|) / 127`
60/// 2. Compute `X_q = round(X / scale)` clamped to [-127, 127]
61///
62/// # Arguments
63///
64/// * `activations` - Input tensor [batch, seq_len, hidden_dim] or [batch, hidden_dim]
65/// * `config` - BitNet configuration
66///
67/// # Errors
68///
69/// Returns error if quantization fails.
70pub fn quantize_activations(activations: &Tensor, config: &BitNetConfig) -> Result<QuantizedActivations> {
71    let shape = activations.shape().dims().to_vec();
72
73    // Handle both 2D [batch, hidden] and 3D [batch, seq, hidden] inputs
74    let (batch_size, seq_len, hidden_dim) = match shape.len() {
75        2 => (shape[0], 1, shape[1]),
76        3 => (shape[0], shape[1], shape[2]),
77        _ => {
78            return Err(BitNetError::InvalidConfig(
79                "activations must be 2D or 3D".to_string(),
80            ))
81        }
82    };
83
84    // Reshape to [batch * seq_len, hidden_dim] for uniform processing
85    let flat = activations.reshape((batch_size * seq_len, hidden_dim))?;
86    let flat_f32 = flat.to_dtype(candle_core::DType::F32)?.to_vec2::<f32>()?;
87
88    let max_val = (1 << (config.activation_bits - 1)) - 1; // 127 for 8-bit
89    let max_val_f32 = max_val as f32;
90
91    let mut data = Vec::with_capacity(batch_size * seq_len * hidden_dim);
92    let mut scales = Vec::with_capacity(batch_size * seq_len);
93
94    for row in &flat_f32 {
95        // Compute AbsMax for this token
96        let abs_max = row.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
97        let scale = if abs_max > 0.0 {
98            abs_max / max_val_f32
99        } else {
100            1.0
101        };
102        scales.push(scale);
103
104        // Quantize
105        for &val in row {
106            let quantized = (val / scale).round().clamp(-max_val_f32, max_val_f32) as i8;
107            data.push(quantized);
108        }
109    }
110
111    Ok(QuantizedActivations {
112        data,
113        scales,
114        shape,
115        per_token: config.per_token_activation,
116    })
117}
118
119/// Dequantize INT8 activations back to float tensor.
120///
121/// # Arguments
122///
123/// * `quantized` - Quantized activations
124/// * `device` - Device to create output tensor on
125///
126/// # Errors
127///
128/// Returns error if tensor creation fails.
129pub fn dequantize_activations(quantized: &QuantizedActivations, device: &Device) -> Result<Tensor> {
130    let shape = &quantized.shape;
131    let (batch_size, seq_len, hidden_dim) = match shape.len() {
132        2 => (shape[0], 1, shape[1]),
133        3 => (shape[0], shape[1], shape[2]),
134        _ => {
135            return Err(BitNetError::InvalidConfig(
136                "invalid shape for dequantization".to_string(),
137            ))
138        }
139    };
140
141    let mut output = vec![0.0f32; batch_size * seq_len * hidden_dim];
142
143    for token_idx in 0..(batch_size * seq_len) {
144        let scale = quantized.scales[token_idx];
145        let token_start = token_idx * hidden_dim;
146
147        for i in 0..hidden_dim {
148            let q_val = quantized.data[token_start + i];
149            output[token_start + i] = q_val as f32 * scale;
150        }
151    }
152
153    let tensor = Tensor::from_vec(output, shape.clone(), device)?;
154    Ok(tensor)
155}
156
157/// Apply quantization in a differentiable way using Straight-Through Estimator.
158///
159/// During forward pass: quantize -> dequantize
160/// During backward pass: gradients flow through unchanged
161///
162/// # Arguments
163///
164/// * `activations` - Input tensor
165/// * `config` - BitNet configuration
166///
167/// # Errors
168///
169/// Returns error if quantization fails.
170pub fn quantize_ste(activations: &Tensor, config: &BitNetConfig) -> Result<Tensor> {
171    let device = activations.device();
172    let quantized = quantize_activations(activations, config)?;
173    dequantize_activations(&quantized, device)
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_quantize_dequantize_roundtrip_2d() {
182        let device = Device::Cpu;
183        let config = BitNetConfig::default();
184
185        // Create 2D activations [batch, hidden]
186        let activations = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
187
188        let quantized = quantize_activations(&activations, &config).unwrap();
189
190        assert_eq!(quantized.shape, vec![4, 128]);
191        assert_eq!(quantized.scales.len(), 4); // One per batch item
192        assert_eq!(quantized.data.len(), 4 * 128);
193
194        let restored = dequantize_activations(&quantized, &device).unwrap();
195        assert_eq!(restored.shape().dims(), &[4, 128]);
196    }
197
198    #[test]
199    fn test_quantize_dequantize_roundtrip_3d() {
200        let device = Device::Cpu;
201        let config = BitNetConfig::default();
202
203        // Create 3D activations [batch, seq, hidden]
204        let activations = Tensor::randn(0.0f32, 1.0, (2, 16, 128), &device).unwrap();
205
206        let quantized = quantize_activations(&activations, &config).unwrap();
207
208        assert_eq!(quantized.shape, vec![2, 16, 128]);
209        assert_eq!(quantized.scales.len(), 2 * 16); // Per token
210        assert_eq!(quantized.data.len(), 2 * 16 * 128);
211
212        let restored = dequantize_activations(&quantized, &device).unwrap();
213        assert_eq!(restored.shape().dims(), &[2, 16, 128]);
214    }
215
216    #[test]
217    fn test_quantization_range() {
218        let device = Device::Cpu;
219        let config = BitNetConfig::default();
220
221        // Create activations with known range
222        let values: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) / 10.0).collect();
223        let activations = Tensor::from_vec(values, (1, 64), &device).unwrap();
224
225        let quantized = quantize_activations(&activations, &config).unwrap();
226
227        // All quantized values should be in [-127, 127] (i8 type enforces upper bound)
228        for &val in &quantized.data {
229            assert!(val >= -127, "value {val} below -127");
230        }
231    }
232
233    #[test]
234    fn test_ste_passthrough() {
235        let device = Device::Cpu;
236        let config = BitNetConfig::training();
237
238        let activations = Tensor::randn(0.0f32, 1.0, (2, 64), &device).unwrap();
239
240        let result = quantize_ste(&activations, &config).unwrap();
241
242        // Shape should be preserved
243        assert_eq!(result.shape().dims(), activations.shape().dims());
244
245        // Values should be close (within quantization error)
246        let orig: Vec<f32> = activations.flatten_all().unwrap().to_vec1().unwrap();
247        let quant: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
248
249        for (o, q) in orig.iter().zip(quant.iter()) {
250            let error = (o - q).abs();
251            // INT8 quantization error should be bounded
252            assert!(error < 0.1, "error {error} too large");
253        }
254    }
255
256    #[test]
257    fn test_zero_activations() {
258        let device = Device::Cpu;
259        let config = BitNetConfig::default();
260
261        let activations = Tensor::zeros(&[4, 64], candle_core::DType::F32, &device).unwrap();
262
263        let quantized = quantize_activations(&activations, &config).unwrap();
264
265        // All quantized values should be zero
266        for &val in &quantized.data {
267            assert_eq!(val, 0);
268        }
269
270        // Scales should be 1.0 (fallback)
271        for &scale in &quantized.scales {
272            assert!((scale - 1.0).abs() < 0.001);
273        }
274    }
275
276    #[test]
277    fn test_invalid_shape() {
278        let device = Device::Cpu;
279        let config = BitNetConfig::default();
280
281        // 1D tensor should fail
282        let activations = Tensor::zeros(&[64], candle_core::DType::F32, &device).unwrap();
283        assert!(quantize_activations(&activations, &config).is_err());
284
285        // 4D tensor should fail
286        let activations = Tensor::zeros(&[2, 4, 8, 16], candle_core::DType::F32, &device).unwrap();
287        assert!(quantize_activations(&activations, &config).is_err());
288    }
289}