bitnet_quantize/quantization/
activation.rs1use candle_core::{Device, Tensor};
6use serde::{Deserialize, Serialize};
7
8use crate::config::BitNetConfig;
9use crate::error::{BitNetError, Result};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct QuantizedActivations {
14 pub data: Vec<i8>,
16
17 pub scales: Vec<f32>,
20
21 pub shape: Vec<usize>,
23
24 pub per_token: bool,
26}
27
28impl QuantizedActivations {
29 #[must_use]
31 pub fn batch_size(&self) -> usize {
32 self.shape.first().copied().unwrap_or(1)
33 }
34
35 #[must_use]
37 pub fn seq_len(&self) -> usize {
38 self.shape.get(1).copied().unwrap_or(1)
39 }
40
41 #[must_use]
43 pub fn hidden_dim(&self) -> usize {
44 self.shape.last().copied().unwrap_or(0)
45 }
46
47 #[must_use]
49 pub fn numel(&self) -> usize {
50 self.shape.iter().product()
51 }
52}
53
54pub fn quantize_activations(
71 activations: &Tensor,
72 config: &BitNetConfig,
73) -> Result<QuantizedActivations> {
74 let shape = activations.shape().dims().to_vec();
75
76 let (batch_size, seq_len, hidden_dim) = match shape.len() {
78 2 => (shape[0], 1, shape[1]),
79 3 => (shape[0], shape[1], shape[2]),
80 _ => {
81 return Err(BitNetError::InvalidConfig(
82 "activations must be 2D or 3D".to_string(),
83 ))
84 }
85 };
86
87 let flat = activations.reshape((batch_size * seq_len, hidden_dim))?;
89 let flat_f32 = flat.to_dtype(candle_core::DType::F32)?.to_vec2::<f32>()?;
90
91 let max_val = (1 << (config.activation_bits - 1)) - 1; let max_val_f32 = max_val as f32;
93
94 let mut data = Vec::with_capacity(batch_size * seq_len * hidden_dim);
95 let mut scales = Vec::with_capacity(batch_size * seq_len);
96
97 for row in &flat_f32 {
98 let abs_max = row.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
100 let scale = if abs_max > 0.0 {
101 abs_max / max_val_f32
102 } else {
103 1.0
104 };
105 scales.push(scale);
106
107 for &val in row {
109 let quantized = (val / scale).round().clamp(-max_val_f32, max_val_f32) as i8;
110 data.push(quantized);
111 }
112 }
113
114 Ok(QuantizedActivations {
115 data,
116 scales,
117 shape,
118 per_token: config.per_token_activation,
119 })
120}
121
122pub fn dequantize_activations(quantized: &QuantizedActivations, device: &Device) -> Result<Tensor> {
133 let shape = &quantized.shape;
134 let (batch_size, seq_len, hidden_dim) = match shape.len() {
135 2 => (shape[0], 1, shape[1]),
136 3 => (shape[0], shape[1], shape[2]),
137 _ => {
138 return Err(BitNetError::InvalidConfig(
139 "invalid shape for dequantization".to_string(),
140 ))
141 }
142 };
143
144 let mut output = vec![0.0f32; batch_size * seq_len * hidden_dim];
145
146 for token_idx in 0..(batch_size * seq_len) {
147 let scale = quantized.scales[token_idx];
148 let token_start = token_idx * hidden_dim;
149
150 for i in 0..hidden_dim {
151 let q_val = quantized.data[token_start + i];
152 output[token_start + i] = q_val as f32 * scale;
153 }
154 }
155
156 let tensor = Tensor::from_vec(output, shape.clone(), device)?;
157 Ok(tensor)
158}
159
160pub fn quantize_ste(activations: &Tensor, config: &BitNetConfig) -> Result<Tensor> {
174 let device = activations.device();
175 let quantized = quantize_activations(activations, config)?;
176 dequantize_activations(&quantized, device)
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_quantize_dequantize_roundtrip_2d() {
185 let device = Device::Cpu;
186 let config = BitNetConfig::default();
187
188 let activations = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
190
191 let quantized = quantize_activations(&activations, &config).unwrap();
192
193 assert_eq!(quantized.shape, vec![4, 128]);
194 assert_eq!(quantized.scales.len(), 4); assert_eq!(quantized.data.len(), 4 * 128);
196
197 let restored = dequantize_activations(&quantized, &device).unwrap();
198 assert_eq!(restored.shape().dims(), &[4, 128]);
199 }
200
201 #[test]
202 fn test_quantize_dequantize_roundtrip_3d() {
203 let device = Device::Cpu;
204 let config = BitNetConfig::default();
205
206 let activations = Tensor::randn(0.0f32, 1.0, (2, 16, 128), &device).unwrap();
208
209 let quantized = quantize_activations(&activations, &config).unwrap();
210
211 assert_eq!(quantized.shape, vec![2, 16, 128]);
212 assert_eq!(quantized.scales.len(), 2 * 16); assert_eq!(quantized.data.len(), 2 * 16 * 128);
214
215 let restored = dequantize_activations(&quantized, &device).unwrap();
216 assert_eq!(restored.shape().dims(), &[2, 16, 128]);
217 }
218
219 #[test]
220 fn test_quantization_range() {
221 let device = Device::Cpu;
222 let config = BitNetConfig::default();
223
224 let values: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) / 10.0).collect();
226 let activations = Tensor::from_vec(values, (1, 64), &device).unwrap();
227
228 let quantized = quantize_activations(&activations, &config).unwrap();
229
230 for &val in &quantized.data {
232 assert!(val >= -127, "value {val} below -127");
233 }
234 }
235
236 #[test]
237 fn test_ste_passthrough() {
238 let device = Device::Cpu;
239 let config = BitNetConfig::training();
240
241 let activations = Tensor::randn(0.0f32, 1.0, (2, 64), &device).unwrap();
242
243 let result = quantize_ste(&activations, &config).unwrap();
244
245 assert_eq!(result.shape().dims(), activations.shape().dims());
247
248 let orig: Vec<f32> = activations.flatten_all().unwrap().to_vec1().unwrap();
250 let quant: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
251
252 for (o, q) in orig.iter().zip(quant.iter()) {
253 let error = (o - q).abs();
254 assert!(error < 0.1, "error {error} too large");
256 }
257 }
258
259 #[test]
260 fn test_zero_activations() {
261 let device = Device::Cpu;
262 let config = BitNetConfig::default();
263
264 let activations = Tensor::zeros(&[4, 64], candle_core::DType::F32, &device).unwrap();
265
266 let quantized = quantize_activations(&activations, &config).unwrap();
267
268 for &val in &quantized.data {
270 assert_eq!(val, 0);
271 }
272
273 for &scale in &quantized.scales {
275 assert!((scale - 1.0).abs() < 0.001);
276 }
277 }
278
279 #[test]
280 fn test_invalid_shape() {
281 let device = Device::Cpu;
282 let config = BitNetConfig::default();
283
284 let activations = Tensor::zeros(&[64], candle_core::DType::F32, &device).unwrap();
286 assert!(quantize_activations(&activations, &config).is_err());
287
288 let activations = Tensor::zeros(&[2, 4, 8, 16], candle_core::DType::F32, &device).unwrap();
290 assert!(quantize_activations(&activations, &config).is_err());
291 }
292}