use super::quant_type::QuantType;
use crate::brick::exec_graph::BrickId;
#[derive(Debug, Clone)]
pub struct QuantizationErrorTrace {
pub brick_id: BrickId,
pub layer_idx: usize,
pub mse: f32,
pub max_abs_error: f32,
pub cosine_similarity: f32,
pub snr_db: f32,
pub quant_type: QuantType,
}
impl QuantizationErrorTrace {
pub fn compute(
brick_id: BrickId,
layer_idx: usize,
quantized: &[f32],
reference: &[f32],
quant_type: QuantType,
) -> Self {
assert_eq!(quantized.len(), reference.len(), "Length mismatch");
let n = quantized.len();
if n == 0 {
return Self {
brick_id,
layer_idx,
mse: 0.0,
max_abs_error: 0.0,
cosine_similarity: 1.0, snr_db: f32::INFINITY,
quant_type,
};
}
let mut sum_sq_error = 0.0f64;
let mut max_abs_error = 0.0f32;
for (q, r) in quantized.iter().zip(reference.iter()) {
let error = q - r;
sum_sq_error += (error as f64) * (error as f64);
max_abs_error = max_abs_error.max(error.abs());
}
let mse = (sum_sq_error / n as f64) as f32;
let mut dot = 0.0f64;
let mut norm_q = 0.0f64;
let mut norm_r = 0.0f64;
for (q, r) in quantized.iter().zip(reference.iter()) {
dot += (*q as f64) * (*r as f64);
norm_q += (*q as f64) * (*q as f64);
norm_r += (*r as f64) * (*r as f64);
}
let cosine_similarity = if norm_q > 0.0 && norm_r > 0.0 {
(dot / (norm_q.sqrt() * norm_r.sqrt())) as f32
} else {
0.0
};
let signal_power = norm_r / n as f64;
let noise_power = sum_sq_error / n as f64;
let snr_db = if noise_power > 1e-10 {
(10.0 * (signal_power / noise_power).max(f64::EPSILON).log10()) as f32
} else {
f32::INFINITY
};
Self { brick_id, layer_idx, mse, max_abs_error, cosine_similarity, snr_db, quant_type }
}
pub fn is_acceptable(&self) -> bool {
self.cosine_similarity > 0.995
}
pub fn is_warning(&self) -> bool {
self.cosine_similarity > 0.99 && self.cosine_similarity <= 0.995
}
pub fn is_critical(&self) -> bool {
self.cosine_similarity < 0.99
}
}
#[derive(Debug, Clone, Default)]
pub struct ModelQuantizationError {
pub brick_errors: Vec<QuantizationErrorTrace>,
pub logits_cosine: f32,
pub output_kl_divergence: f32,
pub perplexity_delta: f32,
}
impl ModelQuantizationError {
pub fn add_error(&mut self, trace: QuantizationErrorTrace) {
self.brick_errors.push(trace);
}
pub fn critical_count(&self) -> usize {
self.brick_errors.iter().filter(|e| e.is_critical()).count()
}
pub fn warning_count(&self) -> usize {
self.brick_errors.iter().filter(|e| e.is_warning()).count()
}
pub fn worst_brick(&self) -> Option<&QuantizationErrorTrace> {
self.brick_errors.iter().min_by(|a, b| {
a.cosine_similarity
.partial_cmp(&b.cosine_similarity)
.unwrap_or(std::cmp::Ordering::Equal)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quant_error_perfect_match() {
let reference = vec![1.0, 2.0, 3.0];
let quantized = vec![1.0, 2.0, 3.0];
let trace = QuantizationErrorTrace::compute(
BrickId::RmsNorm,
0,
&quantized,
&reference,
QuantType::Q4_K,
);
assert!((trace.mse - 0.0).abs() < 1e-6);
assert!((trace.cosine_similarity - 1.0).abs() < 1e-6);
assert!(trace.is_acceptable());
}
#[test]
fn test_quant_error_significant_difference() {
let reference = vec![1.0, 2.0, 3.0];
let quantized = vec![1.5, 2.1, 2.9];
let trace = QuantizationErrorTrace::compute(
BrickId::RmsNorm,
0,
&quantized,
&reference,
QuantType::Q4_K,
);
assert!(trace.mse > 0.0);
assert!(trace.cosine_similarity < 1.0);
assert!(trace.cosine_similarity > 0.99);
}
#[test]
fn test_falsify_cosine_identical_vectors() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let trace =
QuantizationErrorTrace::compute(BrickId::RmsNorm, 0, &data, &data, QuantType::F32);
assert!(
(trace.cosine_similarity - 1.0).abs() < 1e-6,
"FALSIFICATION FAILED: identical vectors have cosine {} != 1.0",
trace.cosine_similarity
);
}
#[test]
fn test_falsify_cosine_symmetry() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let trace_ab = QuantizationErrorTrace::compute(BrickId::RmsNorm, 0, &a, &b, QuantType::F32);
let trace_ba = QuantizationErrorTrace::compute(BrickId::RmsNorm, 0, &b, &a, QuantType::F32);
assert!(
(trace_ab.cosine_similarity - trace_ba.cosine_similarity).abs() < 1e-6,
"FALSIFICATION FAILED: cosine(a,b) {} != cosine(b,a) {}",
trace_ab.cosine_similarity,
trace_ba.cosine_similarity
);
}
}