Skip to main content

oxibonsai_model/
smoothquant.rs

1//! SmoothQuant per-channel FP8 calibrator and channel-aware quantization.
2//!
3//! SmoothQuant (Xiao et al. 2022) addresses the quantization difficulty mismatch
4//! between activations (which have large per-channel outliers) and weights (which
5//! are relatively smooth). It migrates the quantization difficulty from activations
6//! to weights by computing per-channel smoothing factors:
7//!
8//!   `s_j = max(|A_j|)^α / max(|W_j|)^(1−α)`
9//!
10//! then rescaling: `Ã[i,j] = A[i,j] / s_j`, `W̃[i,j] = W[i,j] × s_j`.
11//!
12//! This module provides:
13//! - [`SmoothQuantCalibrator`]: online per-channel max-abs accumulator.
14//! - [`quantize_fp8_e4m3_smooth`]: quantize smoothed weights into E4M3FN blocks.
15//! - [`quantize_fp8_e5m2_smooth`]: quantize smoothed weights into E5M2 blocks.
16
17use std::collections::HashMap;
18
19use crate::dynamic_quant::{
20    compute_smooth_factors, smooth_weights, DynQuantError, SmoothQuantConfig,
21};
22use oxibonsai_core::quant_fp8::{BlockFP8E4M3, BlockFP8E5M2};
23
24// ─── Error ────────────────────────────────────────────────────────────────────
25
26/// Errors produced by the SmoothQuant calibrator and channel-aware quantization.
27#[derive(Debug, Clone)]
28pub enum SmoothQuantError {
29    /// Calibrator has no recorded layers.
30    EmptyCalibrator,
31    /// The requested layer has not been recorded.
32    LayerNotFound(String),
33    /// The supplied `in_features` doesn't match what was originally recorded.
34    InFeaturesMismatch { expected: usize, got: usize },
35    /// An underlying quantization operation failed.
36    QuantizationError(String),
37}
38
39impl std::fmt::Display for SmoothQuantError {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            Self::EmptyCalibrator => write!(f, "SmoothQuant calibrator has no recorded layers"),
43            Self::LayerNotFound(name) => {
44                write!(f, "SmoothQuant calibrator: layer '{name}' not found")
45            }
46            Self::InFeaturesMismatch { expected, got } => write!(
47                f,
48                "SmoothQuant calibrator: in_features mismatch — expected {expected}, got {got}"
49            ),
50            Self::QuantizationError(msg) => {
51                write!(f, "SmoothQuant quantization error: {msg}")
52            }
53        }
54    }
55}
56
57impl std::error::Error for SmoothQuantError {}
58
59// ─── Internal per-layer channel statistics ────────────────────────────────────
60
61/// Per-layer per-channel running statistics used by the calibrator.
62struct ChannelStats {
63    /// Number of input features (columns) for this layer.
64    in_features: usize,
65    /// Running maximum of |activation| across all batches, one slot per column.
66    running_max_abs: Vec<f32>,
67    /// Total number of activation batches recorded for this layer.
68    sample_count: usize,
69}
70
71impl ChannelStats {
72    fn new(in_features: usize) -> Self {
73        Self {
74            in_features,
75            running_max_abs: vec![0.0_f32; in_features],
76            sample_count: 0,
77        }
78    }
79
80    /// Update running per-channel max-abs from one batch of activations.
81    ///
82    /// `activations` is `[num_tokens × in_features]` row-major.
83    fn update(&mut self, activations: &[f32], in_features: usize) {
84        debug_assert_eq!(in_features, self.in_features);
85        let num_tokens = activations.len() / in_features;
86        for t in 0..num_tokens {
87            for (j, slot) in self.running_max_abs.iter_mut().enumerate() {
88                let idx = t * in_features + j;
89                if idx < activations.len() {
90                    let v = activations[idx].abs();
91                    if v > *slot {
92                        *slot = v;
93                    }
94                }
95            }
96        }
97        self.sample_count += 1;
98    }
99}
100
101// ─── SmoothQuantCalibrator ────────────────────────────────────────────────────
102
103/// Online per-channel activation calibrator for SmoothQuant.
104///
105/// Feed batches of activations for each named linear layer via
106/// [`record_activation`][Self::record_activation], then call
107/// [`smooth_factors`][Self::smooth_factors] to obtain the SmoothQuant
108/// smoothing vector for that layer's weight matrix.
109pub struct SmoothQuantCalibrator {
110    layers: HashMap<String, ChannelStats>,
111    config: SmoothQuantConfig,
112}
113
114impl SmoothQuantCalibrator {
115    /// Create a new calibrator using the given SmoothQuant config.
116    pub fn new(config: SmoothQuantConfig) -> Self {
117        Self {
118            layers: HashMap::new(),
119            config,
120        }
121    }
122
123    /// Record one batch of activations for a named layer.
124    ///
125    /// `activations` is a `[num_tokens × in_features]` row-major flat slice.
126    /// If this is the first call for `layer_name`, a new per-channel accumulator
127    /// is created. For subsequent calls the running per-channel max-abs is updated.
128    ///
129    /// # Panics
130    ///
131    /// Panics if `in_features` changes between calls for the same `layer_name`.
132    pub fn record_activation(&mut self, layer_name: &str, activations: &[f32], in_features: usize) {
133        if in_features == 0 || activations.is_empty() {
134            return;
135        }
136
137        let stats = self
138            .layers
139            .entry(layer_name.to_owned())
140            .or_insert_with(|| ChannelStats::new(in_features));
141
142        if stats.in_features != in_features {
143            panic!(
144                "SmoothQuantCalibrator::record_activation: in_features mismatch for layer '{}' \
145                 — expected {}, got {}",
146                layer_name, stats.in_features, in_features
147            );
148        }
149
150        stats.update(activations, in_features);
151    }
152
153    /// Compute SmoothQuant smoothing factors for a named layer.
154    ///
155    /// Uses the running per-channel activation max accumulated via
156    /// [`record_activation`][Self::record_activation] together with the supplied
157    /// weight matrix to derive per-input-feature smoothing factors.
158    ///
159    /// `weights` is `[out_features × in_features]` row-major.
160    ///
161    /// Returns a `Vec<f32>` of length `in_features`.
162    pub fn smooth_factors(
163        &self,
164        layer_name: &str,
165        weights: &[f32],
166        out_features: usize,
167    ) -> Result<Vec<f32>, SmoothQuantError> {
168        let stats = self
169            .layers
170            .get(layer_name)
171            .ok_or_else(|| SmoothQuantError::LayerNotFound(layer_name.to_owned()))?;
172
173        let in_features = stats.in_features;
174
175        // We pass running_max_abs as a synthetic single-row "activation" matrix
176        // (shape [1 × in_features]).  compute_smooth_factors will compute the
177        // per-column max across that one row, which gives exactly running_max_abs[j].
178        let factors = compute_smooth_factors(
179            &stats.running_max_abs,
180            weights,
181            in_features,
182            1, // tokens = 1 (running_max_abs is already the global max)
183            out_features,
184            &self.config,
185        );
186
187        Ok(factors)
188    }
189
190    /// Number of distinct layers recorded by this calibrator.
191    pub fn layer_count(&self) -> usize {
192        self.layers.len()
193    }
194
195    /// Whether the calibrator has any recorded data for the given layer name.
196    pub fn has_layer(&self, name: &str) -> bool {
197        self.layers.contains_key(name)
198    }
199}
200
201// ─── Channel-aware FP8 quantization ──────────────────────────────────────────
202
203/// Quantize a weight matrix using SmoothQuant scaling into FP8 E4M3FN blocks.
204///
205/// The smoothing factors obtained from [`SmoothQuantCalibrator::smooth_factors`]
206/// are applied to the weight matrix in-place (W̃ = W × s_j for column j), then
207/// the smoothed weights are quantized into [`BlockFP8E4M3`] blocks.
208///
209/// `weights` is `[out_features × in_features]` row-major.
210/// `smooth_factors` must have length `in_features`.
211///
212/// Returns the resulting FP8 block vector; total elements must be a multiple of
213/// `QK_FP8` (32), so `out_features × in_features` must be divisible by 32.
214pub fn quantize_fp8_e4m3_smooth(
215    weights: &[f32],
216    out_features: usize,
217    in_features: usize,
218    smooth_factors: &[f32],
219) -> Result<Vec<BlockFP8E4M3>, SmoothQuantError> {
220    if smooth_factors.len() != in_features {
221        return Err(SmoothQuantError::InFeaturesMismatch {
222            expected: in_features,
223            got: smooth_factors.len(),
224        });
225    }
226
227    // Clone and apply SmoothQuant weight scaling: W̃[i,j] = W[i,j] * s_j
228    let mut smoothed = weights.to_vec();
229    smooth_weights(&mut smoothed, smooth_factors, out_features, in_features)
230        .map_err(|e: DynQuantError| SmoothQuantError::QuantizationError(e.to_string()))?;
231
232    // Quantize the smoothed weights into FP8 E4M3FN blocks.
233    BlockFP8E4M3::quantize(&smoothed)
234        .map_err(|e| SmoothQuantError::QuantizationError(e.to_string()))
235}
236
237/// Quantize a weight matrix using SmoothQuant scaling into FP8 E5M2 blocks.
238///
239/// Mirrors [`quantize_fp8_e4m3_smooth`] for the E5M2 format.
240///
241/// `weights` is `[out_features × in_features]` row-major.
242/// `smooth_factors` must have length `in_features`.
243pub fn quantize_fp8_e5m2_smooth(
244    weights: &[f32],
245    out_features: usize,
246    in_features: usize,
247    smooth_factors: &[f32],
248) -> Result<Vec<BlockFP8E5M2>, SmoothQuantError> {
249    if smooth_factors.len() != in_features {
250        return Err(SmoothQuantError::InFeaturesMismatch {
251            expected: in_features,
252            got: smooth_factors.len(),
253        });
254    }
255
256    let mut smoothed = weights.to_vec();
257    smooth_weights(&mut smoothed, smooth_factors, out_features, in_features)
258        .map_err(|e: DynQuantError| SmoothQuantError::QuantizationError(e.to_string()))?;
259
260    BlockFP8E5M2::quantize(&smoothed)
261        .map_err(|e| SmoothQuantError::QuantizationError(e.to_string()))
262}