oxibonsai_model/smoothquant.rs
1//! SmoothQuant per-channel FP8 calibrator and channel-aware quantization.
2//!
3//! SmoothQuant (Xiao et al. 2022) addresses the quantization difficulty mismatch
4//! between activations (which have large per-channel outliers) and weights (which
5//! are relatively smooth). It migrates the quantization difficulty from activations
6//! to weights by computing per-channel smoothing factors:
7//!
8//! `s_j = max(|A_j|)^α / max(|W_j|)^(1−α)`
9//!
10//! then rescaling: `Ã[i,j] = A[i,j] / s_j`, `W̃[i,j] = W[i,j] × s_j`.
11//!
12//! This module provides:
13//! - [`SmoothQuantCalibrator`]: online per-channel max-abs accumulator.
14//! - [`quantize_fp8_e4m3_smooth`]: quantize smoothed weights into E4M3FN blocks.
15//! - [`quantize_fp8_e5m2_smooth`]: quantize smoothed weights into E5M2 blocks.
16
17use std::collections::HashMap;
18
19use crate::dynamic_quant::{
20 compute_smooth_factors, smooth_weights, DynQuantError, SmoothQuantConfig,
21};
22use oxibonsai_core::quant_fp8::{BlockFP8E4M3, BlockFP8E5M2};
23
24// ─── Error ────────────────────────────────────────────────────────────────────
25
26/// Errors produced by the SmoothQuant calibrator and channel-aware quantization.
27#[derive(Debug, Clone)]
28pub enum SmoothQuantError {
29 /// Calibrator has no recorded layers.
30 EmptyCalibrator,
31 /// The requested layer has not been recorded.
32 LayerNotFound(String),
33 /// The supplied `in_features` doesn't match what was originally recorded.
34 InFeaturesMismatch { expected: usize, got: usize },
35 /// An underlying quantization operation failed.
36 QuantizationError(String),
37}
38
39impl std::fmt::Display for SmoothQuantError {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 Self::EmptyCalibrator => write!(f, "SmoothQuant calibrator has no recorded layers"),
43 Self::LayerNotFound(name) => {
44 write!(f, "SmoothQuant calibrator: layer '{name}' not found")
45 }
46 Self::InFeaturesMismatch { expected, got } => write!(
47 f,
48 "SmoothQuant calibrator: in_features mismatch — expected {expected}, got {got}"
49 ),
50 Self::QuantizationError(msg) => {
51 write!(f, "SmoothQuant quantization error: {msg}")
52 }
53 }
54 }
55}
56
57impl std::error::Error for SmoothQuantError {}
58
59// ─── Internal per-layer channel statistics ────────────────────────────────────
60
61/// Per-layer per-channel running statistics used by the calibrator.
62struct ChannelStats {
63 /// Number of input features (columns) for this layer.
64 in_features: usize,
65 /// Running maximum of |activation| across all batches, one slot per column.
66 running_max_abs: Vec<f32>,
67 /// Total number of activation batches recorded for this layer.
68 sample_count: usize,
69}
70
71impl ChannelStats {
72 fn new(in_features: usize) -> Self {
73 Self {
74 in_features,
75 running_max_abs: vec![0.0_f32; in_features],
76 sample_count: 0,
77 }
78 }
79
80 /// Update running per-channel max-abs from one batch of activations.
81 ///
82 /// `activations` is `[num_tokens × in_features]` row-major.
83 fn update(&mut self, activations: &[f32], in_features: usize) {
84 debug_assert_eq!(in_features, self.in_features);
85 let num_tokens = activations.len() / in_features;
86 for t in 0..num_tokens {
87 for (j, slot) in self.running_max_abs.iter_mut().enumerate() {
88 let idx = t * in_features + j;
89 if idx < activations.len() {
90 let v = activations[idx].abs();
91 if v > *slot {
92 *slot = v;
93 }
94 }
95 }
96 }
97 self.sample_count += 1;
98 }
99}
100
101// ─── SmoothQuantCalibrator ────────────────────────────────────────────────────
102
103/// Online per-channel activation calibrator for SmoothQuant.
104///
105/// Feed batches of activations for each named linear layer via
106/// [`record_activation`][Self::record_activation], then call
107/// [`smooth_factors`][Self::smooth_factors] to obtain the SmoothQuant
108/// smoothing vector for that layer's weight matrix.
109pub struct SmoothQuantCalibrator {
110 layers: HashMap<String, ChannelStats>,
111 config: SmoothQuantConfig,
112}
113
114impl SmoothQuantCalibrator {
115 /// Create a new calibrator using the given SmoothQuant config.
116 pub fn new(config: SmoothQuantConfig) -> Self {
117 Self {
118 layers: HashMap::new(),
119 config,
120 }
121 }
122
123 /// Record one batch of activations for a named layer.
124 ///
125 /// `activations` is a `[num_tokens × in_features]` row-major flat slice.
126 /// If this is the first call for `layer_name`, a new per-channel accumulator
127 /// is created. For subsequent calls the running per-channel max-abs is updated.
128 ///
129 /// # Panics
130 ///
131 /// Panics if `in_features` changes between calls for the same `layer_name`.
132 pub fn record_activation(&mut self, layer_name: &str, activations: &[f32], in_features: usize) {
133 if in_features == 0 || activations.is_empty() {
134 return;
135 }
136
137 let stats = self
138 .layers
139 .entry(layer_name.to_owned())
140 .or_insert_with(|| ChannelStats::new(in_features));
141
142 if stats.in_features != in_features {
143 panic!(
144 "SmoothQuantCalibrator::record_activation: in_features mismatch for layer '{}' \
145 — expected {}, got {}",
146 layer_name, stats.in_features, in_features
147 );
148 }
149
150 stats.update(activations, in_features);
151 }
152
153 /// Compute SmoothQuant smoothing factors for a named layer.
154 ///
155 /// Uses the running per-channel activation max accumulated via
156 /// [`record_activation`][Self::record_activation] together with the supplied
157 /// weight matrix to derive per-input-feature smoothing factors.
158 ///
159 /// `weights` is `[out_features × in_features]` row-major.
160 ///
161 /// Returns a `Vec<f32>` of length `in_features`.
162 pub fn smooth_factors(
163 &self,
164 layer_name: &str,
165 weights: &[f32],
166 out_features: usize,
167 ) -> Result<Vec<f32>, SmoothQuantError> {
168 let stats = self
169 .layers
170 .get(layer_name)
171 .ok_or_else(|| SmoothQuantError::LayerNotFound(layer_name.to_owned()))?;
172
173 let in_features = stats.in_features;
174
175 // We pass running_max_abs as a synthetic single-row "activation" matrix
176 // (shape [1 × in_features]). compute_smooth_factors will compute the
177 // per-column max across that one row, which gives exactly running_max_abs[j].
178 let factors = compute_smooth_factors(
179 &stats.running_max_abs,
180 weights,
181 in_features,
182 1, // tokens = 1 (running_max_abs is already the global max)
183 out_features,
184 &self.config,
185 );
186
187 Ok(factors)
188 }
189
190 /// Number of distinct layers recorded by this calibrator.
191 pub fn layer_count(&self) -> usize {
192 self.layers.len()
193 }
194
195 /// Whether the calibrator has any recorded data for the given layer name.
196 pub fn has_layer(&self, name: &str) -> bool {
197 self.layers.contains_key(name)
198 }
199}
200
201// ─── Channel-aware FP8 quantization ──────────────────────────────────────────
202
203/// Quantize a weight matrix using SmoothQuant scaling into FP8 E4M3FN blocks.
204///
205/// The smoothing factors obtained from [`SmoothQuantCalibrator::smooth_factors`]
206/// are applied to the weight matrix in-place (W̃ = W × s_j for column j), then
207/// the smoothed weights are quantized into [`BlockFP8E4M3`] blocks.
208///
209/// `weights` is `[out_features × in_features]` row-major.
210/// `smooth_factors` must have length `in_features`.
211///
212/// Returns the resulting FP8 block vector; total elements must be a multiple of
213/// `QK_FP8` (32), so `out_features × in_features` must be divisible by 32.
214pub fn quantize_fp8_e4m3_smooth(
215 weights: &[f32],
216 out_features: usize,
217 in_features: usize,
218 smooth_factors: &[f32],
219) -> Result<Vec<BlockFP8E4M3>, SmoothQuantError> {
220 if smooth_factors.len() != in_features {
221 return Err(SmoothQuantError::InFeaturesMismatch {
222 expected: in_features,
223 got: smooth_factors.len(),
224 });
225 }
226
227 // Clone and apply SmoothQuant weight scaling: W̃[i,j] = W[i,j] * s_j
228 let mut smoothed = weights.to_vec();
229 smooth_weights(&mut smoothed, smooth_factors, out_features, in_features)
230 .map_err(|e: DynQuantError| SmoothQuantError::QuantizationError(e.to_string()))?;
231
232 // Quantize the smoothed weights into FP8 E4M3FN blocks.
233 BlockFP8E4M3::quantize(&smoothed)
234 .map_err(|e| SmoothQuantError::QuantizationError(e.to_string()))
235}
236
237/// Quantize a weight matrix using SmoothQuant scaling into FP8 E5M2 blocks.
238///
239/// Mirrors [`quantize_fp8_e4m3_smooth`] for the E5M2 format.
240///
241/// `weights` is `[out_features × in_features]` row-major.
242/// `smooth_factors` must have length `in_features`.
243pub fn quantize_fp8_e5m2_smooth(
244 weights: &[f32],
245 out_features: usize,
246 in_features: usize,
247 smooth_factors: &[f32],
248) -> Result<Vec<BlockFP8E5M2>, SmoothQuantError> {
249 if smooth_factors.len() != in_features {
250 return Err(SmoothQuantError::InFeaturesMismatch {
251 expected: in_features,
252 got: smooth_factors.len(),
253 });
254 }
255
256 let mut smoothed = weights.to_vec();
257 smooth_weights(&mut smoothed, smooth_factors, out_features, in_features)
258 .map_err(|e: DynQuantError| SmoothQuantError::QuantizationError(e.to_string()))?;
259
260 BlockFP8E5M2::quantize(&smoothed)
261 .map_err(|e| SmoothQuantError::QuantizationError(e.to_string()))
262}