1use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array1, ScalarOperand};
8use scirs2_core::numeric::Float;
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12#[derive(Debug, Clone)]
14pub struct GradientFlowConfig {
15 pub vanishing_threshold: f64,
17 pub exploding_threshold: f64,
19 pub histogram_bins: usize,
21 pub max_history: usize,
23}
24
25impl Default for GradientFlowConfig {
26 fn default() -> Self {
27 Self {
28 vanishing_threshold: 1e-7,
29 exploding_threshold: 1e3,
30 histogram_bins: 50,
31 max_history: 100,
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct LayerGradientStats<A> {
39 pub layer_name: String,
41 pub mean_norm: A,
43 pub max_norm: A,
45 pub min_norm: A,
47 pub variance: A,
49 pub sparsity: A,
51 pub histogram: Vec<usize>,
53}
54
55#[derive(Debug, Clone, PartialEq)]
57pub enum GradientHealth {
58 Healthy,
60 Warning,
62 Critical,
64}
65
66#[derive(Debug, Clone)]
68pub struct GradientHealthReport {
69 pub vanishing_layers: Vec<String>,
71 pub exploding_layers: Vec<String>,
73 pub healthy_layers: Vec<String>,
75 pub overall_health: GradientHealth,
77 pub recommendations: Vec<String>,
79}
80
81pub struct GradientFlowAnalyzer<A> {
83 config: GradientFlowConfig,
85 layer_stats: HashMap<String, Vec<LayerGradientStats<A>>>,
87 layer_order: Vec<String>,
89}
90
91impl<A> GradientFlowAnalyzer<A>
92where
93 A: Float + ScalarOperand + Debug + std::iter::Sum,
94{
95 pub fn new(config: GradientFlowConfig) -> Self {
97 Self {
98 config,
99 layer_stats: HashMap::new(),
100 layer_order: Vec::new(),
101 }
102 }
103
104 pub fn record_gradients(
110 &mut self,
111 layer_name: &str,
112 gradients: &Array1<A>,
113 ) -> Result<LayerGradientStats<A>> {
114 let len = gradients.len();
115 if len == 0 {
116 return Err(OptimError::InvalidParameter(
117 "Gradients array must not be empty".to_string(),
118 ));
119 }
120
121 let len_a = A::from(len).ok_or_else(|| {
122 OptimError::ComputationError("Failed to convert length to float".to_string())
123 })?;
124
125 let abs_grads: Vec<A> = gradients.iter().map(|&g| g.abs()).collect();
127
128 let sum: A = abs_grads.iter().copied().sum();
130 let mean_norm = sum / len_a;
131
132 let max_norm = abs_grads
134 .iter()
135 .copied()
136 .fold(A::neg_infinity(), |a, b| if b > a { b } else { a });
137 let min_norm = abs_grads
138 .iter()
139 .copied()
140 .fold(A::infinity(), |a, b| if b < a { b } else { a });
141
142 let sum_sq: A = abs_grads.iter().map(|&g| g * g).sum();
144 let mean_sq = sum_sq / len_a;
145 let variance = mean_sq - mean_norm * mean_norm;
146 let variance = if variance < A::zero() {
148 A::zero()
149 } else {
150 variance
151 };
152
153 let vanishing_thresh = A::from(self.config.vanishing_threshold).ok_or_else(|| {
155 OptimError::ComputationError(
156 "Failed to convert vanishing threshold to float".to_string(),
157 )
158 })?;
159 let near_zero_count = abs_grads.iter().filter(|&&g| g < vanishing_thresh).count();
160 let sparsity = A::from(near_zero_count).ok_or_else(|| {
161 OptimError::ComputationError("Failed to convert count to float".to_string())
162 })? / len_a;
163
164 let histogram = self.compute_histogram(&abs_grads, max_norm)?;
166
167 let stats = LayerGradientStats {
168 layer_name: layer_name.to_string(),
169 mean_norm,
170 max_norm,
171 min_norm,
172 variance,
173 sparsity,
174 histogram,
175 };
176
177 if !self.layer_order.contains(&layer_name.to_string()) {
179 self.layer_order.push(layer_name.to_string());
180 }
181
182 let history = self.layer_stats.entry(layer_name.to_string()).or_default();
184 history.push(stats.clone());
185 if history.len() > self.config.max_history {
186 history.remove(0);
187 }
188
189 Ok(stats)
190 }
191
192 fn compute_histogram(&self, abs_grads: &[A], max_val: A) -> Result<Vec<usize>> {
194 let bins = self.config.histogram_bins;
195 let mut histogram = vec![0usize; bins];
196
197 if max_val <= A::zero() {
198 histogram[0] = abs_grads.len();
200 return Ok(histogram);
201 }
202
203 for &val in abs_grads {
204 let normalized = val / max_val;
205 let bin_idx = (normalized
206 * A::from(bins).ok_or_else(|| {
207 OptimError::ComputationError("Failed to convert bins to float".to_string())
208 })?)
209 .to_f64()
210 .ok_or_else(|| OptimError::ComputationError("Failed to convert to f64".to_string()))?;
211 let bin_idx = (bin_idx as usize).min(bins - 1);
212 histogram[bin_idx] += 1;
213 }
214
215 Ok(histogram)
216 }
217
218 pub fn detect_vanishing_gradients(&self) -> Vec<String> {
223 let threshold = self.config.vanishing_threshold;
224 let mut vanishing = Vec::new();
225
226 for (name, stats_history) in &self.layer_stats {
227 if let Some(latest) = stats_history.last() {
228 let mean_f64 = latest.mean_norm.to_f64().unwrap_or(0.0);
229 if mean_f64 < threshold {
230 vanishing.push(name.clone());
231 }
232 }
233 }
234
235 vanishing.sort();
236 vanishing
237 }
238
239 pub fn detect_exploding_gradients(&self) -> Vec<String> {
244 let threshold = self.config.exploding_threshold;
245 let mut exploding = Vec::new();
246
247 for (name, stats_history) in &self.layer_stats {
248 if let Some(latest) = stats_history.last() {
249 let max_f64 = latest.max_norm.to_f64().unwrap_or(0.0);
250 if max_f64 > threshold {
251 exploding.push(name.clone());
252 }
253 }
254 }
255
256 exploding.sort();
257 exploding
258 }
259
260 pub fn get_health_report(&self) -> GradientHealthReport {
267 let vanishing = self.detect_vanishing_gradients();
268 let exploding = self.detect_exploding_gradients();
269
270 let mut healthy = Vec::new();
271 for name in &self.layer_order {
272 if !vanishing.contains(name) && !exploding.contains(name) {
273 healthy.push(name.clone());
274 }
275 }
276
277 let overall_health = if !exploding.is_empty() {
278 GradientHealth::Critical
279 } else if !vanishing.is_empty() {
280 if vanishing.len() > self.layer_order.len() / 2 {
281 GradientHealth::Critical
282 } else {
283 GradientHealth::Warning
284 }
285 } else {
286 GradientHealth::Healthy
287 };
288
289 let mut recommendations = Vec::new();
290
291 if !vanishing.is_empty() {
292 recommendations.push(format!(
293 "Vanishing gradients detected in {} layer(s): consider using residual connections, \
294 batch normalization, or switching to ReLU-family activations.",
295 vanishing.len()
296 ));
297 recommendations
298 .push("Consider using gradient scaling or a smaller model depth.".to_string());
299 }
300
301 if !exploding.is_empty() {
302 recommendations.push(format!(
303 "Exploding gradients detected in {} layer(s): apply gradient clipping \
304 (e.g., max norm clipping) or reduce learning rate.",
305 exploding.len()
306 ));
307 recommendations.push(
308 "Consider weight initialization with smaller variance (e.g., He or Xavier init)."
309 .to_string(),
310 );
311 }
312
313 if vanishing.is_empty() && exploding.is_empty() {
314 recommendations.push("Gradient flow appears healthy across all layers.".to_string());
315 }
316
317 GradientHealthReport {
318 vanishing_layers: vanishing,
319 exploding_layers: exploding,
320 healthy_layers: healthy,
321 overall_health,
322 recommendations,
323 }
324 }
325
326 pub fn render_flow_chart(&self) -> Result<String> {
331 if self.layer_order.is_empty() {
332 return Err(OptimError::InvalidState(
333 "No gradient data recorded yet".to_string(),
334 ));
335 }
336
337 let vanishing = self.detect_vanishing_gradients();
338 let exploding = self.detect_exploding_gradients();
339
340 let bar_width = 40;
341 let bar_spacing = 10;
342 let margin_left = 150;
343 let margin_top = 40;
344 let chart_width = 400;
345 let num_layers = self.layer_order.len();
346 let total_height = margin_top + num_layers * (bar_width + bar_spacing) + 40;
347 let total_width = margin_left + chart_width + 60;
348
349 let mut svg = format!(
350 r#"<svg xmlns="http://www.w3.org/2000/svg" width="{}" height="{}" viewBox="0 0 {} {}">"#,
351 total_width, total_height, total_width, total_height
352 );
353 svg.push('\n');
354
355 svg.push_str(&format!(
357 r#" <text x="{}" y="25" text-anchor="middle" font-size="16" font-weight="bold">Gradient Flow Analysis</text>"#,
358 total_width / 2
359 ));
360 svg.push('\n');
361
362 let mut max_mean = 0.0f64;
364 for name in &self.layer_order {
365 if let Some(history) = self.layer_stats.get(name) {
366 if let Some(latest) = history.last() {
367 let val = latest.mean_norm.to_f64().unwrap_or(0.0);
368 if val > max_mean {
369 max_mean = val;
370 }
371 }
372 }
373 }
374 if max_mean <= 0.0 {
375 max_mean = 1.0;
376 }
377
378 for (i, name) in self.layer_order.iter().enumerate() {
379 let y = margin_top + i * (bar_width + bar_spacing);
380
381 let mean_val = self
382 .layer_stats
383 .get(name)
384 .and_then(|h| h.last())
385 .map(|s| s.mean_norm.to_f64().unwrap_or(0.0))
386 .unwrap_or(0.0);
387
388 let bar_len = ((mean_val / max_mean) * chart_width as f64).max(1.0) as usize;
389
390 let color = if exploding.contains(name) {
391 "#ff4444" } else if vanishing.contains(name) {
393 "#ffaa00" } else {
395 "#44bb44" };
397
398 svg.push_str(&format!(
400 r#" <text x="{}" y="{}" text-anchor="end" font-size="12" dominant-baseline="middle">{}</text>"#,
401 margin_left - 10,
402 y + bar_width / 2,
403 name
404 ));
405 svg.push('\n');
406
407 svg.push_str(&format!(
409 r#" <rect x="{}" y="{}" width="{}" height="{}" fill="{}" rx="3" ry="3"/>"#,
410 margin_left, y, bar_len, bar_width, color
411 ));
412 svg.push('\n');
413
414 svg.push_str(&format!(
416 r#" <text x="{}" y="{}" font-size="10" dominant-baseline="middle">{:.2e}</text>"#,
417 margin_left + bar_len + 5,
418 y + bar_width / 2,
419 mean_val
420 ));
421 svg.push('\n');
422 }
423
424 svg.push_str("</svg>");
425 Ok(svg)
426 }
427
428 pub fn get_layer_history(&self, layer_name: &str) -> Option<&Vec<LayerGradientStats<A>>> {
430 self.layer_stats.get(layer_name)
431 }
432
433 pub fn clear_history(&mut self) {
435 self.layer_stats.clear();
436 self.layer_order.clear();
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 use scirs2_core::ndarray::Array1;
444
445 #[test]
446 fn test_record_gradients_basic() {
447 let config = GradientFlowConfig::default();
448 let mut analyzer = GradientFlowAnalyzer::<f64>::new(config);
449
450 let gradients = Array1::from_vec(vec![0.1, -0.2, 0.3, -0.4, 0.5]);
451 let stats = analyzer
452 .record_gradients("layer1", &gradients)
453 .expect("Should record gradients");
454
455 assert_eq!(stats.layer_name, "layer1");
456 assert!((stats.mean_norm - 0.3).abs() < 1e-10);
458 assert!((stats.max_norm - 0.5).abs() < 1e-10);
460 assert!((stats.min_norm - 0.1).abs() < 1e-10);
462 assert!((stats.sparsity - 0.0).abs() < 1e-10);
464 let hist_sum: usize = stats.histogram.iter().sum();
466 assert_eq!(hist_sum, 5);
467
468 let history = analyzer.get_layer_history("layer1");
470 assert!(history.is_some());
471 assert_eq!(history.map(|h| h.len()).unwrap_or(0), 1);
472 }
473
474 #[test]
475 fn test_detect_vanishing_gradients() {
476 let config = GradientFlowConfig {
477 vanishing_threshold: 1e-7,
478 ..Default::default()
479 };
480 let mut analyzer = GradientFlowAnalyzer::<f64>::new(config);
481
482 let normal_grads = Array1::from_vec(vec![0.01, 0.02, 0.015, 0.008]);
484 analyzer
485 .record_gradients("healthy_layer", &normal_grads)
486 .expect("Should record");
487
488 let tiny_grads = Array1::from_vec(vec![1e-9, 1e-10, 1e-8, 1e-11]);
490 analyzer
491 .record_gradients("vanishing_layer", &tiny_grads)
492 .expect("Should record");
493
494 let vanishing = analyzer.detect_vanishing_gradients();
495 assert!(vanishing.contains(&"vanishing_layer".to_string()));
496 assert!(!vanishing.contains(&"healthy_layer".to_string()));
497 }
498
499 #[test]
500 fn test_detect_exploding_gradients() {
501 let config = GradientFlowConfig {
502 exploding_threshold: 1e3,
503 ..Default::default()
504 };
505 let mut analyzer = GradientFlowAnalyzer::<f64>::new(config);
506
507 let normal_grads = Array1::from_vec(vec![0.5, 1.0, 0.3, 0.8]);
509 analyzer
510 .record_gradients("normal_layer", &normal_grads)
511 .expect("Should record");
512
513 let huge_grads = Array1::from_vec(vec![5000.0, 10000.0, 3000.0, 8000.0]);
515 analyzer
516 .record_gradients("exploding_layer", &huge_grads)
517 .expect("Should record");
518
519 let exploding = analyzer.detect_exploding_gradients();
520 assert!(exploding.contains(&"exploding_layer".to_string()));
521 assert!(!exploding.contains(&"normal_layer".to_string()));
522 }
523
524 #[test]
525 fn test_health_report_generation() {
526 let config = GradientFlowConfig::default();
527 let mut analyzer = GradientFlowAnalyzer::<f64>::new(config);
528
529 let healthy = Array1::from_vec(vec![0.01, 0.02, 0.015]);
531 analyzer
532 .record_gradients("fc1", &healthy)
533 .expect("Should record");
534
535 let vanishing = Array1::from_vec(vec![1e-10, 1e-11, 1e-9]);
537 analyzer
538 .record_gradients("fc2", &vanishing)
539 .expect("Should record");
540
541 let exploding = Array1::from_vec(vec![5000.0, 10000.0, 8000.0]);
543 analyzer
544 .record_gradients("fc3", &exploding)
545 .expect("Should record");
546
547 let report = analyzer.get_health_report();
548
549 assert!(report.vanishing_layers.contains(&"fc2".to_string()));
550 assert!(report.exploding_layers.contains(&"fc3".to_string()));
551 assert!(report.healthy_layers.contains(&"fc1".to_string()));
552 assert_eq!(report.overall_health, GradientHealth::Critical);
553 assert!(!report.recommendations.is_empty());
554 }
555
556 #[test]
557 fn test_render_flow_chart_svg() {
558 let config = GradientFlowConfig::default();
559 let mut analyzer = GradientFlowAnalyzer::<f64>::new(config);
560
561 let grads1 = Array1::from_vec(vec![0.01, 0.02, 0.015]);
562 let grads2 = Array1::from_vec(vec![0.005, 0.003, 0.004]);
563 let grads3 = Array1::from_vec(vec![0.1, 0.08, 0.12]);
564
565 analyzer
566 .record_gradients("conv1", &grads1)
567 .expect("Should record");
568 analyzer
569 .record_gradients("conv2", &grads2)
570 .expect("Should record");
571 analyzer
572 .record_gradients("fc1", &grads3)
573 .expect("Should record");
574
575 let svg = analyzer
576 .render_flow_chart()
577 .expect("Should render flow chart");
578
579 assert!(svg.starts_with("<svg"));
580 assert!(svg.ends_with("</svg>"));
581 assert!(svg.contains("conv1"));
582 assert!(svg.contains("conv2"));
583 assert!(svg.contains("fc1"));
584 assert!(svg.contains("Gradient Flow Analysis"));
585 }
586}