use crate::error::{QuantError, QuantResult};
#[derive(Debug, Clone, Copy)]
pub struct SmoothQuantConfig {
pub alpha: f32,
}
impl Default for SmoothQuantConfig {
fn default() -> Self {
Self { alpha: 0.5 }
}
}
#[derive(Debug, Clone, Copy)]
pub struct SmoothQuantMigrator {
pub config: SmoothQuantConfig,
}
impl SmoothQuantMigrator {
#[must_use]
pub fn new(alpha: f32) -> Self {
Self {
config: SmoothQuantConfig { alpha },
}
}
pub fn compute_migration_scales(
&self,
act_max: &[f32],
weight_max: &[f32],
) -> QuantResult<Vec<f32>> {
if act_max.is_empty() {
return Err(QuantError::EmptyInput(
"SmoothQuantMigrator::compute_migration_scales",
));
}
if act_max.len() != weight_max.len() {
return Err(QuantError::DimensionMismatch {
expected: act_max.len(),
got: weight_max.len(),
});
}
let alpha = self.config.alpha;
let scales = act_max
.iter()
.zip(weight_max.iter())
.map(|(&a_max, &w_max)| {
let a = a_max.abs().max(1e-8);
let w = w_max.abs().max(1e-8);
a.powf(alpha) / w.powf(1.0 - alpha)
})
.collect();
Ok(scales)
}
pub fn compute_act_stats(
acts: &[f32],
n_tokens: usize,
n_channels: usize,
) -> QuantResult<Vec<f32>> {
if acts.is_empty() {
return Err(QuantError::EmptyInput(
"compute_act_stats: empty activations",
));
}
if acts.len() != n_tokens * n_channels {
return Err(QuantError::DimensionMismatch {
expected: n_tokens * n_channels,
got: acts.len(),
});
}
let mut stats = vec![0.0_f32; n_channels];
for t in 0..n_tokens {
for j in 0..n_channels {
let v = acts[t * n_channels + j].abs();
if v > stats[j] {
stats[j] = v;
}
}
}
Ok(stats)
}
pub fn compute_weight_stats(
weights: &[f32],
n_out: usize,
n_channels: usize,
) -> QuantResult<Vec<f32>> {
if weights.is_empty() {
return Err(QuantError::EmptyInput(
"compute_weight_stats: empty weights",
));
}
if weights.len() != n_out * n_channels {
return Err(QuantError::DimensionMismatch {
expected: n_out * n_channels,
got: weights.len(),
});
}
let mut stats = vec![0.0_f32; n_channels];
for r in 0..n_out {
for j in 0..n_channels {
let v = weights[r * n_channels + j].abs();
if v > stats[j] {
stats[j] = v;
}
}
}
Ok(stats)
}
pub fn smooth_activations(
acts: &mut [f32],
scales: &[f32],
n_tokens: usize,
n_channels: usize,
) -> QuantResult<()> {
if acts.len() != n_tokens * n_channels {
return Err(QuantError::DimensionMismatch {
expected: n_tokens * n_channels,
got: acts.len(),
});
}
if scales.len() != n_channels {
return Err(QuantError::DimensionMismatch {
expected: n_channels,
got: scales.len(),
});
}
for t in 0..n_tokens {
for j in 0..n_channels {
acts[t * n_channels + j] /= scales[j].max(1e-12);
}
}
Ok(())
}
pub fn smooth_weights(
weights: &mut [f32],
scales: &[f32],
n_out: usize,
n_channels: usize,
) -> QuantResult<()> {
if weights.len() != n_out * n_channels {
return Err(QuantError::DimensionMismatch {
expected: n_out * n_channels,
got: weights.len(),
});
}
if scales.len() != n_channels {
return Err(QuantError::DimensionMismatch {
expected: n_channels,
got: scales.len(),
});
}
for r in 0..n_out {
for j in 0..n_channels {
weights[r * n_channels + j] *= scales[j];
}
}
Ok(())
}
pub fn smooth_layer(
&self,
acts: &mut [f32],
weights: &mut [f32],
n_tokens: usize,
n_channels: usize,
n_out: usize,
) -> QuantResult<Vec<f32>> {
let act_stats = Self::compute_act_stats(acts, n_tokens, n_channels)?;
let weight_stats = Self::compute_weight_stats(weights, n_out, n_channels)?;
let scales = self.compute_migration_scales(&act_stats, &weight_stats)?;
Self::smooth_activations(acts, &scales, n_tokens, n_channels)?;
Self::smooth_weights(weights, &scales, n_out, n_channels)?;
Ok(scales)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
fn matmul_nt(x: &[f32], w: &[f32], n_tok: usize, n_ch: usize, n_out: usize) -> Vec<f32> {
let mut y = vec![0.0_f32; n_tok * n_out];
for t in 0..n_tok {
for o in 0..n_out {
let dot: f32 = (0..n_ch).map(|j| x[t * n_ch + j] * w[o * n_ch + j]).sum();
y[t * n_out + o] = dot;
}
}
y
}
#[test]
fn scale_alpha_half() {
let m = SmoothQuantMigrator::new(0.5);
let act_max = vec![4.0_f32, 1.0, 9.0];
let weight_max = vec![1.0_f32, 4.0, 1.0];
let scales = m.compute_migration_scales(&act_max, &weight_max).unwrap();
assert_abs_diff_eq!(scales[0], 2.0, epsilon = 1e-5);
assert_abs_diff_eq!(scales[1], 0.5, epsilon = 1e-5);
assert_abs_diff_eq!(scales[2], 3.0, epsilon = 1e-5);
}
#[test]
fn scale_alpha_one_activations_only() {
let m = SmoothQuantMigrator::new(1.0);
let act_max = vec![2.0_f32, 5.0];
let weight_max = vec![3.0_f32, 7.0]; let scales = m.compute_migration_scales(&act_max, &weight_max).unwrap();
assert_abs_diff_eq!(scales[0], 2.0, epsilon = 1e-5);
assert_abs_diff_eq!(scales[1], 5.0, epsilon = 1e-5);
}
#[test]
fn scale_alpha_zero_weights_only() {
let m = SmoothQuantMigrator::new(0.0);
let act_max = vec![4.0_f32, 1.0]; let weight_max = vec![2.0_f32, 5.0];
let scales = m.compute_migration_scales(&act_max, &weight_max).unwrap();
assert_abs_diff_eq!(scales[0], 1.0 / 2.0, epsilon = 1e-5);
assert_abs_diff_eq!(scales[1], 1.0 / 5.0, epsilon = 1e-5);
}
#[test]
fn smoothing_preserves_layer_output() {
let m = SmoothQuantMigrator::new(0.5);
let n_tok = 3;
let n_ch = 4;
let n_out = 2;
let mut acts: Vec<f32> = (0..(n_tok * n_ch))
.map(|i| (i as f32 * 0.3) - 1.0)
.collect();
let mut weights: Vec<f32> = (0..(n_out * n_ch))
.map(|i| (i as f32 * 0.2) - 0.5)
.collect();
let y_orig = matmul_nt(&acts, &weights, n_tok, n_ch, n_out);
m.smooth_layer(&mut acts, &mut weights, n_tok, n_ch, n_out)
.unwrap();
let y_smooth = matmul_nt(&acts, &weights, n_tok, n_ch, n_out);
for (a, b) in y_orig.iter().zip(y_smooth.iter()) {
assert_abs_diff_eq!(a, b, epsilon = 1e-4);
}
}
#[test]
fn activation_stats_max_per_channel() {
let acts = vec![1.0_f32, -5.0, 2.0, -3.0, 4.0, 1.0];
let stats = SmoothQuantMigrator::compute_act_stats(&acts, 2, 3).unwrap();
assert_abs_diff_eq!(stats[0], 3.0, epsilon = 1e-6); assert_abs_diff_eq!(stats[1], 5.0, epsilon = 1e-6); assert_abs_diff_eq!(stats[2], 2.0, epsilon = 1e-6); }
#[test]
fn weight_stats_max_per_column() {
let w = vec![0.5_f32, -2.0, 1.0, -1.5, 0.3, 3.0];
let stats = SmoothQuantMigrator::compute_weight_stats(&w, 2, 3).unwrap();
assert_abs_diff_eq!(stats[0], 1.5, epsilon = 1e-6);
assert_abs_diff_eq!(stats[1], 2.0, epsilon = 1e-6);
assert_abs_diff_eq!(stats[2], 3.0, epsilon = 1e-6);
}
#[test]
fn dimension_mismatch_error() {
let m = SmoothQuantMigrator::new(0.5);
let act_max = vec![1.0_f32; 3];
let weight_max = vec![1.0_f32; 4]; assert!(matches!(
m.compute_migration_scales(&act_max, &weight_max),
Err(QuantError::DimensionMismatch { .. })
));
}
#[test]
fn empty_input_error() {
let m = SmoothQuantMigrator::new(0.5);
assert!(matches!(
m.compute_migration_scales(&[], &[]),
Err(QuantError::EmptyInput(_))
));
}
#[test]
fn smoothing_reduces_act_channel_range_imbalance() {
let m = SmoothQuantMigrator::new(0.5);
let n_tok = 4;
let n_ch = 2;
let n_out = 2;
let mut acts = vec![100.0_f32, 1.0, -100.0, 1.0, 100.0, -1.0, -100.0, -1.0];
let mut weights = vec![0.5_f32, 0.5, -0.5, 0.5];
let scales = m
.smooth_layer(&mut acts, &mut weights, n_tok, n_ch, n_out)
.unwrap();
let act_max_0: f32 = (0..n_tok)
.map(|t| acts[t * n_ch].abs())
.fold(0.0_f32, f32::max);
let act_max_1: f32 = (0..n_tok)
.map(|t| acts[t * n_ch + 1].abs())
.fold(0.0_f32, f32::max);
let ratio = act_max_0 / act_max_1.max(1e-8);
assert!(
scales[0] > 1.0,
"scale[0] should be > 1 for outlier channel"
);
assert!(
ratio < 100.0,
"channel range imbalance should decrease after smoothing"
);
}
}