trueno/brick/tracing/
quant_error.rs1use super::quant_type::QuantType;
6use crate::brick::exec_graph::BrickId;
7
8#[derive(Debug, Clone)]
12pub struct QuantizationErrorTrace {
13 pub brick_id: BrickId,
15 pub layer_idx: usize,
17 pub mse: f32,
19 pub max_abs_error: f32,
21 pub cosine_similarity: f32,
23 pub snr_db: f32,
25 pub quant_type: QuantType,
27}
28
29impl QuantizationErrorTrace {
30 pub fn compute(
32 brick_id: BrickId,
33 layer_idx: usize,
34 quantized: &[f32],
35 reference: &[f32],
36 quant_type: QuantType,
37 ) -> Self {
38 assert_eq!(quantized.len(), reference.len(), "Length mismatch");
39 let n = quantized.len();
40 if n == 0 {
41 return Self {
42 brick_id,
43 layer_idx,
44 mse: 0.0,
45 max_abs_error: 0.0,
46 cosine_similarity: 1.0, snr_db: f32::INFINITY,
48 quant_type,
49 };
50 }
51
52 let mut sum_sq_error = 0.0f64;
54 let mut max_abs_error = 0.0f32;
55 for (q, r) in quantized.iter().zip(reference.iter()) {
56 let error = q - r;
57 sum_sq_error += (error as f64) * (error as f64);
58 max_abs_error = max_abs_error.max(error.abs());
59 }
60 let mse = (sum_sq_error / n as f64) as f32;
61
62 let mut dot = 0.0f64;
64 let mut norm_q = 0.0f64;
65 let mut norm_r = 0.0f64;
66 for (q, r) in quantized.iter().zip(reference.iter()) {
67 dot += (*q as f64) * (*r as f64);
68 norm_q += (*q as f64) * (*q as f64);
69 norm_r += (*r as f64) * (*r as f64);
70 }
71 let cosine_similarity = if norm_q > 0.0 && norm_r > 0.0 {
72 (dot / (norm_q.sqrt() * norm_r.sqrt())) as f32
73 } else {
74 0.0
75 };
76
77 let signal_power = norm_r / n as f64;
79 let noise_power = sum_sq_error / n as f64;
80 let snr_db = if noise_power > 1e-10 {
81 (10.0 * (signal_power / noise_power).max(f64::EPSILON).log10()) as f32
82 } else {
83 f32::INFINITY
84 };
85
86 Self { brick_id, layer_idx, mse, max_abs_error, cosine_similarity, snr_db, quant_type }
87 }
88
89 pub fn is_acceptable(&self) -> bool {
91 self.cosine_similarity > 0.995
92 }
93
94 pub fn is_warning(&self) -> bool {
96 self.cosine_similarity > 0.99 && self.cosine_similarity <= 0.995
97 }
98
99 pub fn is_critical(&self) -> bool {
101 self.cosine_similarity < 0.99
102 }
103}
104
105#[derive(Debug, Clone, Default)]
107pub struct ModelQuantizationError {
108 pub brick_errors: Vec<QuantizationErrorTrace>,
110 pub logits_cosine: f32,
112 pub output_kl_divergence: f32,
114 pub perplexity_delta: f32,
116}
117
118impl ModelQuantizationError {
119 pub fn add_error(&mut self, trace: QuantizationErrorTrace) {
121 self.brick_errors.push(trace);
122 }
123
124 pub fn critical_count(&self) -> usize {
126 self.brick_errors.iter().filter(|e| e.is_critical()).count()
127 }
128
129 pub fn warning_count(&self) -> usize {
131 self.brick_errors.iter().filter(|e| e.is_warning()).count()
132 }
133
134 pub fn worst_brick(&self) -> Option<&QuantizationErrorTrace> {
136 self.brick_errors.iter().min_by(|a, b| {
137 a.cosine_similarity
138 .partial_cmp(&b.cosine_similarity)
139 .unwrap_or(std::cmp::Ordering::Equal)
140 })
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147
148 #[test]
149 fn test_quant_error_perfect_match() {
150 let reference = vec![1.0, 2.0, 3.0];
151 let quantized = vec![1.0, 2.0, 3.0];
152 let trace = QuantizationErrorTrace::compute(
153 BrickId::RmsNorm,
154 0,
155 &quantized,
156 &reference,
157 QuantType::Q4_K,
158 );
159
160 assert!((trace.mse - 0.0).abs() < 1e-6);
161 assert!((trace.cosine_similarity - 1.0).abs() < 1e-6);
162 assert!(trace.is_acceptable());
163 }
164
165 #[test]
166 fn test_quant_error_significant_difference() {
167 let reference = vec![1.0, 2.0, 3.0];
168 let quantized = vec![1.5, 2.1, 2.9];
170 let trace = QuantizationErrorTrace::compute(
171 BrickId::RmsNorm,
172 0,
173 &quantized,
174 &reference,
175 QuantType::Q4_K,
176 );
177
178 assert!(trace.mse > 0.0);
179 assert!(trace.cosine_similarity < 1.0);
180 assert!(trace.cosine_similarity > 0.99);
182 }
183
184 #[test]
186 fn test_falsify_cosine_identical_vectors() {
187 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
188 let trace =
189 QuantizationErrorTrace::compute(BrickId::RmsNorm, 0, &data, &data, QuantType::F32);
190
191 assert!(
192 (trace.cosine_similarity - 1.0).abs() < 1e-6,
193 "FALSIFICATION FAILED: identical vectors have cosine {} != 1.0",
194 trace.cosine_similarity
195 );
196 }
197
198 #[test]
200 fn test_falsify_cosine_symmetry() {
201 let a = vec![1.0, 2.0, 3.0];
202 let b = vec![4.0, 5.0, 6.0];
203
204 let trace_ab = QuantizationErrorTrace::compute(BrickId::RmsNorm, 0, &a, &b, QuantType::F32);
205 let trace_ba = QuantizationErrorTrace::compute(BrickId::RmsNorm, 0, &b, &a, QuantType::F32);
206
207 assert!(
208 (trace_ab.cosine_similarity - trace_ba.cosine_similarity).abs() < 1e-6,
209 "FALSIFICATION FAILED: cosine(a,b) {} != cosine(b,a) {}",
210 trace_ab.cosine_similarity,
211 trace_ba.cosine_similarity
212 );
213 }
214}