Skip to main content

oxicuda_quant/scheme/
smooth_quant.rs

1//! # SmoothQuant — Activation–Weight Quantization Migration
2//!
3//! Xiao et al. (2022): "SmoothQuant: Accurate and Efficient Post-Training
4//! Quantization for Large Language Models" <https://arxiv.org/abs/2211.10438>
5//!
6//! LLM activations often contain large per-channel outliers that make INT8
7//! quantization difficult, while weights are typically well-behaved.
8//! SmoothQuant migrates the quantization difficulty from activations to weights
9//! via a mathematically equivalent per-channel rescaling.
10//!
11//! ## Migration
12//!
13//! ```text
14//! s_j = max|X_j|^α / max|W_j|^(1−α)    (per-channel scale)
15//!
16//! X_smooth[:,j] = X[:,j] / s_j          (activations ÷ s)
17//! W_smooth[:,j] = W[:,j] × s_j          (weights × s, column = input channel)
18//!
19//! Y = X W^T = X_smooth W_smooth^T       (output unchanged)
20//! ```
21//!
22//! `α = 0.5` balances difficulty equally.  `α → 1` pushes all difficulty to
23//! weights; `α → 0` leaves activations as-is.
24
25use crate::error::{QuantError, QuantResult};
26
27// ─── Config ───────────────────────────────────────────────────────────────────
28
29/// SmoothQuant migration configuration.
30#[derive(Debug, Clone, Copy)]
31pub struct SmoothQuantConfig {
32    /// Migration strength α ∈ [0, 1].
33    ///
34    /// * 0.5 — equal difficulty between activations and weights (default).
35    /// * 1.0 — migrate all difficulty to weights (activations easy, weights hard).
36    /// * 0.0 — no migration (activations carry full difficulty).
37    pub alpha: f32,
38}
39
40impl Default for SmoothQuantConfig {
41    fn default() -> Self {
42        Self { alpha: 0.5 }
43    }
44}
45
46// ─── SmoothQuantMigrator ─────────────────────────────────────────────────────
47
48/// Applies per-channel scaling to balance quantization difficulty.
49///
50/// The migrator operates on linear layers:
51/// * **Activations** `X` — shape `[n_tokens, n_channels]`
52/// * **Weights** `W` — shape `[n_out, n_channels]` (transposed: `Y = X W^T`)
53#[derive(Debug, Clone, Copy)]
54pub struct SmoothQuantMigrator {
55    /// Migration configuration.
56    pub config: SmoothQuantConfig,
57}
58
59impl SmoothQuantMigrator {
60    /// Create a migrator with the given `alpha`.
61    #[must_use]
62    pub fn new(alpha: f32) -> Self {
63        Self {
64            config: SmoothQuantConfig { alpha },
65        }
66    }
67
68    /// Compute per-channel migration scales from pre-aggregated statistics.
69    ///
70    /// # Parameters
71    ///
72    /// * `act_max`    — per-channel max absolute value of activations (length `n_ch`).
73    /// * `weight_max` — per-channel (column) max absolute value of weights (length `n_ch`).
74    ///
75    /// # Returns
76    ///
77    /// Scale vector `s` of length `n_ch` where
78    /// `s[j] = act_max[j]^alpha / weight_max[j]^(1−alpha)`.
79    ///
80    /// # Errors
81    ///
82    /// * [`QuantError::DimensionMismatch`] — `act_max` and `weight_max` differ in length.
83    /// * [`QuantError::EmptyInput`] — either slice is empty.
84    pub fn compute_migration_scales(
85        &self,
86        act_max: &[f32],
87        weight_max: &[f32],
88    ) -> QuantResult<Vec<f32>> {
89        if act_max.is_empty() {
90            return Err(QuantError::EmptyInput(
91                "SmoothQuantMigrator::compute_migration_scales",
92            ));
93        }
94        if act_max.len() != weight_max.len() {
95            return Err(QuantError::DimensionMismatch {
96                expected: act_max.len(),
97                got: weight_max.len(),
98            });
99        }
100        let alpha = self.config.alpha;
101        let scales = act_max
102            .iter()
103            .zip(weight_max.iter())
104            .map(|(&a_max, &w_max)| {
105                let a = a_max.abs().max(1e-8);
106                let w = w_max.abs().max(1e-8);
107                a.powf(alpha) / w.powf(1.0 - alpha)
108            })
109            .collect();
110        Ok(scales)
111    }
112
113    /// Compute per-channel max absolute values from an activation tensor.
114    ///
115    /// # Parameters
116    ///
117    /// * `acts`       — row-major activation matrix `[n_tokens, n_channels]`.
118    /// * `n_tokens`   — number of tokens (rows).
119    /// * `n_channels` — hidden dimension (columns).
120    ///
121    /// # Errors
122    ///
123    /// * [`QuantError::DimensionMismatch`] — slice length ≠ `n_tokens × n_channels`.
124    /// * [`QuantError::EmptyInput`] — either dimension is 0.
125    pub fn compute_act_stats(
126        acts: &[f32],
127        n_tokens: usize,
128        n_channels: usize,
129    ) -> QuantResult<Vec<f32>> {
130        if acts.is_empty() {
131            return Err(QuantError::EmptyInput(
132                "compute_act_stats: empty activations",
133            ));
134        }
135        if acts.len() != n_tokens * n_channels {
136            return Err(QuantError::DimensionMismatch {
137                expected: n_tokens * n_channels,
138                got: acts.len(),
139            });
140        }
141        let mut stats = vec![0.0_f32; n_channels];
142        for t in 0..n_tokens {
143            for j in 0..n_channels {
144                let v = acts[t * n_channels + j].abs();
145                if v > stats[j] {
146                    stats[j] = v;
147                }
148            }
149        }
150        Ok(stats)
151    }
152
153    /// Compute per-column (input-channel) max absolute values from a weight matrix.
154    ///
155    /// # Parameters
156    ///
157    /// * `weights`    — row-major weight matrix `[n_out, n_channels]`.
158    /// * `n_out`      — number of output features (rows).
159    /// * `n_channels` — number of input features (columns).
160    ///
161    /// # Errors
162    ///
163    /// * [`QuantError::DimensionMismatch`] — slice length ≠ `n_out × n_channels`.
164    /// * [`QuantError::EmptyInput`] — either dimension is 0.
165    pub fn compute_weight_stats(
166        weights: &[f32],
167        n_out: usize,
168        n_channels: usize,
169    ) -> QuantResult<Vec<f32>> {
170        if weights.is_empty() {
171            return Err(QuantError::EmptyInput(
172                "compute_weight_stats: empty weights",
173            ));
174        }
175        if weights.len() != n_out * n_channels {
176            return Err(QuantError::DimensionMismatch {
177                expected: n_out * n_channels,
178                got: weights.len(),
179            });
180        }
181        let mut stats = vec![0.0_f32; n_channels];
182        for r in 0..n_out {
183            for j in 0..n_channels {
184                let v = weights[r * n_channels + j].abs();
185                if v > stats[j] {
186                    stats[j] = v;
187                }
188            }
189        }
190        Ok(stats)
191    }
192
193    /// Divide each activation channel j by `scales[j]` in-place.
194    ///
195    /// # Errors
196    ///
197    /// * [`QuantError::DimensionMismatch`] — inconsistent lengths.
198    pub fn smooth_activations(
199        acts: &mut [f32],
200        scales: &[f32],
201        n_tokens: usize,
202        n_channels: usize,
203    ) -> QuantResult<()> {
204        if acts.len() != n_tokens * n_channels {
205            return Err(QuantError::DimensionMismatch {
206                expected: n_tokens * n_channels,
207                got: acts.len(),
208            });
209        }
210        if scales.len() != n_channels {
211            return Err(QuantError::DimensionMismatch {
212                expected: n_channels,
213                got: scales.len(),
214            });
215        }
216        for t in 0..n_tokens {
217            for j in 0..n_channels {
218                acts[t * n_channels + j] /= scales[j].max(1e-12);
219            }
220        }
221        Ok(())
222    }
223
224    /// Multiply each weight column j (input channel) by `scales[j]` in-place.
225    ///
226    /// Weights are assumed to have shape `[n_out, n_channels]`.
227    ///
228    /// # Errors
229    ///
230    /// * [`QuantError::DimensionMismatch`] — inconsistent lengths.
231    pub fn smooth_weights(
232        weights: &mut [f32],
233        scales: &[f32],
234        n_out: usize,
235        n_channels: usize,
236    ) -> QuantResult<()> {
237        if weights.len() != n_out * n_channels {
238            return Err(QuantError::DimensionMismatch {
239                expected: n_out * n_channels,
240                got: weights.len(),
241            });
242        }
243        if scales.len() != n_channels {
244            return Err(QuantError::DimensionMismatch {
245                expected: n_channels,
246                got: scales.len(),
247            });
248        }
249        for r in 0..n_out {
250            for j in 0..n_channels {
251                weights[r * n_channels + j] *= scales[j];
252            }
253        }
254        Ok(())
255    }
256
257    /// Smooth a complete linear layer: compute scales, apply to activations and weights.
258    ///
259    /// # Parameters
260    ///
261    /// * `acts`       — mutable activation matrix `[n_tokens, n_channels]`.
262    /// * `weights`    — mutable weight matrix `[n_out, n_channels]`.
263    /// * `n_tokens`   — token (batch) dimension.
264    /// * `n_channels` — input feature dimension.
265    /// * `n_out`      — output feature dimension.
266    ///
267    /// # Returns
268    ///
269    /// The per-channel migration scales used (length `n_channels`).
270    ///
271    /// # Errors
272    ///
273    /// Propagates all dimension and empty-input errors from sub-operations.
274    pub fn smooth_layer(
275        &self,
276        acts: &mut [f32],
277        weights: &mut [f32],
278        n_tokens: usize,
279        n_channels: usize,
280        n_out: usize,
281    ) -> QuantResult<Vec<f32>> {
282        let act_stats = Self::compute_act_stats(acts, n_tokens, n_channels)?;
283        let weight_stats = Self::compute_weight_stats(weights, n_out, n_channels)?;
284        let scales = self.compute_migration_scales(&act_stats, &weight_stats)?;
285        Self::smooth_activations(acts, &scales, n_tokens, n_channels)?;
286        Self::smooth_weights(weights, &scales, n_out, n_channels)?;
287        Ok(scales)
288    }
289}
290
291// ─── Tests ───────────────────────────────────────────────────────────────────
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use approx::assert_abs_diff_eq;
297
298    /// Simple matrix multiply for test verification.
299    fn matmul_nt(x: &[f32], w: &[f32], n_tok: usize, n_ch: usize, n_out: usize) -> Vec<f32> {
300        // Y = X W^T   (X: n_tok × n_ch, W: n_out × n_ch)
301        let mut y = vec![0.0_f32; n_tok * n_out];
302        for t in 0..n_tok {
303            for o in 0..n_out {
304                let dot: f32 = (0..n_ch).map(|j| x[t * n_ch + j] * w[o * n_ch + j]).sum();
305                y[t * n_out + o] = dot;
306            }
307        }
308        y
309    }
310
311    #[test]
312    fn scale_alpha_half() {
313        let m = SmoothQuantMigrator::new(0.5);
314        let act_max = vec![4.0_f32, 1.0, 9.0];
315        let weight_max = vec![1.0_f32, 4.0, 1.0];
316        let scales = m.compute_migration_scales(&act_max, &weight_max).unwrap();
317        // s[0] = 4^0.5 / 1^0.5 = 2/1 = 2
318        assert_abs_diff_eq!(scales[0], 2.0, epsilon = 1e-5);
319        // s[1] = 1^0.5 / 4^0.5 = 1/2 = 0.5
320        assert_abs_diff_eq!(scales[1], 0.5, epsilon = 1e-5);
321        // s[2] = 9^0.5 / 1^0.5 = 3/1 = 3
322        assert_abs_diff_eq!(scales[2], 3.0, epsilon = 1e-5);
323    }
324
325    #[test]
326    fn scale_alpha_one_activations_only() {
327        // alpha=1 → s = act_max / weight_max^0 = act_max
328        let m = SmoothQuantMigrator::new(1.0);
329        let act_max = vec![2.0_f32, 5.0];
330        let weight_max = vec![3.0_f32, 7.0]; // ignored
331        let scales = m.compute_migration_scales(&act_max, &weight_max).unwrap();
332        assert_abs_diff_eq!(scales[0], 2.0, epsilon = 1e-5);
333        assert_abs_diff_eq!(scales[1], 5.0, epsilon = 1e-5);
334    }
335
336    #[test]
337    fn scale_alpha_zero_weights_only() {
338        // alpha=0 → s = act_max^0 / weight_max^1 = 1 / weight_max
339        let m = SmoothQuantMigrator::new(0.0);
340        let act_max = vec![4.0_f32, 1.0]; // ignored
341        let weight_max = vec![2.0_f32, 5.0];
342        let scales = m.compute_migration_scales(&act_max, &weight_max).unwrap();
343        assert_abs_diff_eq!(scales[0], 1.0 / 2.0, epsilon = 1e-5);
344        assert_abs_diff_eq!(scales[1], 1.0 / 5.0, epsilon = 1e-5);
345    }
346
347    #[test]
348    fn smoothing_preserves_layer_output() {
349        let m = SmoothQuantMigrator::new(0.5);
350        let n_tok = 3;
351        let n_ch = 4;
352        let n_out = 2;
353        let mut acts: Vec<f32> = (0..(n_tok * n_ch))
354            .map(|i| (i as f32 * 0.3) - 1.0)
355            .collect();
356        let mut weights: Vec<f32> = (0..(n_out * n_ch))
357            .map(|i| (i as f32 * 0.2) - 0.5)
358            .collect();
359
360        // Compute original output.
361        let y_orig = matmul_nt(&acts, &weights, n_tok, n_ch, n_out);
362
363        // Smooth the layer.
364        m.smooth_layer(&mut acts, &mut weights, n_tok, n_ch, n_out)
365            .unwrap();
366
367        // Compute smoothed output.
368        let y_smooth = matmul_nt(&acts, &weights, n_tok, n_ch, n_out);
369
370        // Outputs must match.
371        for (a, b) in y_orig.iter().zip(y_smooth.iter()) {
372            assert_abs_diff_eq!(a, b, epsilon = 1e-4);
373        }
374    }
375
376    #[test]
377    fn activation_stats_max_per_channel() {
378        // 2 tokens, 3 channels
379        // acts = [[1, -5, 2], [-3, 4, 1]]
380        let acts = vec![1.0_f32, -5.0, 2.0, -3.0, 4.0, 1.0];
381        let stats = SmoothQuantMigrator::compute_act_stats(&acts, 2, 3).unwrap();
382        assert_abs_diff_eq!(stats[0], 3.0, epsilon = 1e-6); // max(|1|, |-3|) = 3
383        assert_abs_diff_eq!(stats[1], 5.0, epsilon = 1e-6); // max(|-5|, |4|) = 5
384        assert_abs_diff_eq!(stats[2], 2.0, epsilon = 1e-6); // max(|2|, |1|) = 2
385    }
386
387    #[test]
388    fn weight_stats_max_per_column() {
389        // weights [2 out, 3 in] = [[0.5, -2.0, 1.0], [-1.5, 0.3, 3.0]]
390        let w = vec![0.5_f32, -2.0, 1.0, -1.5, 0.3, 3.0];
391        let stats = SmoothQuantMigrator::compute_weight_stats(&w, 2, 3).unwrap();
392        assert_abs_diff_eq!(stats[0], 1.5, epsilon = 1e-6);
393        assert_abs_diff_eq!(stats[1], 2.0, epsilon = 1e-6);
394        assert_abs_diff_eq!(stats[2], 3.0, epsilon = 1e-6);
395    }
396
397    #[test]
398    fn dimension_mismatch_error() {
399        let m = SmoothQuantMigrator::new(0.5);
400        let act_max = vec![1.0_f32; 3];
401        let weight_max = vec![1.0_f32; 4]; // wrong
402        assert!(matches!(
403            m.compute_migration_scales(&act_max, &weight_max),
404            Err(QuantError::DimensionMismatch { .. })
405        ));
406    }
407
408    #[test]
409    fn empty_input_error() {
410        let m = SmoothQuantMigrator::new(0.5);
411        assert!(matches!(
412            m.compute_migration_scales(&[], &[]),
413            Err(QuantError::EmptyInput(_))
414        ));
415    }
416
417    #[test]
418    fn smoothing_reduces_act_channel_range_imbalance() {
419        // Channel 0 has very large activations, channel 1 is normal.
420        let m = SmoothQuantMigrator::new(0.5);
421        let n_tok = 4;
422        let n_ch = 2;
423        let n_out = 2;
424        let mut acts = vec![100.0_f32, 1.0, -100.0, 1.0, 100.0, -1.0, -100.0, -1.0];
425        let mut weights = vec![0.5_f32, 0.5, -0.5, 0.5];
426
427        let scales = m
428            .smooth_layer(&mut acts, &mut weights, n_tok, n_ch, n_out)
429            .unwrap();
430        // After smoothing, channel 0 max |act| should be reduced.
431        let act_max_0: f32 = (0..n_tok)
432            .map(|t| acts[t * n_ch].abs())
433            .fold(0.0_f32, f32::max);
434        let act_max_1: f32 = (0..n_tok)
435            .map(|t| acts[t * n_ch + 1].abs())
436            .fold(0.0_f32, f32::max);
437        // The ratio should be closer to 1 than the original 100:1.
438        let ratio = act_max_0 / act_max_1.max(1e-8);
439        // scale[0] > 1 means acts[:,0] was divided by > 1, reducing its range.
440        assert!(
441            scales[0] > 1.0,
442            "scale[0] should be > 1 for outlier channel"
443        );
444        assert!(
445            ratio < 100.0,
446            "channel range imbalance should decrease after smoothing"
447        );
448    }
449}