Skip to main content

god_graph/transformer/optimization/
error_analysis.rs

1//! Error Accumulation Analysis for Graph-Level Tensor Operations
2//!
3//! This module provides tools for tracking and analyzing numerical errors
4//! that accumulate during graph-level tensor operations like orthogonalization
5//! and decomposition.
6//!
7//! ## Purpose
8//!
9//! When performing operations like QR decomposition on weights across a graph,
10//! numerical errors can accumulate. This module helps:
11//!
12//! - Track per-layer error contributions
13//! - Generate statistical reports on error distributions
14//! - Identify problematic layers with high error
15//!
16//! ## Example
17//!
18//! ```no_run
19//! use god_gragh::transformer::optimization::error_analysis::ErrorAccumulator;
20//!
21//! let mut accumulator = ErrorAccumulator::new();
22//!
23//! // Record errors from each layer
24//! accumulator.record_error("layer_0/q_proj", 1.2e-14);
25//! accumulator.record_error("layer_0/k_proj", 1.5e-14);
26//! accumulator.record_error("layer_1/q_proj", 2.1e-14);
27//!
28//! // Generate and print report
29//! let report = accumulator.generate_report();
30//! println!("{}", report);
31//! ```
32
33use std::collections::HashMap;
34use std::fmt;
35
36/// Error accumulator for tracking numerical errors across graph operations
37#[derive(Debug, Clone, Default)]
38pub struct ErrorAccumulator {
39    /// Errors grouped by layer name
40    layer_errors: HashMap<String, Vec<f64>>,
41    /// Total accumulated error
42    total_error: f64,
43    /// Minimum error observed
44    min_error: f64,
45    /// Maximum error observed
46    max_error: f64,
47}
48
49impl ErrorAccumulator {
50    /// Create a new error accumulator
51    pub fn new() -> Self {
52        Self {
53            layer_errors: HashMap::new(),
54            total_error: 0.0,
55            min_error: f64::INFINITY,
56            max_error: f64::NEG_INFINITY,
57        }
58    }
59
60    /// Record an error for a specific layer
61    ///
62    /// # Arguments
63    ///
64    /// * `layer_name` - Name/identifier of the layer
65    /// * `error` - Numerical error value (should be positive)
66    pub fn record_error(&mut self, layer_name: &str, error: f64) {
67        self.layer_errors
68            .entry(layer_name.to_string())
69            .or_default()
70            .push(error);
71
72        self.total_error += error;
73        if error < self.min_error {
74            self.min_error = error;
75        }
76        if error > self.max_error {
77            self.max_error = error;
78        }
79    }
80
81    /// Record multiple errors for a layer at once
82    ///
83    /// # Arguments
84    ///
85    /// * `layer_name` - Name/identifier of the layer
86    /// * `errors` - Iterator of error values
87    pub fn record_errors<I>(&mut self, layer_name: &str, errors: I)
88    where
89        I: IntoIterator<Item = f64>,
90    {
91        let layer_vec = self
92            .layer_errors
93            .entry(layer_name.to_string())
94            .or_default();
95
96        for error in errors {
97            layer_vec.push(error);
98            self.total_error += error;
99            if error < self.min_error {
100                self.min_error = error;
101            }
102            if error > self.max_error {
103                self.max_error = error;
104            }
105        }
106    }
107
108    /// Get the total accumulated error
109    pub fn total_error(&self) -> f64 {
110        self.total_error
111    }
112
113    /// Get the minimum error observed
114    pub fn min_error(&self) -> f64 {
115        self.min_error
116    }
117
118    /// Get the maximum error observed
119    pub fn max_error(&self) -> f64 {
120        self.max_error
121    }
122
123    /// Get the number of layers tracked
124    pub fn num_layers(&self) -> usize {
125        self.layer_errors.len()
126    }
127
128    /// Get the total number of error recordings
129    pub fn total_recordings(&self) -> usize {
130        self.layer_errors.values().map(|v| v.len()).sum()
131    }
132
133    /// Get all layer errors
134    pub fn all_layer_errors(&self) -> &HashMap<String, Vec<f64>> {
135        &self.layer_errors
136    }
137
138    /// Get errors for a specific layer
139    pub fn get_layer_errors(&self, layer_name: &str) -> Option<&[f64]> {
140        self.layer_errors.get(layer_name).map(|v| v.as_slice())
141    }
142
143    /// Compute global error statistics
144    pub fn compute_statistics(&self) -> ErrorStatistics {
145        let all_errors: Vec<f64> = self.layer_errors.values().flatten().copied().collect();
146
147        if all_errors.is_empty() {
148            return ErrorStatistics {
149                mean: 0.0,
150                std_dev: 0.0,
151                min: 0.0,
152                max: 0.0,
153                total: 0.0,
154                count: 0,
155            };
156        }
157
158        let count = all_errors.len();
159        let total = all_errors.iter().sum::<f64>();
160        let mean = total / count as f64;
161
162        // Compute standard deviation
163        let variance = all_errors
164            .iter()
165            .map(|&e| (e - mean).powi(2))
166            .sum::<f64>()
167            / count as f64;
168        let std_dev = variance.sqrt();
169
170        let min = all_errors
171            .iter()
172            .cloned()
173            .fold(f64::INFINITY, f64::min);
174        let max = all_errors
175            .iter()
176            .cloned()
177            .fold(f64::NEG_INFINITY, f64::max);
178
179        ErrorStatistics {
180            mean,
181            std_dev,
182            min,
183            max,
184            total,
185            count,
186        }
187    }
188
189    /// Generate a detailed error report
190    pub fn generate_report(&self) -> ErrorReport {
191        let global_stats = self.compute_statistics();
192
193        // Compute per-layer statistics
194        let mut layer_stats: Vec<LayerErrorStats> = self
195            .layer_errors
196            .iter()
197            .map(|(layer_name, errors)| {
198                let count = errors.len();
199                let total = errors.iter().sum::<f64>();
200                let mean = total / count as f64;
201                let max = errors.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
202                let min = errors.iter().cloned().fold(f64::INFINITY, f64::min);
203
204                // Compute layer standard deviation
205                let variance = errors
206                    .iter()
207                    .map(|&e| (e - mean).powi(2))
208                    .sum::<f64>()
209                    / count as f64;
210                let std_dev = variance.sqrt();
211
212                LayerErrorStats {
213                    layer_name: layer_name.clone(),
214                    mean,
215                    std_dev,
216                    min,
217                    max,
218                    count,
219                }
220            })
221            .collect();
222
223        // Sort by max error (descending) to highlight problematic layers
224        layer_stats.sort_by(|a, b| b.max.partial_cmp(&a.max).unwrap_or(std::cmp::Ordering::Equal));
225
226        ErrorReport {
227            global_stats,
228            layer_stats,
229        }
230    }
231
232    /// Reset the accumulator (clear all recorded errors)
233    pub fn reset(&mut self) {
234        self.layer_errors.clear();
235        self.total_error = 0.0;
236        self.min_error = f64::INFINITY;
237        self.max_error = f64::NEG_INFINITY;
238    }
239}
240
241/// Global error statistics
242#[derive(Debug, Clone)]
243pub struct ErrorStatistics {
244    /// Mean error across all recordings
245    pub mean: f64,
246    /// Standard deviation of errors
247    pub std_dev: f64,
248    /// Minimum error observed
249    pub min: f64,
250    /// Maximum error observed
251    pub max: f64,
252    /// Total accumulated error
253    pub total: f64,
254    /// Number of error recordings
255    pub count: usize,
256}
257
258impl fmt::Display for ErrorStatistics {
259    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260        writeln!(f, "Error Statistics:")?;
261        writeln!(f, "  Count:   {}", self.count)?;
262        writeln!(f, "  Mean:    {:.2e}", self.mean)?;
263        writeln!(f, "  Std Dev: {:.2e}", self.std_dev)?;
264        writeln!(f, "  Min:     {:.2e}", self.min)?;
265        writeln!(f, "  Max:     {:.2e}", self.max)?;
266        writeln!(f, "  Total:   {:.2e}", self.total)
267    }
268}
269
270/// Per-layer error statistics
271#[derive(Debug, Clone)]
272pub struct LayerErrorStats {
273    /// Layer name
274    pub layer_name: String,
275    /// Mean error for this layer
276    pub mean: f64,
277    /// Standard deviation for this layer
278    pub std_dev: f64,
279    /// Minimum error for this layer
280    pub min: f64,
281    /// Maximum error for this layer
282    pub max: f64,
283    /// Number of recordings for this layer
284    pub count: usize,
285}
286
287impl fmt::Display for LayerErrorStats {
288    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
289        writeln!(f, "  {}:", self.layer_name)?;
290        writeln!(f, "    Count:   {}", self.count)?;
291        writeln!(f, "    Mean:    {:.2e}", self.mean)?;
292        writeln!(f, "    Std Dev: {:.2e}", self.std_dev)?;
293        writeln!(f, "    Min:     {:.2e}", self.min)?;
294        writeln!(f, "    Max:     {:.2e}", self.max)
295    }
296}
297
298/// Comprehensive error report
299#[derive(Debug, Clone)]
300pub struct ErrorReport {
301    /// Global statistics across all layers
302    pub global_stats: ErrorStatistics,
303    /// Per-layer statistics (sorted by max error, descending)
304    pub layer_stats: Vec<LayerErrorStats>,
305}
306
307impl fmt::Display for ErrorReport {
308    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
309        writeln!(f)?;
310        writeln!(f, "╔══════════════════════════════════════════════════════════╗")?;
311        writeln!(f, "║           ERROR ACCUMULATION REPORT                      ║")?;
312        writeln!(f, "╠══════════════════════════════════════════════════════════╣")?;
313        writeln!(f, "║ GLOBAL STATISTICS                                        ║")?;
314        writeln!(f, "╟──────────────────────────────────────────────────────────╢")?;
315        writeln!(f, "║   Total Recordings: {:>20} ║", self.global_stats.count)?;
316        writeln!(f, "║   Mean Error:       {:>20.2e} ║", self.global_stats.mean)?;
317        writeln!(f, "║   Std Dev:          {:>20.2e} ║", self.global_stats.std_dev)?;
318        writeln!(f, "║   Min Error:        {:>20.2e} ║", self.global_stats.min)?;
319        writeln!(f, "║   Max Error:        {:>20.2e} ║", self.global_stats.max)?;
320        writeln!(f, "║   Total Error:      {:>20.2e} ║", self.global_stats.total)?;
321        writeln!(f, "╟──────────────────────────────────────────────────────────╢")?;
322        writeln!(f, "║ PER-LAYER STATISTICS (sorted by max error)             ║")?;
323        writeln!(f, "╟──────────────────────────────────────────────────────────╢")?;
324
325        // Show top 10 layers by max error
326        let display_count = self.layer_stats.len().min(10);
327        for (i, layer) in self.layer_stats.iter().take(display_count).enumerate() {
328            writeln!(
329                f,
330                "║   {:2}. {:<28} {:.2e} ║",
331                i + 1,
332                truncate_name(&layer.layer_name, 28),
333                layer.max
334            )?;
335        }
336
337        if self.layer_stats.len() > display_count {
338            writeln!(
339                f,
340                "║   ... and {} more layers                              ║",
341                self.layer_stats.len() - display_count
342            )?;
343        }
344
345        writeln!(f, "╚══════════════════════════════════════════════════════════╝")?;
346        Ok(())
347    }
348}
349
350/// Truncate a string to a maximum length, adding ellipsis if needed
351fn truncate_name(s: &str, max_len: usize) -> String {
352    if s.len() <= max_len {
353        s.to_string()
354    } else {
355        format!("...{}", &s[s.len() - max_len + 3..])
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn test_error_accumulator_basic() {
365        let mut accumulator = ErrorAccumulator::new();
366
367        accumulator.record_error("layer_0", 1.0e-14);
368        accumulator.record_error("layer_0", 1.5e-14);
369        accumulator.record_error("layer_1", 2.0e-14);
370
371        assert_eq!(accumulator.num_layers(), 2);
372        assert_eq!(accumulator.total_recordings(), 3);
373        assert!((accumulator.total_error() - 4.5e-14).abs() < 1e-20);
374    }
375
376    #[test]
377    fn test_error_accumulator_statistics() {
378        let mut accumulator = ErrorAccumulator::new();
379
380        // Record known errors
381        let errors = vec![1.0e-14, 2.0e-14, 3.0e-14, 4.0e-14, 5.0e-14];
382        for (i, &error) in errors.iter().enumerate() {
383            accumulator.record_error(&format!("layer_{}", i), error);
384        }
385
386        let stats = accumulator.compute_statistics();
387
388        // Mean should be (1+2+3+4+5)/5 = 3.0e-14
389        assert!((stats.mean - 3.0e-14).abs() < 1e-20);
390        // Min should be 1.0e-14
391        assert!((stats.min - 1.0e-14).abs() < 1e-20);
392        // Max should be 5.0e-14
393        assert!((stats.max - 5.0e-14).abs() < 1e-20);
394        // Total should be 15.0e-14
395        assert!((stats.total - 15.0e-14).abs() < 1e-20);
396        assert_eq!(stats.count, 5);
397    }
398
399    #[test]
400    fn test_error_accumulator_multiple_errors_per_layer() {
401        let mut accumulator = ErrorAccumulator::new();
402
403        // Record multiple errors for the same layer
404        accumulator.record_error("layer_0", 1.0e-14);
405        accumulator.record_error("layer_0", 2.0e-14);
406        accumulator.record_error("layer_0", 3.0e-14);
407
408        let layer_errors = accumulator.get_layer_errors("layer_0").unwrap();
409        assert_eq!(layer_errors.len(), 3);
410        assert!((layer_errors[0] - 1.0e-14).abs() < 1e-20);
411        assert!((layer_errors[1] - 2.0e-14).abs() < 1e-20);
412        assert!((layer_errors[2] - 3.0e-14).abs() < 1e-20);
413    }
414
415    #[test]
416    fn test_error_accumulator_record_multiple() {
417        let mut accumulator = ErrorAccumulator::new();
418
419        let errors = vec![1.0e-14, 2.0e-14, 3.0e-14];
420        accumulator.record_errors("layer_0", errors);
421
422        assert_eq!(accumulator.total_recordings(), 3);
423        let layer_errors = accumulator.get_layer_errors("layer_0").unwrap();
424        assert_eq!(layer_errors.len(), 3);
425    }
426
427    #[test]
428    fn test_error_accumulator_empty() {
429        let accumulator = ErrorAccumulator::new();
430
431        let stats = accumulator.compute_statistics();
432        assert_eq!(stats.count, 0);
433        assert_eq!(stats.mean, 0.0);
434        assert_eq!(stats.total, 0.0);
435    }
436
437    #[test]
438    fn test_error_accumulator_reset() {
439        let mut accumulator = ErrorAccumulator::new();
440        accumulator.record_error("layer_0", 1.0e-14);
441        accumulator.record_error("layer_1", 2.0e-14);
442
443        assert_eq!(accumulator.num_layers(), 2);
444        assert_eq!(accumulator.total_recordings(), 2);
445
446        accumulator.reset();
447
448        assert_eq!(accumulator.num_layers(), 0);
449        assert_eq!(accumulator.total_recordings(), 0);
450        assert_eq!(accumulator.total_error(), 0.0);
451    }
452
453    #[test]
454    fn test_error_report_generation() {
455        let mut accumulator = ErrorAccumulator::new();
456
457        // Simulate errors from a small transformer graph
458        accumulator.record_error("embeddings", 5.0e-15);
459        accumulator.record_error("layer_0/q_proj", 1.2e-14);
460        accumulator.record_error("layer_0/k_proj", 1.5e-14);
461        accumulator.record_error("layer_0/v_proj", 1.1e-14);
462        accumulator.record_error("layer_0/out_proj", 1.3e-14);
463        accumulator.record_error("layer_1/q_proj", 2.1e-14);
464        accumulator.record_error("layer_1/k_proj", 1.8e-14);
465        accumulator.record_error("layer_1/v_proj", 2.3e-14);
466        accumulator.record_error("layer_1/out_proj", 1.9e-14);
467        accumulator.record_error("lm_head", 3.5e-14);
468
469        let report = accumulator.generate_report();
470
471        // Verify global stats
472        assert_eq!(report.global_stats.count, 10);
473        assert_eq!(report.layer_stats.len(), 10);
474
475        // lm_head should have the highest max error
476        assert_eq!(report.layer_stats[0].layer_name, "lm_head");
477        assert!((report.layer_stats[0].max - 3.5e-14).abs() < 1e-20);
478    }
479
480    #[test]
481    fn test_error_report_display() {
482        let mut accumulator = ErrorAccumulator::new();
483        accumulator.record_error("layer_0", 1.0e-14);
484        accumulator.record_error("layer_1", 2.0e-14);
485
486        let report = accumulator.generate_report();
487        let display = format!("{}", report);
488
489        assert!(display.contains("ERROR ACCUMULATION REPORT"));
490        assert!(display.contains("GLOBAL STATISTICS"));
491        assert!(display.contains("PER-LAYER STATISTICS"));
492        assert!(display.contains("layer_0"));
493        assert!(display.contains("layer_1"));
494    }
495
496    #[test]
497    fn test_layer_error_stats_display() {
498        let stats = LayerErrorStats {
499            layer_name: "test_layer".to_string(),
500            mean: 1.5e-14,
501            std_dev: 0.5e-14,
502            min: 1.0e-14,
503            max: 2.0e-14,
504            count: 3,
505        };
506
507        let display = format!("{}", stats);
508        assert!(display.contains("test_layer"));
509        assert!(display.contains("Count:"));
510        assert!(display.contains("Mean:"));
511    }
512
513    #[test]
514    fn test_error_statistics_display() {
515        let stats = ErrorStatistics {
516            mean: 1.5e-14,
517            std_dev: 0.5e-14,
518            min: 1.0e-14,
519            max: 2.0e-14,
520            total: 4.5e-14,
521            count: 3,
522        };
523
524        let display = format!("{}", stats);
525        assert!(display.contains("Error Statistics:"));
526        assert!(display.contains("Count:"));
527        assert!(display.contains("Mean:"));
528    }
529}