Skip to main content

optirs_core/
gradient_flow.rs

1//! Gradient Flow Analysis Module
2//!
3//! Provides tools for analyzing gradient flow through neural network layers,
4//! detecting vanishing/exploding gradients, and generating visual reports.
5
6use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array1, ScalarOperand};
8use scirs2_core::numeric::Float;
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12/// Configuration for gradient flow analysis
13#[derive(Debug, Clone)]
14pub struct GradientFlowConfig {
15    /// Threshold below which gradients are considered vanishing
16    pub vanishing_threshold: f64,
17    /// Threshold above which gradients are considered exploding
18    pub exploding_threshold: f64,
19    /// Number of histogram bins for gradient magnitude distribution
20    pub histogram_bins: usize,
21    /// Maximum number of historical records to keep per layer
22    pub max_history: usize,
23}
24
25impl Default for GradientFlowConfig {
26    fn default() -> Self {
27        Self {
28            vanishing_threshold: 1e-7,
29            exploding_threshold: 1e3,
30            histogram_bins: 50,
31            max_history: 100,
32        }
33    }
34}
35
36/// Per-layer gradient statistics
37#[derive(Debug, Clone)]
38pub struct LayerGradientStats<A> {
39    /// Name of the layer
40    pub layer_name: String,
41    /// Mean gradient norm
42    pub mean_norm: A,
43    /// Maximum gradient norm
44    pub max_norm: A,
45    /// Minimum gradient norm
46    pub min_norm: A,
47    /// Variance of gradient magnitudes
48    pub variance: A,
49    /// Fraction of near-zero elements (sparsity)
50    pub sparsity: A,
51    /// Histogram of gradient magnitude distribution
52    pub histogram: Vec<usize>,
53}
54
55/// Overall gradient health status
56#[derive(Debug, Clone, PartialEq)]
57pub enum GradientHealth {
58    /// All layers have healthy gradient flow
59    Healthy,
60    /// Some layers show concerning gradient behavior
61    Warning,
62    /// Critical gradient flow issues detected
63    Critical,
64}
65
66/// Comprehensive gradient health report
67#[derive(Debug, Clone)]
68pub struct GradientHealthReport {
69    /// Layers with vanishing gradients
70    pub vanishing_layers: Vec<String>,
71    /// Layers with exploding gradients
72    pub exploding_layers: Vec<String>,
73    /// Layers with healthy gradient flow
74    pub healthy_layers: Vec<String>,
75    /// Overall health assessment
76    pub overall_health: GradientHealth,
77    /// Actionable recommendations
78    pub recommendations: Vec<String>,
79}
80
81/// Gradient flow analyzer for monitoring and diagnosing gradient behavior
82pub struct GradientFlowAnalyzer<A> {
83    /// Configuration parameters
84    config: GradientFlowConfig,
85    /// Historical statistics per layer
86    layer_stats: HashMap<String, Vec<LayerGradientStats<A>>>,
87    /// Ordering of layers for rendering
88    layer_order: Vec<String>,
89}
90
91impl<A> GradientFlowAnalyzer<A>
92where
93    A: Float + ScalarOperand + Debug + std::iter::Sum,
94{
95    /// Create a new gradient flow analyzer with the given configuration
96    pub fn new(config: GradientFlowConfig) -> Self {
97        Self {
98            config,
99            layer_stats: HashMap::new(),
100            layer_order: Vec::new(),
101        }
102    }
103
104    /// Record gradients for a layer and compute statistics
105    ///
106    /// Computes mean norm, max norm, min norm, variance, sparsity, and
107    /// a histogram of gradient magnitudes. The results are stored in the
108    /// internal history for later analysis.
109    pub fn record_gradients(
110        &mut self,
111        layer_name: &str,
112        gradients: &Array1<A>,
113    ) -> Result<LayerGradientStats<A>> {
114        let len = gradients.len();
115        if len == 0 {
116            return Err(OptimError::InvalidParameter(
117                "Gradients array must not be empty".to_string(),
118            ));
119        }
120
121        let len_a = A::from(len).ok_or_else(|| {
122            OptimError::ComputationError("Failed to convert length to float".to_string())
123        })?;
124
125        // Compute absolute values for magnitude analysis
126        let abs_grads: Vec<A> = gradients.iter().map(|&g| g.abs()).collect();
127
128        // Mean norm
129        let sum: A = abs_grads.iter().copied().sum();
130        let mean_norm = sum / len_a;
131
132        // Max and min norms
133        let max_norm = abs_grads
134            .iter()
135            .copied()
136            .fold(A::neg_infinity(), |a, b| if b > a { b } else { a });
137        let min_norm = abs_grads
138            .iter()
139            .copied()
140            .fold(A::infinity(), |a, b| if b < a { b } else { a });
141
142        // Variance: E[x^2] - (E[x])^2
143        let sum_sq: A = abs_grads.iter().map(|&g| g * g).sum();
144        let mean_sq = sum_sq / len_a;
145        let variance = mean_sq - mean_norm * mean_norm;
146        // Clamp to zero in case of floating point issues
147        let variance = if variance < A::zero() {
148            A::zero()
149        } else {
150            variance
151        };
152
153        // Sparsity: fraction of elements with magnitude below vanishing threshold
154        let vanishing_thresh = A::from(self.config.vanishing_threshold).ok_or_else(|| {
155            OptimError::ComputationError(
156                "Failed to convert vanishing threshold to float".to_string(),
157            )
158        })?;
159        let near_zero_count = abs_grads.iter().filter(|&&g| g < vanishing_thresh).count();
160        let sparsity = A::from(near_zero_count).ok_or_else(|| {
161            OptimError::ComputationError("Failed to convert count to float".to_string())
162        })? / len_a;
163
164        // Histogram of gradient magnitudes
165        let histogram = self.compute_histogram(&abs_grads, max_norm)?;
166
167        let stats = LayerGradientStats {
168            layer_name: layer_name.to_string(),
169            mean_norm,
170            max_norm,
171            min_norm,
172            variance,
173            sparsity,
174            histogram,
175        };
176
177        // Track layer ordering
178        if !self.layer_order.contains(&layer_name.to_string()) {
179            self.layer_order.push(layer_name.to_string());
180        }
181
182        // Store in history, respecting max_history
183        let history = self.layer_stats.entry(layer_name.to_string()).or_default();
184        history.push(stats.clone());
185        if history.len() > self.config.max_history {
186            history.remove(0);
187        }
188
189        Ok(stats)
190    }
191
192    /// Compute a histogram of gradient magnitudes
193    fn compute_histogram(&self, abs_grads: &[A], max_val: A) -> Result<Vec<usize>> {
194        let bins = self.config.histogram_bins;
195        let mut histogram = vec![0usize; bins];
196
197        if max_val <= A::zero() {
198            // All zeros, put everything in first bin
199            histogram[0] = abs_grads.len();
200            return Ok(histogram);
201        }
202
203        for &val in abs_grads {
204            let normalized = val / max_val;
205            let bin_idx = (normalized
206                * A::from(bins).ok_or_else(|| {
207                    OptimError::ComputationError("Failed to convert bins to float".to_string())
208                })?)
209            .to_f64()
210            .ok_or_else(|| OptimError::ComputationError("Failed to convert to f64".to_string()))?;
211            let bin_idx = (bin_idx as usize).min(bins - 1);
212            histogram[bin_idx] += 1;
213        }
214
215        Ok(histogram)
216    }
217
218    /// Detect layers with vanishing gradients
219    ///
220    /// Returns names of layers whose most recent mean gradient norm
221    /// is below the configured vanishing threshold.
222    pub fn detect_vanishing_gradients(&self) -> Vec<String> {
223        let threshold = self.config.vanishing_threshold;
224        let mut vanishing = Vec::new();
225
226        for (name, stats_history) in &self.layer_stats {
227            if let Some(latest) = stats_history.last() {
228                let mean_f64 = latest.mean_norm.to_f64().unwrap_or(0.0);
229                if mean_f64 < threshold {
230                    vanishing.push(name.clone());
231                }
232            }
233        }
234
235        vanishing.sort();
236        vanishing
237    }
238
239    /// Detect layers with exploding gradients
240    ///
241    /// Returns names of layers whose most recent max gradient norm
242    /// is above the configured exploding threshold.
243    pub fn detect_exploding_gradients(&self) -> Vec<String> {
244        let threshold = self.config.exploding_threshold;
245        let mut exploding = Vec::new();
246
247        for (name, stats_history) in &self.layer_stats {
248            if let Some(latest) = stats_history.last() {
249                let max_f64 = latest.max_norm.to_f64().unwrap_or(0.0);
250                if max_f64 > threshold {
251                    exploding.push(name.clone());
252                }
253            }
254        }
255
256        exploding.sort();
257        exploding
258    }
259
260    /// Generate a comprehensive gradient health report
261    ///
262    /// Analyzes all tracked layers and produces a report with:
263    /// - Lists of vanishing, exploding, and healthy layers
264    /// - An overall health assessment
265    /// - Actionable recommendations for fixing gradient issues
266    pub fn get_health_report(&self) -> GradientHealthReport {
267        let vanishing = self.detect_vanishing_gradients();
268        let exploding = self.detect_exploding_gradients();
269
270        let mut healthy = Vec::new();
271        for name in &self.layer_order {
272            if !vanishing.contains(name) && !exploding.contains(name) {
273                healthy.push(name.clone());
274            }
275        }
276
277        let overall_health = if !exploding.is_empty() {
278            GradientHealth::Critical
279        } else if !vanishing.is_empty() {
280            if vanishing.len() > self.layer_order.len() / 2 {
281                GradientHealth::Critical
282            } else {
283                GradientHealth::Warning
284            }
285        } else {
286            GradientHealth::Healthy
287        };
288
289        let mut recommendations = Vec::new();
290
291        if !vanishing.is_empty() {
292            recommendations.push(format!(
293                "Vanishing gradients detected in {} layer(s): consider using residual connections, \
294                 batch normalization, or switching to ReLU-family activations.",
295                vanishing.len()
296            ));
297            recommendations
298                .push("Consider using gradient scaling or a smaller model depth.".to_string());
299        }
300
301        if !exploding.is_empty() {
302            recommendations.push(format!(
303                "Exploding gradients detected in {} layer(s): apply gradient clipping \
304                 (e.g., max norm clipping) or reduce learning rate.",
305                exploding.len()
306            ));
307            recommendations.push(
308                "Consider weight initialization with smaller variance (e.g., He or Xavier init)."
309                    .to_string(),
310            );
311        }
312
313        if vanishing.is_empty() && exploding.is_empty() {
314            recommendations.push("Gradient flow appears healthy across all layers.".to_string());
315        }
316
317        GradientHealthReport {
318            vanishing_layers: vanishing,
319            exploding_layers: exploding,
320            healthy_layers: healthy,
321            overall_health,
322            recommendations,
323        }
324    }
325
326    /// Render an SVG flow chart showing gradient magnitudes across layers
327    ///
328    /// Produces an SVG string with bars representing mean gradient norms
329    /// for each layer, color-coded by health status.
330    pub fn render_flow_chart(&self) -> Result<String> {
331        if self.layer_order.is_empty() {
332            return Err(OptimError::InvalidState(
333                "No gradient data recorded yet".to_string(),
334            ));
335        }
336
337        let vanishing = self.detect_vanishing_gradients();
338        let exploding = self.detect_exploding_gradients();
339
340        let bar_width = 40;
341        let bar_spacing = 10;
342        let margin_left = 150;
343        let margin_top = 40;
344        let chart_width = 400;
345        let num_layers = self.layer_order.len();
346        let total_height = margin_top + num_layers * (bar_width + bar_spacing) + 40;
347        let total_width = margin_left + chart_width + 60;
348
349        let mut svg = format!(
350            r#"<svg xmlns="http://www.w3.org/2000/svg" width="{}" height="{}" viewBox="0 0 {} {}">"#,
351            total_width, total_height, total_width, total_height
352        );
353        svg.push('\n');
354
355        // Title
356        svg.push_str(&format!(
357            r#"  <text x="{}" y="25" text-anchor="middle" font-size="16" font-weight="bold">Gradient Flow Analysis</text>"#,
358            total_width / 2
359        ));
360        svg.push('\n');
361
362        // Find max mean_norm for scaling
363        let mut max_mean = 0.0f64;
364        for name in &self.layer_order {
365            if let Some(history) = self.layer_stats.get(name) {
366                if let Some(latest) = history.last() {
367                    let val = latest.mean_norm.to_f64().unwrap_or(0.0);
368                    if val > max_mean {
369                        max_mean = val;
370                    }
371                }
372            }
373        }
374        if max_mean <= 0.0 {
375            max_mean = 1.0;
376        }
377
378        for (i, name) in self.layer_order.iter().enumerate() {
379            let y = margin_top + i * (bar_width + bar_spacing);
380
381            let mean_val = self
382                .layer_stats
383                .get(name)
384                .and_then(|h| h.last())
385                .map(|s| s.mean_norm.to_f64().unwrap_or(0.0))
386                .unwrap_or(0.0);
387
388            let bar_len = ((mean_val / max_mean) * chart_width as f64).max(1.0) as usize;
389
390            let color = if exploding.contains(name) {
391                "#ff4444" // Red for exploding
392            } else if vanishing.contains(name) {
393                "#ffaa00" // Orange for vanishing
394            } else {
395                "#44bb44" // Green for healthy
396            };
397
398            // Layer label
399            svg.push_str(&format!(
400                r#"  <text x="{}" y="{}" text-anchor="end" font-size="12" dominant-baseline="middle">{}</text>"#,
401                margin_left - 10,
402                y + bar_width / 2,
403                name
404            ));
405            svg.push('\n');
406
407            // Bar
408            svg.push_str(&format!(
409                r#"  <rect x="{}" y="{}" width="{}" height="{}" fill="{}" rx="3" ry="3"/>"#,
410                margin_left, y, bar_len, bar_width, color
411            ));
412            svg.push('\n');
413
414            // Value label
415            svg.push_str(&format!(
416                r#"  <text x="{}" y="{}" font-size="10" dominant-baseline="middle">{:.2e}</text>"#,
417                margin_left + bar_len + 5,
418                y + bar_width / 2,
419                mean_val
420            ));
421            svg.push('\n');
422        }
423
424        svg.push_str("</svg>");
425        Ok(svg)
426    }
427
428    /// Get the history of gradient statistics for a specific layer
429    pub fn get_layer_history(&self, layer_name: &str) -> Option<&Vec<LayerGradientStats<A>>> {
430        self.layer_stats.get(layer_name)
431    }
432
433    /// Clear all recorded gradient history
434    pub fn clear_history(&mut self) {
435        self.layer_stats.clear();
436        self.layer_order.clear();
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443    use scirs2_core::ndarray::Array1;
444
445    #[test]
446    fn test_record_gradients_basic() {
447        let config = GradientFlowConfig::default();
448        let mut analyzer = GradientFlowAnalyzer::<f64>::new(config);
449
450        let gradients = Array1::from_vec(vec![0.1, -0.2, 0.3, -0.4, 0.5]);
451        let stats = analyzer
452            .record_gradients("layer1", &gradients)
453            .expect("Should record gradients");
454
455        assert_eq!(stats.layer_name, "layer1");
456        // Mean of |0.1, 0.2, 0.3, 0.4, 0.5| = 0.3
457        assert!((stats.mean_norm - 0.3).abs() < 1e-10);
458        // Max = 0.5
459        assert!((stats.max_norm - 0.5).abs() < 1e-10);
460        // Min = 0.1
461        assert!((stats.min_norm - 0.1).abs() < 1e-10);
462        // Sparsity should be 0 (no values below 1e-7)
463        assert!((stats.sparsity - 0.0).abs() < 1e-10);
464        // Histogram should sum to gradient count
465        let hist_sum: usize = stats.histogram.iter().sum();
466        assert_eq!(hist_sum, 5);
467
468        // Verify history was stored
469        let history = analyzer.get_layer_history("layer1");
470        assert!(history.is_some());
471        assert_eq!(history.map(|h| h.len()).unwrap_or(0), 1);
472    }
473
474    #[test]
475    fn test_detect_vanishing_gradients() {
476        let config = GradientFlowConfig {
477            vanishing_threshold: 1e-7,
478            ..Default::default()
479        };
480        let mut analyzer = GradientFlowAnalyzer::<f64>::new(config);
481
482        // Normal gradients
483        let normal_grads = Array1::from_vec(vec![0.01, 0.02, 0.015, 0.008]);
484        analyzer
485            .record_gradients("healthy_layer", &normal_grads)
486            .expect("Should record");
487
488        // Vanishing gradients
489        let tiny_grads = Array1::from_vec(vec![1e-9, 1e-10, 1e-8, 1e-11]);
490        analyzer
491            .record_gradients("vanishing_layer", &tiny_grads)
492            .expect("Should record");
493
494        let vanishing = analyzer.detect_vanishing_gradients();
495        assert!(vanishing.contains(&"vanishing_layer".to_string()));
496        assert!(!vanishing.contains(&"healthy_layer".to_string()));
497    }
498
499    #[test]
500    fn test_detect_exploding_gradients() {
501        let config = GradientFlowConfig {
502            exploding_threshold: 1e3,
503            ..Default::default()
504        };
505        let mut analyzer = GradientFlowAnalyzer::<f64>::new(config);
506
507        // Normal gradients
508        let normal_grads = Array1::from_vec(vec![0.5, 1.0, 0.3, 0.8]);
509        analyzer
510            .record_gradients("normal_layer", &normal_grads)
511            .expect("Should record");
512
513        // Exploding gradients
514        let huge_grads = Array1::from_vec(vec![5000.0, 10000.0, 3000.0, 8000.0]);
515        analyzer
516            .record_gradients("exploding_layer", &huge_grads)
517            .expect("Should record");
518
519        let exploding = analyzer.detect_exploding_gradients();
520        assert!(exploding.contains(&"exploding_layer".to_string()));
521        assert!(!exploding.contains(&"normal_layer".to_string()));
522    }
523
524    #[test]
525    fn test_health_report_generation() {
526        let config = GradientFlowConfig::default();
527        let mut analyzer = GradientFlowAnalyzer::<f64>::new(config);
528
529        // Healthy layer
530        let healthy = Array1::from_vec(vec![0.01, 0.02, 0.015]);
531        analyzer
532            .record_gradients("fc1", &healthy)
533            .expect("Should record");
534
535        // Vanishing layer
536        let vanishing = Array1::from_vec(vec![1e-10, 1e-11, 1e-9]);
537        analyzer
538            .record_gradients("fc2", &vanishing)
539            .expect("Should record");
540
541        // Exploding layer
542        let exploding = Array1::from_vec(vec![5000.0, 10000.0, 8000.0]);
543        analyzer
544            .record_gradients("fc3", &exploding)
545            .expect("Should record");
546
547        let report = analyzer.get_health_report();
548
549        assert!(report.vanishing_layers.contains(&"fc2".to_string()));
550        assert!(report.exploding_layers.contains(&"fc3".to_string()));
551        assert!(report.healthy_layers.contains(&"fc1".to_string()));
552        assert_eq!(report.overall_health, GradientHealth::Critical);
553        assert!(!report.recommendations.is_empty());
554    }
555
556    #[test]
557    fn test_render_flow_chart_svg() {
558        let config = GradientFlowConfig::default();
559        let mut analyzer = GradientFlowAnalyzer::<f64>::new(config);
560
561        let grads1 = Array1::from_vec(vec![0.01, 0.02, 0.015]);
562        let grads2 = Array1::from_vec(vec![0.005, 0.003, 0.004]);
563        let grads3 = Array1::from_vec(vec![0.1, 0.08, 0.12]);
564
565        analyzer
566            .record_gradients("conv1", &grads1)
567            .expect("Should record");
568        analyzer
569            .record_gradients("conv2", &grads2)
570            .expect("Should record");
571        analyzer
572            .record_gradients("fc1", &grads3)
573            .expect("Should record");
574
575        let svg = analyzer
576            .render_flow_chart()
577            .expect("Should render flow chart");
578
579        assert!(svg.starts_with("<svg"));
580        assert!(svg.ends_with("</svg>"));
581        assert!(svg.contains("conv1"));
582        assert!(svg.contains("conv2"));
583        assert!(svg.contains("fc1"));
584        assert!(svg.contains("Gradient Flow Analysis"));
585    }
586}