1use crate::error::{QuantError, QuantResult};
26
27#[derive(Debug, Clone, Copy)]
31pub struct SmoothQuantConfig {
32 pub alpha: f32,
38}
39
40impl Default for SmoothQuantConfig {
41 fn default() -> Self {
42 Self { alpha: 0.5 }
43 }
44}
45
46#[derive(Debug, Clone, Copy)]
54pub struct SmoothQuantMigrator {
55 pub config: SmoothQuantConfig,
57}
58
59impl SmoothQuantMigrator {
60 #[must_use]
62 pub fn new(alpha: f32) -> Self {
63 Self {
64 config: SmoothQuantConfig { alpha },
65 }
66 }
67
68 pub fn compute_migration_scales(
85 &self,
86 act_max: &[f32],
87 weight_max: &[f32],
88 ) -> QuantResult<Vec<f32>> {
89 if act_max.is_empty() {
90 return Err(QuantError::EmptyInput(
91 "SmoothQuantMigrator::compute_migration_scales",
92 ));
93 }
94 if act_max.len() != weight_max.len() {
95 return Err(QuantError::DimensionMismatch {
96 expected: act_max.len(),
97 got: weight_max.len(),
98 });
99 }
100 let alpha = self.config.alpha;
101 let scales = act_max
102 .iter()
103 .zip(weight_max.iter())
104 .map(|(&a_max, &w_max)| {
105 let a = a_max.abs().max(1e-8);
106 let w = w_max.abs().max(1e-8);
107 a.powf(alpha) / w.powf(1.0 - alpha)
108 })
109 .collect();
110 Ok(scales)
111 }
112
113 pub fn compute_act_stats(
126 acts: &[f32],
127 n_tokens: usize,
128 n_channels: usize,
129 ) -> QuantResult<Vec<f32>> {
130 if acts.is_empty() {
131 return Err(QuantError::EmptyInput(
132 "compute_act_stats: empty activations",
133 ));
134 }
135 if acts.len() != n_tokens * n_channels {
136 return Err(QuantError::DimensionMismatch {
137 expected: n_tokens * n_channels,
138 got: acts.len(),
139 });
140 }
141 let mut stats = vec![0.0_f32; n_channels];
142 for t in 0..n_tokens {
143 for j in 0..n_channels {
144 let v = acts[t * n_channels + j].abs();
145 if v > stats[j] {
146 stats[j] = v;
147 }
148 }
149 }
150 Ok(stats)
151 }
152
153 pub fn compute_weight_stats(
166 weights: &[f32],
167 n_out: usize,
168 n_channels: usize,
169 ) -> QuantResult<Vec<f32>> {
170 if weights.is_empty() {
171 return Err(QuantError::EmptyInput(
172 "compute_weight_stats: empty weights",
173 ));
174 }
175 if weights.len() != n_out * n_channels {
176 return Err(QuantError::DimensionMismatch {
177 expected: n_out * n_channels,
178 got: weights.len(),
179 });
180 }
181 let mut stats = vec![0.0_f32; n_channels];
182 for r in 0..n_out {
183 for j in 0..n_channels {
184 let v = weights[r * n_channels + j].abs();
185 if v > stats[j] {
186 stats[j] = v;
187 }
188 }
189 }
190 Ok(stats)
191 }
192
193 pub fn smooth_activations(
199 acts: &mut [f32],
200 scales: &[f32],
201 n_tokens: usize,
202 n_channels: usize,
203 ) -> QuantResult<()> {
204 if acts.len() != n_tokens * n_channels {
205 return Err(QuantError::DimensionMismatch {
206 expected: n_tokens * n_channels,
207 got: acts.len(),
208 });
209 }
210 if scales.len() != n_channels {
211 return Err(QuantError::DimensionMismatch {
212 expected: n_channels,
213 got: scales.len(),
214 });
215 }
216 for t in 0..n_tokens {
217 for j in 0..n_channels {
218 acts[t * n_channels + j] /= scales[j].max(1e-12);
219 }
220 }
221 Ok(())
222 }
223
224 pub fn smooth_weights(
232 weights: &mut [f32],
233 scales: &[f32],
234 n_out: usize,
235 n_channels: usize,
236 ) -> QuantResult<()> {
237 if weights.len() != n_out * n_channels {
238 return Err(QuantError::DimensionMismatch {
239 expected: n_out * n_channels,
240 got: weights.len(),
241 });
242 }
243 if scales.len() != n_channels {
244 return Err(QuantError::DimensionMismatch {
245 expected: n_channels,
246 got: scales.len(),
247 });
248 }
249 for r in 0..n_out {
250 for j in 0..n_channels {
251 weights[r * n_channels + j] *= scales[j];
252 }
253 }
254 Ok(())
255 }
256
257 pub fn smooth_layer(
275 &self,
276 acts: &mut [f32],
277 weights: &mut [f32],
278 n_tokens: usize,
279 n_channels: usize,
280 n_out: usize,
281 ) -> QuantResult<Vec<f32>> {
282 let act_stats = Self::compute_act_stats(acts, n_tokens, n_channels)?;
283 let weight_stats = Self::compute_weight_stats(weights, n_out, n_channels)?;
284 let scales = self.compute_migration_scales(&act_stats, &weight_stats)?;
285 Self::smooth_activations(acts, &scales, n_tokens, n_channels)?;
286 Self::smooth_weights(weights, &scales, n_out, n_channels)?;
287 Ok(scales)
288 }
289}
290
291#[cfg(test)]
294mod tests {
295 use super::*;
296 use approx::assert_abs_diff_eq;
297
298 fn matmul_nt(x: &[f32], w: &[f32], n_tok: usize, n_ch: usize, n_out: usize) -> Vec<f32> {
300 let mut y = vec![0.0_f32; n_tok * n_out];
302 for t in 0..n_tok {
303 for o in 0..n_out {
304 let dot: f32 = (0..n_ch).map(|j| x[t * n_ch + j] * w[o * n_ch + j]).sum();
305 y[t * n_out + o] = dot;
306 }
307 }
308 y
309 }
310
311 #[test]
312 fn scale_alpha_half() {
313 let m = SmoothQuantMigrator::new(0.5);
314 let act_max = vec![4.0_f32, 1.0, 9.0];
315 let weight_max = vec![1.0_f32, 4.0, 1.0];
316 let scales = m.compute_migration_scales(&act_max, &weight_max).unwrap();
317 assert_abs_diff_eq!(scales[0], 2.0, epsilon = 1e-5);
319 assert_abs_diff_eq!(scales[1], 0.5, epsilon = 1e-5);
321 assert_abs_diff_eq!(scales[2], 3.0, epsilon = 1e-5);
323 }
324
325 #[test]
326 fn scale_alpha_one_activations_only() {
327 let m = SmoothQuantMigrator::new(1.0);
329 let act_max = vec![2.0_f32, 5.0];
330 let weight_max = vec![3.0_f32, 7.0]; let scales = m.compute_migration_scales(&act_max, &weight_max).unwrap();
332 assert_abs_diff_eq!(scales[0], 2.0, epsilon = 1e-5);
333 assert_abs_diff_eq!(scales[1], 5.0, epsilon = 1e-5);
334 }
335
336 #[test]
337 fn scale_alpha_zero_weights_only() {
338 let m = SmoothQuantMigrator::new(0.0);
340 let act_max = vec![4.0_f32, 1.0]; let weight_max = vec![2.0_f32, 5.0];
342 let scales = m.compute_migration_scales(&act_max, &weight_max).unwrap();
343 assert_abs_diff_eq!(scales[0], 1.0 / 2.0, epsilon = 1e-5);
344 assert_abs_diff_eq!(scales[1], 1.0 / 5.0, epsilon = 1e-5);
345 }
346
347 #[test]
348 fn smoothing_preserves_layer_output() {
349 let m = SmoothQuantMigrator::new(0.5);
350 let n_tok = 3;
351 let n_ch = 4;
352 let n_out = 2;
353 let mut acts: Vec<f32> = (0..(n_tok * n_ch))
354 .map(|i| (i as f32 * 0.3) - 1.0)
355 .collect();
356 let mut weights: Vec<f32> = (0..(n_out * n_ch))
357 .map(|i| (i as f32 * 0.2) - 0.5)
358 .collect();
359
360 let y_orig = matmul_nt(&acts, &weights, n_tok, n_ch, n_out);
362
363 m.smooth_layer(&mut acts, &mut weights, n_tok, n_ch, n_out)
365 .unwrap();
366
367 let y_smooth = matmul_nt(&acts, &weights, n_tok, n_ch, n_out);
369
370 for (a, b) in y_orig.iter().zip(y_smooth.iter()) {
372 assert_abs_diff_eq!(a, b, epsilon = 1e-4);
373 }
374 }
375
376 #[test]
377 fn activation_stats_max_per_channel() {
378 let acts = vec![1.0_f32, -5.0, 2.0, -3.0, 4.0, 1.0];
381 let stats = SmoothQuantMigrator::compute_act_stats(&acts, 2, 3).unwrap();
382 assert_abs_diff_eq!(stats[0], 3.0, epsilon = 1e-6); assert_abs_diff_eq!(stats[1], 5.0, epsilon = 1e-6); assert_abs_diff_eq!(stats[2], 2.0, epsilon = 1e-6); }
386
387 #[test]
388 fn weight_stats_max_per_column() {
389 let w = vec![0.5_f32, -2.0, 1.0, -1.5, 0.3, 3.0];
391 let stats = SmoothQuantMigrator::compute_weight_stats(&w, 2, 3).unwrap();
392 assert_abs_diff_eq!(stats[0], 1.5, epsilon = 1e-6);
393 assert_abs_diff_eq!(stats[1], 2.0, epsilon = 1e-6);
394 assert_abs_diff_eq!(stats[2], 3.0, epsilon = 1e-6);
395 }
396
397 #[test]
398 fn dimension_mismatch_error() {
399 let m = SmoothQuantMigrator::new(0.5);
400 let act_max = vec![1.0_f32; 3];
401 let weight_max = vec![1.0_f32; 4]; assert!(matches!(
403 m.compute_migration_scales(&act_max, &weight_max),
404 Err(QuantError::DimensionMismatch { .. })
405 ));
406 }
407
408 #[test]
409 fn empty_input_error() {
410 let m = SmoothQuantMigrator::new(0.5);
411 assert!(matches!(
412 m.compute_migration_scales(&[], &[]),
413 Err(QuantError::EmptyInput(_))
414 ));
415 }
416
417 #[test]
418 fn smoothing_reduces_act_channel_range_imbalance() {
419 let m = SmoothQuantMigrator::new(0.5);
421 let n_tok = 4;
422 let n_ch = 2;
423 let n_out = 2;
424 let mut acts = vec![100.0_f32, 1.0, -100.0, 1.0, 100.0, -1.0, -100.0, -1.0];
425 let mut weights = vec![0.5_f32, 0.5, -0.5, 0.5];
426
427 let scales = m
428 .smooth_layer(&mut acts, &mut weights, n_tok, n_ch, n_out)
429 .unwrap();
430 let act_max_0: f32 = (0..n_tok)
432 .map(|t| acts[t * n_ch].abs())
433 .fold(0.0_f32, f32::max);
434 let act_max_1: f32 = (0..n_tok)
435 .map(|t| acts[t * n_ch + 1].abs())
436 .fold(0.0_f32, f32::max);
437 let ratio = act_max_0 / act_max_1.max(1e-8);
439 assert!(
441 scales[0] > 1.0,
442 "scale[0] should be > 1 for outlier channel"
443 );
444 assert!(
445 ratio < 100.0,
446 "channel range imbalance should decrease after smoothing"
447 );
448 }
449}