Skip to main content

trueno/brick/tracing/
quant_error.rs

1// ============================================================================
2// E.11.5: QuantizationErrorTrace (MLT-04)
3// ============================================================================
4
5use super::quant_type::QuantType;
6use crate::brick::exec_graph::BrickId;
7
8/// Quantization error measurement for a single operation.
9///
10/// Compares quantized computation against FP32 reference using multiple metrics.
11#[derive(Debug, Clone)]
12pub struct QuantizationErrorTrace {
13    /// Brick type being measured
14    pub brick_id: BrickId,
15    /// Layer index
16    pub layer_idx: usize,
17    /// Mean squared error vs FP32 reference
18    pub mse: f32,
19    /// Maximum absolute error
20    pub max_abs_error: f32,
21    /// Cosine similarity (1.0 = perfect match)
22    pub cosine_similarity: f32,
23    /// Signal-to-noise ratio in dB
24    pub snr_db: f32,
25    /// Quantization type used
26    pub quant_type: QuantType,
27}
28
29impl QuantizationErrorTrace {
30    /// Compute error metrics between quantized and reference outputs.
31    pub fn compute(
32        brick_id: BrickId,
33        layer_idx: usize,
34        quantized: &[f32],
35        reference: &[f32],
36        quant_type: QuantType,
37    ) -> Self {
38        assert_eq!(quantized.len(), reference.len(), "Length mismatch");
39        let n = quantized.len();
40        if n == 0 {
41            return Self {
42                brick_id,
43                layer_idx,
44                mse: 0.0,
45                max_abs_error: 0.0,
46                cosine_similarity: 1.0, // Perfect match when both empty
47                snr_db: f32::INFINITY,
48                quant_type,
49            };
50        }
51
52        // MSE and max abs error
53        let mut sum_sq_error = 0.0f64;
54        let mut max_abs_error = 0.0f32;
55        for (q, r) in quantized.iter().zip(reference.iter()) {
56            let error = q - r;
57            sum_sq_error += (error as f64) * (error as f64);
58            max_abs_error = max_abs_error.max(error.abs());
59        }
60        let mse = (sum_sq_error / n as f64) as f32;
61
62        // Cosine similarity
63        let mut dot = 0.0f64;
64        let mut norm_q = 0.0f64;
65        let mut norm_r = 0.0f64;
66        for (q, r) in quantized.iter().zip(reference.iter()) {
67            dot += (*q as f64) * (*r as f64);
68            norm_q += (*q as f64) * (*q as f64);
69            norm_r += (*r as f64) * (*r as f64);
70        }
71        let cosine_similarity = if norm_q > 0.0 && norm_r > 0.0 {
72            (dot / (norm_q.sqrt() * norm_r.sqrt())) as f32
73        } else {
74            0.0
75        };
76
77        // SNR in dB: 10 * log10(signal_power / noise_power)
78        let signal_power = norm_r / n as f64;
79        let noise_power = sum_sq_error / n as f64;
80        let snr_db = if noise_power > 1e-10 {
81            (10.0 * (signal_power / noise_power).max(f64::EPSILON).log10()) as f32
82        } else {
83            f32::INFINITY
84        };
85
86        Self { brick_id, layer_idx, mse, max_abs_error, cosine_similarity, snr_db, quant_type }
87    }
88
89    /// Check if error is acceptable (cosine > 0.995).
90    pub fn is_acceptable(&self) -> bool {
91        self.cosine_similarity > 0.995
92    }
93
94    /// Check if error is in warning zone (0.99 < cosine < 0.995).
95    pub fn is_warning(&self) -> bool {
96        self.cosine_similarity > 0.99 && self.cosine_similarity <= 0.995
97    }
98
99    /// Check if error is critical (cosine < 0.99).
100    pub fn is_critical(&self) -> bool {
101        self.cosine_similarity < 0.99
102    }
103}
104
105/// Cumulative quantization error across an entire model.
106#[derive(Debug, Clone, Default)]
107pub struct ModelQuantizationError {
108    /// Per-brick error traces
109    pub brick_errors: Vec<QuantizationErrorTrace>,
110    /// Overall cosine similarity of final logits
111    pub logits_cosine: f32,
112    /// KL divergence of output probability distributions
113    pub output_kl_divergence: f32,
114    /// Perplexity difference (PPL_quant - PPL_fp32)
115    pub perplexity_delta: f32,
116}
117
118impl ModelQuantizationError {
119    /// Add a brick error trace.
120    pub fn add_error(&mut self, trace: QuantizationErrorTrace) {
121        self.brick_errors.push(trace);
122    }
123
124    /// Get count of critical errors.
125    pub fn critical_count(&self) -> usize {
126        self.brick_errors.iter().filter(|e| e.is_critical()).count()
127    }
128
129    /// Get count of warning errors.
130    pub fn warning_count(&self) -> usize {
131        self.brick_errors.iter().filter(|e| e.is_warning()).count()
132    }
133
134    /// Get worst brick by cosine similarity.
135    pub fn worst_brick(&self) -> Option<&QuantizationErrorTrace> {
136        self.brick_errors.iter().min_by(|a, b| {
137            a.cosine_similarity
138                .partial_cmp(&b.cosine_similarity)
139                .unwrap_or(std::cmp::Ordering::Equal)
140        })
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn test_quant_error_perfect_match() {
150        let reference = vec![1.0, 2.0, 3.0];
151        let quantized = vec![1.0, 2.0, 3.0];
152        let trace = QuantizationErrorTrace::compute(
153            BrickId::RmsNorm,
154            0,
155            &quantized,
156            &reference,
157            QuantType::Q4_K,
158        );
159
160        assert!((trace.mse - 0.0).abs() < 1e-6);
161        assert!((trace.cosine_similarity - 1.0).abs() < 1e-6);
162        assert!(trace.is_acceptable());
163    }
164
165    #[test]
166    fn test_quant_error_significant_difference() {
167        let reference = vec![1.0, 2.0, 3.0];
168        // Non-proportional: adds different offsets, changing direction
169        let quantized = vec![1.5, 2.1, 2.9];
170        let trace = QuantizationErrorTrace::compute(
171            BrickId::RmsNorm,
172            0,
173            &quantized,
174            &reference,
175            QuantType::Q4_K,
176        );
177
178        assert!(trace.mse > 0.0);
179        assert!(trace.cosine_similarity < 1.0);
180        // Cosine should still be high since vectors are close
181        assert!(trace.cosine_similarity > 0.99);
182    }
183
184    /// FALSIFICATION TEST: Cosine similarity must be 1.0 for identical vectors
185    #[test]
186    fn test_falsify_cosine_identical_vectors() {
187        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
188        let trace =
189            QuantizationErrorTrace::compute(BrickId::RmsNorm, 0, &data, &data, QuantType::F32);
190
191        assert!(
192            (trace.cosine_similarity - 1.0).abs() < 1e-6,
193            "FALSIFICATION FAILED: identical vectors have cosine {} != 1.0",
194            trace.cosine_similarity
195        );
196    }
197
198    /// FALSIFICATION TEST: Cosine similarity must be symmetric
199    #[test]
200    fn test_falsify_cosine_symmetry() {
201        let a = vec![1.0, 2.0, 3.0];
202        let b = vec![4.0, 5.0, 6.0];
203
204        let trace_ab = QuantizationErrorTrace::compute(BrickId::RmsNorm, 0, &a, &b, QuantType::F32);
205        let trace_ba = QuantizationErrorTrace::compute(BrickId::RmsNorm, 0, &b, &a, QuantType::F32);
206
207        assert!(
208            (trace_ab.cosine_similarity - trace_ba.cosine_similarity).abs() < 1e-6,
209            "FALSIFICATION FAILED: cosine(a,b) {} != cosine(b,a) {}",
210            trace_ab.cosine_similarity,
211            trace_ba.cosine_similarity
212        );
213    }
214}