Skip to main content

oxibonsai_model/
dynamic_quant.rs

1//! Dynamic activation quantization for W8A8 / W4A8 inference.
2//!
3//! Unlike static quantization (which uses pre-computed scales), dynamic
4//! quantization computes the quantization scale from the current activation
5//! values at inference time. This is slower than static but more accurate.
6//!
7//! # Supported formats
8//! - `DynamicInt8`: Per-tensor symmetric INT8 (1 scale per tensor)
9//! - `DynamicInt8PerRow`: Per-row symmetric INT8 (1 scale per row in a 2D tensor)
10//! - `DynamicInt4`: Per-tensor symmetric INT4 (values in [-7, 7], using i8 storage)
11//! - `SmoothQuant`: Activation-weight smoothing to reduce quantization error
12
13/// How to compute the dynamic quantization scale.
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum DynamicScaleMode {
16    /// Use the max absolute value: scale = max(|x|) / clip_val
17    MaxAbs,
18    /// Use a percentile of absolute values (more robust to outliers).
19    /// `percentile` in (0, 1] — e.g. 0.99 = 99th percentile
20    Percentile(f32),
21}
22
23/// Format of a dynamically quantized tensor.
24#[derive(Debug, Clone, Copy, PartialEq)]
25pub enum DynQuantFormat {
26    /// Per-tensor INT8: one scale for the whole tensor.
27    Int8PerTensor,
28    /// Per-row INT8: one scale per row in a 2D tensor.
29    Int8PerRow,
30    /// Per-tensor INT4: packed 2 values per byte, stored as i8 in [-7, 7].
31    Int4PerTensor,
32}
33
34/// A dynamically quantized tensor.
35#[derive(Debug, Clone)]
36pub struct DynQuantTensor {
37    /// Quantized values (i8 storage for both INT8 and INT4).
38    pub data: Vec<i8>,
39    /// Scales, one per quantization group.
40    pub scales: Vec<f32>,
41    /// Shape of the original tensor.
42    pub shape: Vec<usize>,
43    /// Quantization format.
44    pub format: DynQuantFormat,
45}
46
47impl DynQuantTensor {
48    /// Dequantize back to f32.
49    pub fn dequantize(&self) -> Vec<f32> {
50        match self.format {
51            DynQuantFormat::Int8PerTensor => {
52                let scale = self.scales.first().copied().unwrap_or(0.0);
53                self.data.iter().map(|&q| q as f32 * scale).collect()
54            }
55            DynQuantFormat::Int8PerRow => {
56                if self.scales.is_empty() || self.data.is_empty() {
57                    return Vec::new();
58                }
59                let rows = self.scales.len();
60                let cols = self.data.len() / rows.max(1);
61                let mut out = Vec::with_capacity(self.data.len());
62                for (r, &scale) in self.scales.iter().enumerate() {
63                    let start = r * cols;
64                    let end = (start + cols).min(self.data.len());
65                    for &q in &self.data[start..end] {
66                        out.push(q as f32 * scale);
67                    }
68                }
69                out
70            }
71            DynQuantFormat::Int4PerTensor => {
72                let scale = self.scales.first().copied().unwrap_or(0.0);
73                self.data.iter().map(|&q| q as f32 * scale).collect()
74            }
75        }
76    }
77
78    /// Memory in bytes (data + scales).
79    pub fn memory_bytes(&self) -> usize {
80        self.data.len() + self.scales.len() * core::mem::size_of::<f32>()
81    }
82
83    /// Compression ratio vs f32 (data only, excluding scales).
84    pub fn compression_ratio(&self) -> f32 {
85        let original_bytes = self.data.len() * core::mem::size_of::<f32>();
86        let quantized_bytes = self.memory_bytes();
87        if quantized_bytes == 0 {
88            return 1.0;
89        }
90        original_bytes as f32 / quantized_bytes as f32
91    }
92
93    /// Number of elements.
94    pub fn element_count(&self) -> usize {
95        self.data.len()
96    }
97}
98
99// ─── Scale computation ────────────────────────────────────────────────────────
100
101/// Compute the quantization scale for a slice.
102///
103/// - `MaxAbs`: `scale = max(|x|) / clip_val`
104/// - `Percentile(p)`: sort absolute values, use p-th percentile value / clip_val
105pub fn compute_scale(data: &[f32], clip_val: f32, mode: DynamicScaleMode) -> f32 {
106    if data.is_empty() {
107        return 0.0;
108    }
109
110    let abs_max = match mode {
111        DynamicScaleMode::MaxAbs => data.iter().map(|x| x.abs()).fold(0.0_f32, f32::max),
112        DynamicScaleMode::Percentile(p) => {
113            let p_clamped = p.clamp(0.0, 1.0);
114            let mut abs_vals: Vec<f32> = data.iter().map(|x| x.abs()).collect();
115            // Sort ascending
116            abs_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
117            let len = abs_vals.len();
118            // Compute index: ceiling of p * len, then subtract 1, clamped
119            let idx = ((p_clamped * len as f32).ceil() as usize)
120                .saturating_sub(1)
121                .min(len - 1);
122            abs_vals[idx]
123        }
124    };
125
126    if abs_max == 0.0 {
127        return 0.0;
128    }
129
130    abs_max / clip_val
131}
132
133// ─── INT8 per-tensor ──────────────────────────────────────────────────────────
134
135/// Dynamically quantize a 1D activation tensor to INT8 (per-tensor).
136pub fn dynamic_quantize_int8(data: &[f32], mode: DynamicScaleMode) -> DynQuantTensor {
137    const CLIP_VAL: f32 = 127.0;
138
139    if data.is_empty() {
140        return DynQuantTensor {
141            data: Vec::new(),
142            scales: vec![0.0],
143            shape: vec![0],
144            format: DynQuantFormat::Int8PerTensor,
145        };
146    }
147
148    let scale = compute_scale(data, CLIP_VAL, mode);
149
150    let quantized: Vec<i8> = if scale == 0.0 {
151        vec![0i8; data.len()]
152    } else {
153        data.iter()
154            .map(|&x| (x / scale).round().clamp(-127.0, 127.0) as i8)
155            .collect()
156    };
157
158    DynQuantTensor {
159        data: quantized,
160        scales: vec![scale],
161        shape: vec![data.len()],
162        format: DynQuantFormat::Int8PerTensor,
163    }
164}
165
166// ─── INT8 per-row ─────────────────────────────────────────────────────────────
167
168/// Dynamically quantize a 2D activation tensor to INT8, one scale per row.
169///
170/// `data` is row-major with shape `[rows, cols]`.
171pub fn dynamic_quantize_int8_per_row(
172    data: &[f32],
173    rows: usize,
174    cols: usize,
175    mode: DynamicScaleMode,
176) -> DynQuantTensor {
177    const CLIP_VAL: f32 = 127.0;
178
179    if data.is_empty() || rows == 0 || cols == 0 {
180        return DynQuantTensor {
181            data: Vec::new(),
182            scales: Vec::new(),
183            shape: vec![rows, cols],
184            format: DynQuantFormat::Int8PerRow,
185        };
186    }
187
188    let total = rows * cols;
189    let actual_len = data.len().min(total);
190
191    let mut quantized = Vec::with_capacity(actual_len);
192    let mut scales = Vec::with_capacity(rows);
193
194    for r in 0..rows {
195        let start = r * cols;
196        let end = (start + cols).min(data.len());
197        if start >= data.len() {
198            // Pad with zeros if row is out of bounds
199            quantized.extend(vec![0i8; cols]);
200            scales.push(0.0_f32);
201            continue;
202        }
203        let row = &data[start..end];
204        let scale = compute_scale(row, CLIP_VAL, mode);
205        scales.push(scale);
206        if scale == 0.0 {
207            quantized.extend(vec![0i8; row.len()]);
208        } else {
209            for &x in row {
210                quantized.push((x / scale).round().clamp(-127.0, 127.0) as i8);
211            }
212        }
213    }
214
215    DynQuantTensor {
216        data: quantized,
217        scales,
218        shape: vec![rows, cols],
219        format: DynQuantFormat::Int8PerRow,
220    }
221}
222
223// ─── INT4 per-tensor ──────────────────────────────────────────────────────────
224
225/// INT4 quantization: clamp to [-7, 7], stored as i8.
226///
227/// `scale = max(|x|) / 7.0`
228pub fn dynamic_quantize_int4(data: &[f32], mode: DynamicScaleMode) -> DynQuantTensor {
229    const CLIP_VAL: f32 = 7.0;
230
231    if data.is_empty() {
232        return DynQuantTensor {
233            data: Vec::new(),
234            scales: vec![0.0],
235            shape: vec![0],
236            format: DynQuantFormat::Int4PerTensor,
237        };
238    }
239
240    let scale = compute_scale(data, CLIP_VAL, mode);
241
242    let quantized: Vec<i8> = if scale == 0.0 {
243        vec![0i8; data.len()]
244    } else {
245        data.iter()
246            .map(|&x| (x / scale).round().clamp(-7.0, 7.0) as i8)
247            .collect()
248    };
249
250    DynQuantTensor {
251        data: quantized,
252        scales: vec![scale],
253        shape: vec![data.len()],
254        format: DynQuantFormat::Int4PerTensor,
255    }
256}
257
258// ─── Error metrics ────────────────────────────────────────────────────────────
259
260/// Mean absolute quantization error between original f32 data and a quantized tensor.
261pub fn quantization_mae(original: &[f32], quantized: &DynQuantTensor) -> f32 {
262    let reconstructed = quantized.dequantize();
263    let n = original.len().min(reconstructed.len());
264    if n == 0 {
265        return 0.0;
266    }
267    let sum_abs_err: f32 = original[..n]
268        .iter()
269        .zip(reconstructed[..n].iter())
270        .map(|(&o, &r)| (o - r).abs())
271        .sum();
272    sum_abs_err / n as f32
273}
274
275// ─── SmoothQuant ─────────────────────────────────────────────────────────────
276
277/// SmoothQuant configuration: redistribute quantization difficulty from activations to weights.
278///
279/// Smoothing factor: `s_j = max(|A_j|)^α / max(|W_j|)^(1-α)`
280/// Then: `Ã = A / s`, `W̃ = W * s`
281#[derive(Debug, Clone)]
282pub struct SmoothQuantConfig {
283    /// Balance factor in [0, 1]. Typically 0.5.
284    pub alpha: f32,
285    /// Floor for scale values to avoid division by zero.
286    pub epsilon: f32,
287}
288
289impl SmoothQuantConfig {
290    /// Create a new config with the given alpha (must be in [0, 1]).
291    pub fn new(alpha: f32) -> Self {
292        Self {
293            alpha: alpha.clamp(0.0, 1.0),
294            epsilon: 1e-5,
295        }
296    }
297
298    /// Default config with alpha = 0.5.
299    pub fn default_alpha() -> Self {
300        Self::new(0.5)
301    }
302}
303
304/// Compute SmoothQuant smoothing factors (one per input feature).
305///
306/// - `activations`: shape `[tokens, in_features]` (row-major)
307/// - `weights`: shape `[out_features, in_features]` (row-major)
308/// - Returns: smoothing factors of length `in_features`
309pub fn compute_smooth_factors(
310    activations: &[f32],
311    weights: &[f32],
312    in_features: usize,
313    tokens: usize,
314    out_features: usize,
315    config: &SmoothQuantConfig,
316) -> Vec<f32> {
317    if in_features == 0 {
318        return Vec::new();
319    }
320
321    let alpha = config.alpha.clamp(0.0, 1.0);
322    let epsilon = config.epsilon.max(1e-10);
323
324    // Compute per-column max abs of activations: shape [tokens, in_features]
325    let mut act_max = vec![0.0_f32; in_features];
326    for t in 0..tokens {
327        for (j, slot) in act_max.iter_mut().enumerate() {
328            let idx = t * in_features + j;
329            if idx < activations.len() {
330                let v = activations[idx].abs();
331                if v > *slot {
332                    *slot = v;
333                }
334            }
335        }
336    }
337
338    // Compute per-column max abs of weights: shape [out_features, in_features]
339    let mut w_max = vec![0.0_f32; in_features];
340    for o in 0..out_features {
341        for (j, slot) in w_max.iter_mut().enumerate() {
342            let idx = o * in_features + j;
343            if idx < weights.len() {
344                let v = weights[idx].abs();
345                if v > *slot {
346                    *slot = v;
347                }
348            }
349        }
350    }
351
352    // s_j = max(|A_j|)^alpha / max(|W_j|)^(1 - alpha)
353    (0..in_features)
354        .map(|j| {
355            let a = (act_max[j] + epsilon).powf(alpha);
356            let w = (w_max[j] + epsilon).powf(1.0 - alpha);
357            (a / w).max(epsilon)
358        })
359        .collect()
360}
361
362/// Apply smoothing factors to activations in-place: `A_smooth[i,j] = A[i,j] / s[j]`.
363pub fn smooth_activations(
364    activations: &mut [f32],
365    smooth_factors: &[f32],
366    tokens: usize,
367    in_features: usize,
368) -> Result<(), DynQuantError> {
369    if smooth_factors.len() != in_features {
370        return Err(DynQuantError::FeatureDimMismatch {
371            in_features,
372            sf_len: smooth_factors.len(),
373        });
374    }
375    let expected = tokens * in_features;
376    if activations.len() != expected {
377        return Err(DynQuantError::ShapeMismatch {
378            expected,
379            actual: activations.len(),
380        });
381    }
382    for t in 0..tokens {
383        for (j, &sf) in smooth_factors.iter().enumerate() {
384            let idx = t * in_features + j;
385            activations[idx] /= sf;
386        }
387    }
388    Ok(())
389}
390
391/// Apply smoothing factors to weights in-place: `W_smooth[i,j] = W[i,j] * s[j]`.
392pub fn smooth_weights(
393    weights: &mut [f32],
394    smooth_factors: &[f32],
395    out_features: usize,
396    in_features: usize,
397) -> Result<(), DynQuantError> {
398    if smooth_factors.len() != in_features {
399        return Err(DynQuantError::FeatureDimMismatch {
400            in_features,
401            sf_len: smooth_factors.len(),
402        });
403    }
404    let expected = out_features * in_features;
405    if weights.len() != expected {
406        return Err(DynQuantError::ShapeMismatch {
407            expected,
408            actual: weights.len(),
409        });
410    }
411    for o in 0..out_features {
412        for (j, &sf) in smooth_factors.iter().enumerate() {
413            let idx = o * in_features + j;
414            weights[idx] *= sf;
415        }
416    }
417    Ok(())
418}
419
420// ─── W8A8 GEMV ────────────────────────────────────────────────────────────────
421
422/// W8A8 matrix-vector multiply: quantize activation on-the-fly, then perform INT8 GEMV.
423///
424/// - `weight_i8`: shape `[out_size, in_size]` pre-quantized INT8 (row-major)
425/// - `weight_scales`: shape `[out_size]` per-row dequant scales
426/// - `activation`: shape `[in_size]` — dynamically quantized per-tensor
427/// - Returns: shape `[out_size]` as f32
428pub fn w8a8_matvec(
429    weight_i8: &[i8],
430    weight_scales: &[f32],
431    activation: &[f32],
432    out_size: usize,
433    in_size: usize,
434) -> Result<Vec<f32>, DynQuantError> {
435    if activation.is_empty() {
436        return Err(DynQuantError::EmptyInput);
437    }
438    if activation.len() != in_size {
439        return Err(DynQuantError::ShapeMismatch {
440            expected: in_size,
441            actual: activation.len(),
442        });
443    }
444    let expected_w = out_size * in_size;
445    if weight_i8.len() != expected_w {
446        return Err(DynQuantError::ShapeMismatch {
447            expected: expected_w,
448            actual: weight_i8.len(),
449        });
450    }
451    if weight_scales.len() != out_size {
452        return Err(DynQuantError::ShapeMismatch {
453            expected: out_size,
454            actual: weight_scales.len(),
455        });
456    }
457
458    // Dynamically quantize activation per-tensor
459    let act_quant = dynamic_quantize_int8(activation, DynamicScaleMode::MaxAbs);
460    let act_scale = act_quant.scales.first().copied().unwrap_or(0.0);
461    let act_i8 = &act_quant.data;
462
463    let mut output = vec![0.0_f32; out_size];
464
465    for o in 0..out_size {
466        let row_start = o * in_size;
467        let row_end = row_start + in_size;
468        let row = &weight_i8[row_start..row_end];
469
470        let mut acc = 0_i32;
471        for (&w, &a) in row.iter().zip(act_i8.iter()) {
472            acc += w as i32 * a as i32;
473        }
474
475        // Dequantize: result = acc * w_scale * act_scale
476        output[o] = acc as f32 * weight_scales[o] * act_scale;
477    }
478
479    Ok(output)
480}
481
482// ─── Calibration statistics ───────────────────────────────────────────────────
483
484/// Calibration statistics for choosing static quantization scales.
485#[derive(Debug, Clone)]
486pub struct CalibStats {
487    /// Minimum value across all batches.
488    pub min: f32,
489    /// Maximum value across all batches.
490    pub max: f32,
491    /// Mean value across all batches.
492    pub mean: f32,
493    /// Standard deviation across all batches.
494    pub std_dev: f32,
495    /// 99th percentile of absolute values.
496    pub p99: f32,
497    /// Suggested quantization scale (p99 / 127.0 for INT8).
498    pub suggested_scale: f32,
499}
500
501impl CalibStats {
502    /// Collect calibration statistics from a batch of activation vectors.
503    pub fn collect(batches: &[Vec<f32>]) -> Self {
504        let all_values: Vec<f32> = batches.iter().flat_map(|b| b.iter().copied()).collect();
505
506        if all_values.is_empty() {
507            return Self {
508                min: 0.0,
509                max: 0.0,
510                mean: 0.0,
511                std_dev: 0.0,
512                p99: 0.0,
513                suggested_scale: 0.0,
514            };
515        }
516
517        let n = all_values.len();
518
519        // min and max
520        let min_val = all_values.iter().copied().fold(f32::INFINITY, f32::min);
521        let max_val = all_values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
522
523        // mean
524        let sum: f32 = all_values.iter().sum();
525        let mean_val = sum / n as f32;
526
527        // std dev
528        let variance: f32 = all_values
529            .iter()
530            .map(|&x| {
531                let d = x - mean_val;
532                d * d
533            })
534            .sum::<f32>()
535            / n as f32;
536        let std_dev_val = variance.sqrt();
537
538        // p99 of absolute values
539        let mut abs_vals: Vec<f32> = all_values.iter().map(|x| x.abs()).collect();
540        abs_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
541        let p99_idx = ((0.99_f32 * n as f32).ceil() as usize)
542            .saturating_sub(1)
543            .min(n - 1);
544        let p99_val = abs_vals[p99_idx];
545
546        let suggested = if p99_val > 0.0 {
547            p99_val / 127.0
548        } else {
549            // Fallback: use max abs
550            let max_abs = abs_vals.last().copied().unwrap_or(0.0);
551            if max_abs > 0.0 {
552                max_abs / 127.0
553            } else {
554                1.0 / 127.0
555            }
556        };
557
558        Self {
559            min: min_val,
560            max: max_val,
561            mean: mean_val,
562            std_dev: std_dev_val,
563            p99: p99_val,
564            suggested_scale: suggested,
565        }
566    }
567}
568
569// ─── Errors ───────────────────────────────────────────────────────────────────
570
571/// Errors from dynamic quantization operations.
572#[derive(Debug, thiserror::Error)]
573pub enum DynQuantError {
574    /// Shape mismatch between expected and actual sizes.
575    #[error("shape mismatch: expected {expected}, got {actual}")]
576    ShapeMismatch { expected: usize, actual: usize },
577
578    /// Input tensor is empty.
579    #[error("empty input")]
580    EmptyInput,
581
582    /// Alpha value is out of the valid [0, 1] range.
583    #[error("invalid alpha {0}: must be in [0, 1]")]
584    InvalidAlpha(f32),
585
586    /// Input feature dimension doesn't match smooth factors length.
587    #[error("dimension mismatch: in_features {in_features}, smooth_factors {sf_len}")]
588    FeatureDimMismatch { in_features: usize, sf_len: usize },
589}
590
591// ─── Tests ────────────────────────────────────────────────────────────────────
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596
597    #[test]
598    fn test_compute_scale_max_abs_basic() {
599        let data = [1.0_f32, -2.0, 0.5];
600        let scale = compute_scale(&data, 127.0, DynamicScaleMode::MaxAbs);
601        let expected = 2.0 / 127.0;
602        assert!(
603            (scale - expected).abs() < 1e-6,
604            "scale={scale}, expected={expected}"
605        );
606    }
607
608    #[test]
609    fn test_compute_scale_zeros() {
610        let data = [0.0_f32; 8];
611        let scale = compute_scale(&data, 127.0, DynamicScaleMode::MaxAbs);
612        assert_eq!(scale, 0.0);
613    }
614
615    #[test]
616    fn test_dequantize_roundtrip_int8() {
617        let data: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) * 0.1).collect();
618        let qt = dynamic_quantize_int8(&data, DynamicScaleMode::MaxAbs);
619        let recon = qt.dequantize();
620        let mae = quantization_mae(&data, &qt);
621        let max_abs = data.iter().map(|x| x.abs()).fold(0.0_f32, f32::max);
622        assert!(
623            mae < 0.005 * max_abs,
624            "MAE {mae} >= 0.5% of max_abs {max_abs}"
625        );
626        assert_eq!(recon.len(), data.len());
627    }
628
629    #[test]
630    fn test_int4_range() {
631        let data: Vec<f32> = (-50..=50).map(|i| i as f32 * 0.3).collect();
632        let qt = dynamic_quantize_int4(&data, DynamicScaleMode::MaxAbs);
633        for &q in &qt.data {
634            assert!((-7..=7).contains(&q), "INT4 value {q} out of range [-7, 7]");
635        }
636    }
637
638    #[test]
639    fn test_smooth_quant_config_new() {
640        let cfg = SmoothQuantConfig::new(0.7);
641        assert!((cfg.alpha - 0.7).abs() < 1e-6);
642    }
643
644    #[test]
645    fn test_smooth_quant_config_default_alpha() {
646        let cfg = SmoothQuantConfig::default_alpha();
647        assert!((cfg.alpha - 0.5).abs() < 1e-6);
648    }
649
650    #[test]
651    fn test_calib_stats_basic() {
652        let batches = vec![vec![1.0_f32, 2.0, 3.0], vec![-1.0_f32, 0.0, 4.0]];
653        let stats = CalibStats::collect(&batches);
654        assert!(stats.min <= stats.mean);
655        assert!(stats.mean <= stats.max);
656        assert!(stats.suggested_scale > 0.0);
657    }
658}