bitnet_quantize/quantization/
activation.rs1use candle_core::{Device, Tensor};
6use serde::{Deserialize, Serialize};
7
8use crate::config::BitNetConfig;
9use crate::error::{BitNetError, Result};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct QuantizedActivations {
14 pub data: Vec<i8>,
16
17 pub scales: Vec<f32>,
20
21 pub shape: Vec<usize>,
23
24 pub per_token: bool,
26}
27
28impl QuantizedActivations {
29 #[must_use]
31 pub fn batch_size(&self) -> usize {
32 self.shape.first().copied().unwrap_or(1)
33 }
34
35 #[must_use]
37 pub fn seq_len(&self) -> usize {
38 self.shape.get(1).copied().unwrap_or(1)
39 }
40
41 #[must_use]
43 pub fn hidden_dim(&self) -> usize {
44 self.shape.last().copied().unwrap_or(0)
45 }
46
47 #[must_use]
49 pub fn numel(&self) -> usize {
50 self.shape.iter().product()
51 }
52}
53
54pub fn quantize_activations(activations: &Tensor, config: &BitNetConfig) -> Result<QuantizedActivations> {
71 let shape = activations.shape().dims().to_vec();
72
73 let (batch_size, seq_len, hidden_dim) = match shape.len() {
75 2 => (shape[0], 1, shape[1]),
76 3 => (shape[0], shape[1], shape[2]),
77 _ => {
78 return Err(BitNetError::InvalidConfig(
79 "activations must be 2D or 3D".to_string(),
80 ))
81 }
82 };
83
84 let flat = activations.reshape((batch_size * seq_len, hidden_dim))?;
86 let flat_f32 = flat.to_dtype(candle_core::DType::F32)?.to_vec2::<f32>()?;
87
88 let max_val = (1 << (config.activation_bits - 1)) - 1; let max_val_f32 = max_val as f32;
90
91 let mut data = Vec::with_capacity(batch_size * seq_len * hidden_dim);
92 let mut scales = Vec::with_capacity(batch_size * seq_len);
93
94 for row in &flat_f32 {
95 let abs_max = row.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
97 let scale = if abs_max > 0.0 {
98 abs_max / max_val_f32
99 } else {
100 1.0
101 };
102 scales.push(scale);
103
104 for &val in row {
106 let quantized = (val / scale).round().clamp(-max_val_f32, max_val_f32) as i8;
107 data.push(quantized);
108 }
109 }
110
111 Ok(QuantizedActivations {
112 data,
113 scales,
114 shape,
115 per_token: config.per_token_activation,
116 })
117}
118
119pub fn dequantize_activations(quantized: &QuantizedActivations, device: &Device) -> Result<Tensor> {
130 let shape = &quantized.shape;
131 let (batch_size, seq_len, hidden_dim) = match shape.len() {
132 2 => (shape[0], 1, shape[1]),
133 3 => (shape[0], shape[1], shape[2]),
134 _ => {
135 return Err(BitNetError::InvalidConfig(
136 "invalid shape for dequantization".to_string(),
137 ))
138 }
139 };
140
141 let mut output = vec![0.0f32; batch_size * seq_len * hidden_dim];
142
143 for token_idx in 0..(batch_size * seq_len) {
144 let scale = quantized.scales[token_idx];
145 let token_start = token_idx * hidden_dim;
146
147 for i in 0..hidden_dim {
148 let q_val = quantized.data[token_start + i];
149 output[token_start + i] = q_val as f32 * scale;
150 }
151 }
152
153 let tensor = Tensor::from_vec(output, shape.clone(), device)?;
154 Ok(tensor)
155}
156
157pub fn quantize_ste(activations: &Tensor, config: &BitNetConfig) -> Result<Tensor> {
171 let device = activations.device();
172 let quantized = quantize_activations(activations, config)?;
173 dequantize_activations(&quantized, device)
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_quantize_dequantize_roundtrip_2d() {
182 let device = Device::Cpu;
183 let config = BitNetConfig::default();
184
185 let activations = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
187
188 let quantized = quantize_activations(&activations, &config).unwrap();
189
190 assert_eq!(quantized.shape, vec![4, 128]);
191 assert_eq!(quantized.scales.len(), 4); assert_eq!(quantized.data.len(), 4 * 128);
193
194 let restored = dequantize_activations(&quantized, &device).unwrap();
195 assert_eq!(restored.shape().dims(), &[4, 128]);
196 }
197
198 #[test]
199 fn test_quantize_dequantize_roundtrip_3d() {
200 let device = Device::Cpu;
201 let config = BitNetConfig::default();
202
203 let activations = Tensor::randn(0.0f32, 1.0, (2, 16, 128), &device).unwrap();
205
206 let quantized = quantize_activations(&activations, &config).unwrap();
207
208 assert_eq!(quantized.shape, vec![2, 16, 128]);
209 assert_eq!(quantized.scales.len(), 2 * 16); assert_eq!(quantized.data.len(), 2 * 16 * 128);
211
212 let restored = dequantize_activations(&quantized, &device).unwrap();
213 assert_eq!(restored.shape().dims(), &[2, 16, 128]);
214 }
215
216 #[test]
217 fn test_quantization_range() {
218 let device = Device::Cpu;
219 let config = BitNetConfig::default();
220
221 let values: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) / 10.0).collect();
223 let activations = Tensor::from_vec(values, (1, 64), &device).unwrap();
224
225 let quantized = quantize_activations(&activations, &config).unwrap();
226
227 for &val in &quantized.data {
229 assert!(val >= -127, "value {val} below -127");
230 }
231 }
232
233 #[test]
234 fn test_ste_passthrough() {
235 let device = Device::Cpu;
236 let config = BitNetConfig::training();
237
238 let activations = Tensor::randn(0.0f32, 1.0, (2, 64), &device).unwrap();
239
240 let result = quantize_ste(&activations, &config).unwrap();
241
242 assert_eq!(result.shape().dims(), activations.shape().dims());
244
245 let orig: Vec<f32> = activations.flatten_all().unwrap().to_vec1().unwrap();
247 let quant: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
248
249 for (o, q) in orig.iter().zip(quant.iter()) {
250 let error = (o - q).abs();
251 assert!(error < 0.1, "error {error} too large");
253 }
254 }
255
256 #[test]
257 fn test_zero_activations() {
258 let device = Device::Cpu;
259 let config = BitNetConfig::default();
260
261 let activations = Tensor::zeros(&[4, 64], candle_core::DType::F32, &device).unwrap();
262
263 let quantized = quantize_activations(&activations, &config).unwrap();
264
265 for &val in &quantized.data {
267 assert_eq!(val, 0);
268 }
269
270 for &scale in &quantized.scales {
272 assert!((scale - 1.0).abs() < 0.001);
273 }
274 }
275
276 #[test]
277 fn test_invalid_shape() {
278 let device = Device::Cpu;
279 let config = BitNetConfig::default();
280
281 let activations = Tensor::zeros(&[64], candle_core::DType::F32, &device).unwrap();
283 assert!(quantize_activations(&activations, &config).is_err());
284
285 let activations = Tensor::zeros(&[2, 4, 8, 16], candle_core::DType::F32, &device).unwrap();
287 assert!(quantize_activations(&activations, &config).is_err());
288 }
289}