1use std::collections::HashMap;
34use std::fmt;
35
36#[derive(Debug, Clone, Default)]
38pub struct ErrorAccumulator {
39 layer_errors: HashMap<String, Vec<f64>>,
41 total_error: f64,
43 min_error: f64,
45 max_error: f64,
47}
48
49impl ErrorAccumulator {
50 pub fn new() -> Self {
52 Self {
53 layer_errors: HashMap::new(),
54 total_error: 0.0,
55 min_error: f64::INFINITY,
56 max_error: f64::NEG_INFINITY,
57 }
58 }
59
60 pub fn record_error(&mut self, layer_name: &str, error: f64) {
67 self.layer_errors
68 .entry(layer_name.to_string())
69 .or_default()
70 .push(error);
71
72 self.total_error += error;
73 if error < self.min_error {
74 self.min_error = error;
75 }
76 if error > self.max_error {
77 self.max_error = error;
78 }
79 }
80
81 pub fn record_errors<I>(&mut self, layer_name: &str, errors: I)
88 where
89 I: IntoIterator<Item = f64>,
90 {
91 let layer_vec = self
92 .layer_errors
93 .entry(layer_name.to_string())
94 .or_default();
95
96 for error in errors {
97 layer_vec.push(error);
98 self.total_error += error;
99 if error < self.min_error {
100 self.min_error = error;
101 }
102 if error > self.max_error {
103 self.max_error = error;
104 }
105 }
106 }
107
108 pub fn total_error(&self) -> f64 {
110 self.total_error
111 }
112
113 pub fn min_error(&self) -> f64 {
115 self.min_error
116 }
117
118 pub fn max_error(&self) -> f64 {
120 self.max_error
121 }
122
123 pub fn num_layers(&self) -> usize {
125 self.layer_errors.len()
126 }
127
128 pub fn total_recordings(&self) -> usize {
130 self.layer_errors.values().map(|v| v.len()).sum()
131 }
132
133 pub fn all_layer_errors(&self) -> &HashMap<String, Vec<f64>> {
135 &self.layer_errors
136 }
137
138 pub fn get_layer_errors(&self, layer_name: &str) -> Option<&[f64]> {
140 self.layer_errors.get(layer_name).map(|v| v.as_slice())
141 }
142
143 pub fn compute_statistics(&self) -> ErrorStatistics {
145 let all_errors: Vec<f64> = self.layer_errors.values().flatten().copied().collect();
146
147 if all_errors.is_empty() {
148 return ErrorStatistics {
149 mean: 0.0,
150 std_dev: 0.0,
151 min: 0.0,
152 max: 0.0,
153 total: 0.0,
154 count: 0,
155 };
156 }
157
158 let count = all_errors.len();
159 let total = all_errors.iter().sum::<f64>();
160 let mean = total / count as f64;
161
162 let variance = all_errors
164 .iter()
165 .map(|&e| (e - mean).powi(2))
166 .sum::<f64>()
167 / count as f64;
168 let std_dev = variance.sqrt();
169
170 let min = all_errors
171 .iter()
172 .cloned()
173 .fold(f64::INFINITY, f64::min);
174 let max = all_errors
175 .iter()
176 .cloned()
177 .fold(f64::NEG_INFINITY, f64::max);
178
179 ErrorStatistics {
180 mean,
181 std_dev,
182 min,
183 max,
184 total,
185 count,
186 }
187 }
188
189 pub fn generate_report(&self) -> ErrorReport {
191 let global_stats = self.compute_statistics();
192
193 let mut layer_stats: Vec<LayerErrorStats> = self
195 .layer_errors
196 .iter()
197 .map(|(layer_name, errors)| {
198 let count = errors.len();
199 let total = errors.iter().sum::<f64>();
200 let mean = total / count as f64;
201 let max = errors.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
202 let min = errors.iter().cloned().fold(f64::INFINITY, f64::min);
203
204 let variance = errors
206 .iter()
207 .map(|&e| (e - mean).powi(2))
208 .sum::<f64>()
209 / count as f64;
210 let std_dev = variance.sqrt();
211
212 LayerErrorStats {
213 layer_name: layer_name.clone(),
214 mean,
215 std_dev,
216 min,
217 max,
218 count,
219 }
220 })
221 .collect();
222
223 layer_stats.sort_by(|a, b| b.max.partial_cmp(&a.max).unwrap_or(std::cmp::Ordering::Equal));
225
226 ErrorReport {
227 global_stats,
228 layer_stats,
229 }
230 }
231
232 pub fn reset(&mut self) {
234 self.layer_errors.clear();
235 self.total_error = 0.0;
236 self.min_error = f64::INFINITY;
237 self.max_error = f64::NEG_INFINITY;
238 }
239}
240
241#[derive(Debug, Clone)]
243pub struct ErrorStatistics {
244 pub mean: f64,
246 pub std_dev: f64,
248 pub min: f64,
250 pub max: f64,
252 pub total: f64,
254 pub count: usize,
256}
257
258impl fmt::Display for ErrorStatistics {
259 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260 writeln!(f, "Error Statistics:")?;
261 writeln!(f, " Count: {}", self.count)?;
262 writeln!(f, " Mean: {:.2e}", self.mean)?;
263 writeln!(f, " Std Dev: {:.2e}", self.std_dev)?;
264 writeln!(f, " Min: {:.2e}", self.min)?;
265 writeln!(f, " Max: {:.2e}", self.max)?;
266 writeln!(f, " Total: {:.2e}", self.total)
267 }
268}
269
270#[derive(Debug, Clone)]
272pub struct LayerErrorStats {
273 pub layer_name: String,
275 pub mean: f64,
277 pub std_dev: f64,
279 pub min: f64,
281 pub max: f64,
283 pub count: usize,
285}
286
287impl fmt::Display for LayerErrorStats {
288 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
289 writeln!(f, " {}:", self.layer_name)?;
290 writeln!(f, " Count: {}", self.count)?;
291 writeln!(f, " Mean: {:.2e}", self.mean)?;
292 writeln!(f, " Std Dev: {:.2e}", self.std_dev)?;
293 writeln!(f, " Min: {:.2e}", self.min)?;
294 writeln!(f, " Max: {:.2e}", self.max)
295 }
296}
297
298#[derive(Debug, Clone)]
300pub struct ErrorReport {
301 pub global_stats: ErrorStatistics,
303 pub layer_stats: Vec<LayerErrorStats>,
305}
306
307impl fmt::Display for ErrorReport {
308 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
309 writeln!(f)?;
310 writeln!(f, "╔══════════════════════════════════════════════════════════╗")?;
311 writeln!(f, "║ ERROR ACCUMULATION REPORT ║")?;
312 writeln!(f, "╠══════════════════════════════════════════════════════════╣")?;
313 writeln!(f, "║ GLOBAL STATISTICS ║")?;
314 writeln!(f, "╟──────────────────────────────────────────────────────────╢")?;
315 writeln!(f, "║ Total Recordings: {:>20} ║", self.global_stats.count)?;
316 writeln!(f, "║ Mean Error: {:>20.2e} ║", self.global_stats.mean)?;
317 writeln!(f, "║ Std Dev: {:>20.2e} ║", self.global_stats.std_dev)?;
318 writeln!(f, "║ Min Error: {:>20.2e} ║", self.global_stats.min)?;
319 writeln!(f, "║ Max Error: {:>20.2e} ║", self.global_stats.max)?;
320 writeln!(f, "║ Total Error: {:>20.2e} ║", self.global_stats.total)?;
321 writeln!(f, "╟──────────────────────────────────────────────────────────╢")?;
322 writeln!(f, "║ PER-LAYER STATISTICS (sorted by max error) ║")?;
323 writeln!(f, "╟──────────────────────────────────────────────────────────╢")?;
324
325 let display_count = self.layer_stats.len().min(10);
327 for (i, layer) in self.layer_stats.iter().take(display_count).enumerate() {
328 writeln!(
329 f,
330 "║ {:2}. {:<28} {:.2e} ║",
331 i + 1,
332 truncate_name(&layer.layer_name, 28),
333 layer.max
334 )?;
335 }
336
337 if self.layer_stats.len() > display_count {
338 writeln!(
339 f,
340 "║ ... and {} more layers ║",
341 self.layer_stats.len() - display_count
342 )?;
343 }
344
345 writeln!(f, "╚══════════════════════════════════════════════════════════╝")?;
346 Ok(())
347 }
348}
349
350fn truncate_name(s: &str, max_len: usize) -> String {
352 if s.len() <= max_len {
353 s.to_string()
354 } else {
355 format!("...{}", &s[s.len() - max_len + 3..])
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn test_error_accumulator_basic() {
365 let mut accumulator = ErrorAccumulator::new();
366
367 accumulator.record_error("layer_0", 1.0e-14);
368 accumulator.record_error("layer_0", 1.5e-14);
369 accumulator.record_error("layer_1", 2.0e-14);
370
371 assert_eq!(accumulator.num_layers(), 2);
372 assert_eq!(accumulator.total_recordings(), 3);
373 assert!((accumulator.total_error() - 4.5e-14).abs() < 1e-20);
374 }
375
376 #[test]
377 fn test_error_accumulator_statistics() {
378 let mut accumulator = ErrorAccumulator::new();
379
380 let errors = vec![1.0e-14, 2.0e-14, 3.0e-14, 4.0e-14, 5.0e-14];
382 for (i, &error) in errors.iter().enumerate() {
383 accumulator.record_error(&format!("layer_{}", i), error);
384 }
385
386 let stats = accumulator.compute_statistics();
387
388 assert!((stats.mean - 3.0e-14).abs() < 1e-20);
390 assert!((stats.min - 1.0e-14).abs() < 1e-20);
392 assert!((stats.max - 5.0e-14).abs() < 1e-20);
394 assert!((stats.total - 15.0e-14).abs() < 1e-20);
396 assert_eq!(stats.count, 5);
397 }
398
399 #[test]
400 fn test_error_accumulator_multiple_errors_per_layer() {
401 let mut accumulator = ErrorAccumulator::new();
402
403 accumulator.record_error("layer_0", 1.0e-14);
405 accumulator.record_error("layer_0", 2.0e-14);
406 accumulator.record_error("layer_0", 3.0e-14);
407
408 let layer_errors = accumulator.get_layer_errors("layer_0").unwrap();
409 assert_eq!(layer_errors.len(), 3);
410 assert!((layer_errors[0] - 1.0e-14).abs() < 1e-20);
411 assert!((layer_errors[1] - 2.0e-14).abs() < 1e-20);
412 assert!((layer_errors[2] - 3.0e-14).abs() < 1e-20);
413 }
414
415 #[test]
416 fn test_error_accumulator_record_multiple() {
417 let mut accumulator = ErrorAccumulator::new();
418
419 let errors = vec![1.0e-14, 2.0e-14, 3.0e-14];
420 accumulator.record_errors("layer_0", errors);
421
422 assert_eq!(accumulator.total_recordings(), 3);
423 let layer_errors = accumulator.get_layer_errors("layer_0").unwrap();
424 assert_eq!(layer_errors.len(), 3);
425 }
426
427 #[test]
428 fn test_error_accumulator_empty() {
429 let accumulator = ErrorAccumulator::new();
430
431 let stats = accumulator.compute_statistics();
432 assert_eq!(stats.count, 0);
433 assert_eq!(stats.mean, 0.0);
434 assert_eq!(stats.total, 0.0);
435 }
436
437 #[test]
438 fn test_error_accumulator_reset() {
439 let mut accumulator = ErrorAccumulator::new();
440 accumulator.record_error("layer_0", 1.0e-14);
441 accumulator.record_error("layer_1", 2.0e-14);
442
443 assert_eq!(accumulator.num_layers(), 2);
444 assert_eq!(accumulator.total_recordings(), 2);
445
446 accumulator.reset();
447
448 assert_eq!(accumulator.num_layers(), 0);
449 assert_eq!(accumulator.total_recordings(), 0);
450 assert_eq!(accumulator.total_error(), 0.0);
451 }
452
453 #[test]
454 fn test_error_report_generation() {
455 let mut accumulator = ErrorAccumulator::new();
456
457 accumulator.record_error("embeddings", 5.0e-15);
459 accumulator.record_error("layer_0/q_proj", 1.2e-14);
460 accumulator.record_error("layer_0/k_proj", 1.5e-14);
461 accumulator.record_error("layer_0/v_proj", 1.1e-14);
462 accumulator.record_error("layer_0/out_proj", 1.3e-14);
463 accumulator.record_error("layer_1/q_proj", 2.1e-14);
464 accumulator.record_error("layer_1/k_proj", 1.8e-14);
465 accumulator.record_error("layer_1/v_proj", 2.3e-14);
466 accumulator.record_error("layer_1/out_proj", 1.9e-14);
467 accumulator.record_error("lm_head", 3.5e-14);
468
469 let report = accumulator.generate_report();
470
471 assert_eq!(report.global_stats.count, 10);
473 assert_eq!(report.layer_stats.len(), 10);
474
475 assert_eq!(report.layer_stats[0].layer_name, "lm_head");
477 assert!((report.layer_stats[0].max - 3.5e-14).abs() < 1e-20);
478 }
479
480 #[test]
481 fn test_error_report_display() {
482 let mut accumulator = ErrorAccumulator::new();
483 accumulator.record_error("layer_0", 1.0e-14);
484 accumulator.record_error("layer_1", 2.0e-14);
485
486 let report = accumulator.generate_report();
487 let display = format!("{}", report);
488
489 assert!(display.contains("ERROR ACCUMULATION REPORT"));
490 assert!(display.contains("GLOBAL STATISTICS"));
491 assert!(display.contains("PER-LAYER STATISTICS"));
492 assert!(display.contains("layer_0"));
493 assert!(display.contains("layer_1"));
494 }
495
496 #[test]
497 fn test_layer_error_stats_display() {
498 let stats = LayerErrorStats {
499 layer_name: "test_layer".to_string(),
500 mean: 1.5e-14,
501 std_dev: 0.5e-14,
502 min: 1.0e-14,
503 max: 2.0e-14,
504 count: 3,
505 };
506
507 let display = format!("{}", stats);
508 assert!(display.contains("test_layer"));
509 assert!(display.contains("Count:"));
510 assert!(display.contains("Mean:"));
511 }
512
513 #[test]
514 fn test_error_statistics_display() {
515 let stats = ErrorStatistics {
516 mean: 1.5e-14,
517 std_dev: 0.5e-14,
518 min: 1.0e-14,
519 max: 2.0e-14,
520 total: 4.5e-14,
521 count: 3,
522 };
523
524 let display = format!("{}", stats);
525 assert!(display.contains("Error Statistics:"));
526 assert!(display.contains("Count:"));
527 assert!(display.contains("Mean:"));
528 }
529}