use std::collections::HashMap;
use crate::dynamic_quant::{
compute_smooth_factors, smooth_weights, DynQuantError, SmoothQuantConfig,
};
use oxibonsai_core::quant_fp8::{BlockFP8E4M3, BlockFP8E5M2};
#[derive(Debug, Clone)]
pub enum SmoothQuantError {
EmptyCalibrator,
LayerNotFound(String),
InFeaturesMismatch { expected: usize, got: usize },
QuantizationError(String),
}
impl std::fmt::Display for SmoothQuantError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::EmptyCalibrator => write!(f, "SmoothQuant calibrator has no recorded layers"),
Self::LayerNotFound(name) => {
write!(f, "SmoothQuant calibrator: layer '{name}' not found")
}
Self::InFeaturesMismatch { expected, got } => write!(
f,
"SmoothQuant calibrator: in_features mismatch — expected {expected}, got {got}"
),
Self::QuantizationError(msg) => {
write!(f, "SmoothQuant quantization error: {msg}")
}
}
}
}
impl std::error::Error for SmoothQuantError {}
struct ChannelStats {
in_features: usize,
running_max_abs: Vec<f32>,
sample_count: usize,
}
impl ChannelStats {
fn new(in_features: usize) -> Self {
Self {
in_features,
running_max_abs: vec![0.0_f32; in_features],
sample_count: 0,
}
}
fn update(&mut self, activations: &[f32], in_features: usize) {
debug_assert_eq!(in_features, self.in_features);
let num_tokens = activations.len() / in_features;
for t in 0..num_tokens {
for (j, slot) in self.running_max_abs.iter_mut().enumerate() {
let idx = t * in_features + j;
if idx < activations.len() {
let v = activations[idx].abs();
if v > *slot {
*slot = v;
}
}
}
}
self.sample_count += 1;
}
}
pub struct SmoothQuantCalibrator {
layers: HashMap<String, ChannelStats>,
config: SmoothQuantConfig,
}
impl SmoothQuantCalibrator {
pub fn new(config: SmoothQuantConfig) -> Self {
Self {
layers: HashMap::new(),
config,
}
}
pub fn record_activation(&mut self, layer_name: &str, activations: &[f32], in_features: usize) {
if in_features == 0 || activations.is_empty() {
return;
}
let stats = self
.layers
.entry(layer_name.to_owned())
.or_insert_with(|| ChannelStats::new(in_features));
if stats.in_features != in_features {
panic!(
"SmoothQuantCalibrator::record_activation: in_features mismatch for layer '{}' \
— expected {}, got {}",
layer_name, stats.in_features, in_features
);
}
stats.update(activations, in_features);
}
pub fn smooth_factors(
&self,
layer_name: &str,
weights: &[f32],
out_features: usize,
) -> Result<Vec<f32>, SmoothQuantError> {
let stats = self
.layers
.get(layer_name)
.ok_or_else(|| SmoothQuantError::LayerNotFound(layer_name.to_owned()))?;
let in_features = stats.in_features;
let factors = compute_smooth_factors(
&stats.running_max_abs,
weights,
in_features,
1, out_features,
&self.config,
);
Ok(factors)
}
pub fn layer_count(&self) -> usize {
self.layers.len()
}
pub fn has_layer(&self, name: &str) -> bool {
self.layers.contains_key(name)
}
}
pub fn quantize_fp8_e4m3_smooth(
weights: &[f32],
out_features: usize,
in_features: usize,
smooth_factors: &[f32],
) -> Result<Vec<BlockFP8E4M3>, SmoothQuantError> {
if smooth_factors.len() != in_features {
return Err(SmoothQuantError::InFeaturesMismatch {
expected: in_features,
got: smooth_factors.len(),
});
}
let mut smoothed = weights.to_vec();
smooth_weights(&mut smoothed, smooth_factors, out_features, in_features)
.map_err(|e: DynQuantError| SmoothQuantError::QuantizationError(e.to_string()))?;
BlockFP8E4M3::quantize(&smoothed)
.map_err(|e| SmoothQuantError::QuantizationError(e.to_string()))
}
pub fn quantize_fp8_e5m2_smooth(
weights: &[f32],
out_features: usize,
in_features: usize,
smooth_factors: &[f32],
) -> Result<Vec<BlockFP8E5M2>, SmoothQuantError> {
if smooth_factors.len() != in_features {
return Err(SmoothQuantError::InFeaturesMismatch {
expected: in_features,
got: smooth_factors.len(),
});
}
let mut smoothed = weights.to_vec();
smooth_weights(&mut smoothed, smooth_factors, out_features, in_features)
.map_err(|e: DynQuantError| SmoothQuantError::QuantizationError(e.to_string()))?;
BlockFP8E5M2::quantize(&smoothed)
.map_err(|e| SmoothQuantError::QuantizationError(e.to_string()))
}