Skip to main content

optirs_core/
loss_landscape.rs

1//! Loss Landscape Analysis Module
2//!
3//! Provides tools for visualizing and analyzing the loss landscape around
4//! a point in parameter space, including sharpness measurement, saddle point
5//! detection, and contour plot rendering.
6
7use crate::error::{OptimError, Result};
8use scirs2_core::ndarray::{Array1, Array2, ScalarOperand};
9use scirs2_core::numeric::Float;
10use std::fmt::Debug;
11
12/// Method for choosing perturbation directions in the landscape
13#[derive(Debug, Clone, Default, PartialEq)]
14pub enum DirectionMethod {
15    /// Random directions (normalized)
16    #[default]
17    Random,
18    /// PCA-based directions from optimization trajectory
19    PCA,
20    /// Filter-normalized directions (Li et al., 2018)
21    FilterNormalized,
22}
23
24/// Configuration for loss landscape analysis
25#[derive(Debug, Clone)]
26pub struct LossLandscapeConfig<A> {
27    /// Resolution of the evaluation grid (grid_resolution x grid_resolution)
28    pub grid_resolution: usize,
29    /// Range of perturbation along each direction: [-range, +range]
30    pub perturbation_range: A,
31    /// Method for choosing perturbation directions
32    pub direction_method: DirectionMethod,
33}
34
35impl Default for LossLandscapeConfig<f64> {
36    fn default() -> Self {
37        Self {
38            grid_resolution: 20,
39            perturbation_range: 1.0,
40            direction_method: DirectionMethod::Random,
41        }
42    }
43}
44
45impl Default for LossLandscapeConfig<f32> {
46    fn default() -> Self {
47        Self {
48            grid_resolution: 20,
49            perturbation_range: 1.0f32,
50            direction_method: DirectionMethod::Random,
51        }
52    }
53}
54
55/// 2D loss landscape data
56#[derive(Debug, Clone)]
57pub struct LandscapeData<A> {
58    /// 2D grid of loss values (grid_resolution x grid_resolution)
59    pub grid: Array2<A>,
60    /// Range of perturbation in direction 1: (min_alpha, max_alpha)
61    pub x_range: (A, A),
62    /// Range of perturbation in direction 2: (min_beta, max_beta)
63    pub y_range: (A, A),
64    /// Loss value at the center point (no perturbation)
65    pub center_loss: A,
66    /// Minimum loss value in the grid
67    pub min_loss: A,
68    /// Maximum loss value in the grid
69    pub max_loss: A,
70}
71
72/// Information about a detected saddle point in the landscape
73#[derive(Debug, Clone)]
74pub struct SaddlePointInfo<A> {
75    /// Grid x-coordinate of the saddle point
76    pub grid_x: usize,
77    /// Grid y-coordinate of the saddle point
78    pub grid_y: usize,
79    /// Loss value at the saddle point
80    pub loss_value: A,
81}
82
83/// Loss landscape analyzer for understanding optimization surfaces
84pub struct LossLandscapeAnalyzer<A> {
85    /// Configuration parameters
86    config: LossLandscapeConfig<A>,
87}
88
89impl<A> LossLandscapeAnalyzer<A>
90where
91    A: Float + ScalarOperand + Debug + std::iter::Sum,
92{
93    /// Create a new loss landscape analyzer with the given configuration
94    pub fn new(config: LossLandscapeConfig<A>) -> Self {
95        Self { config }
96    }
97
98    /// Compute the loss landscape on a 2D grid
99    ///
100    /// Evaluates `loss_fn(params + alpha * dir1 + beta * dir2)` for a grid of
101    /// (alpha, beta) values in `[-perturbation_range, +perturbation_range]`.
102    ///
103    /// # Arguments
104    /// * `params` - Center point in parameter space
105    /// * `loss_fn` - Function that computes loss for given parameters
106    /// * `dir1` - First perturbation direction (should be normalized)
107    /// * `dir2` - Second perturbation direction (should be normalized)
108    pub fn compute_landscape<F>(
109        &self,
110        params: &Array1<A>,
111        loss_fn: F,
112        dir1: &Array1<A>,
113        dir2: &Array1<A>,
114    ) -> Result<LandscapeData<A>>
115    where
116        F: Fn(&Array1<A>) -> Result<A>,
117    {
118        let n = self.config.grid_resolution;
119        if n == 0 {
120            return Err(OptimError::InvalidConfig(
121                "Grid resolution must be positive".to_string(),
122            ));
123        }
124        if params.len() != dir1.len() || params.len() != dir2.len() {
125            return Err(OptimError::DimensionMismatch(format!(
126                "Parameter dimension ({}) must match direction dimensions ({}, {})",
127                params.len(),
128                dir1.len(),
129                dir2.len()
130            )));
131        }
132
133        let range = self.config.perturbation_range;
134        let neg_range = A::zero() - range;
135
136        let mut grid = Array2::zeros((n, n));
137        let mut min_loss = A::infinity();
138        let mut max_loss = A::neg_infinity();
139        let mut center_loss = A::zero();
140
141        let n_minus_1 = if n > 1 {
142            A::from(n - 1).ok_or_else(|| {
143                OptimError::ComputationError("Failed to convert grid size".to_string())
144            })?
145        } else {
146            A::one()
147        };
148
149        let two = A::from(2.0).ok_or_else(|| {
150            OptimError::ComputationError("Failed to convert constant".to_string())
151        })?;
152
153        for i in 0..n {
154            let alpha = neg_range
155                + (A::from(i).ok_or_else(|| {
156                    OptimError::ComputationError("Failed to convert index".to_string())
157                })? / n_minus_1)
158                    * two
159                    * range;
160
161            for j in 0..n {
162                let beta = neg_range
163                    + (A::from(j).ok_or_else(|| {
164                        OptimError::ComputationError("Failed to convert index".to_string())
165                    })? / n_minus_1)
166                        * two
167                        * range;
168
169                // perturbed = params + alpha * dir1 + beta * dir2
170                let perturbed = params
171                    .iter()
172                    .zip(dir1.iter())
173                    .zip(dir2.iter())
174                    .map(|((&p, &d1), &d2)| p + alpha * d1 + beta * d2)
175                    .collect::<Vec<A>>();
176                let perturbed = Array1::from_vec(perturbed);
177
178                let loss = loss_fn(&perturbed)?;
179                grid[[i, j]] = loss;
180
181                if loss < min_loss {
182                    min_loss = loss;
183                }
184                if loss > max_loss {
185                    max_loss = loss;
186                }
187
188                // Track center point (when alpha ~ 0 and beta ~ 0)
189                if (n > 1 && i == n / 2 && j == n / 2) || n == 1 {
190                    center_loss = loss;
191                }
192            }
193        }
194
195        Ok(LandscapeData {
196            grid,
197            x_range: (neg_range, range),
198            y_range: (neg_range, range),
199            center_loss,
200            min_loss,
201            max_loss,
202        })
203    }
204
205    /// Compute the sharpness of the loss surface around a point
206    ///
207    /// Sharpness is defined as the maximum loss in a neighborhood of radius
208    /// `epsilon` minus the loss at the center point. This measures how "sharp"
209    /// or "flat" the minimum is - flatter minima tend to generalize better.
210    ///
211    /// # Arguments
212    /// * `params` - Center point in parameter space
213    /// * `loss_fn` - Function that computes loss for given parameters
214    /// * `epsilon` - Radius of the neighborhood to search
215    pub fn compute_sharpness<F>(&self, params: &Array1<A>, loss_fn: &F, epsilon: A) -> Result<A>
216    where
217        F: Fn(&Array1<A>) -> Result<A>,
218    {
219        let center_loss = loss_fn(params)?;
220        let dim = params.len();
221
222        if dim == 0 {
223            return Err(OptimError::InvalidParameter(
224                "Parameter array must not be empty".to_string(),
225            ));
226        }
227
228        let mut max_loss = center_loss;
229
230        // Sample along each coordinate axis in both directions
231        for d in 0..dim {
232            // Positive perturbation
233            let mut perturbed_pos = params.to_owned();
234            perturbed_pos[d] = perturbed_pos[d] + epsilon;
235            let loss_pos = loss_fn(&perturbed_pos)?;
236            if loss_pos > max_loss {
237                max_loss = loss_pos;
238            }
239
240            // Negative perturbation
241            let mut perturbed_neg = params.to_owned();
242            perturbed_neg[d] = perturbed_neg[d] - epsilon;
243            let loss_neg = loss_fn(&perturbed_neg)?;
244            if loss_neg > max_loss {
245                max_loss = loss_neg;
246            }
247        }
248
249        // Also sample along diagonal directions for better coverage
250        // Diagonal: all dimensions perturbed by epsilon / sqrt(dim)
251        let dim_f = A::from(dim).ok_or_else(|| {
252            OptimError::ComputationError("Failed to convert dimension".to_string())
253        })?;
254        let scaled_eps = epsilon / dim_f.sqrt();
255
256        // All-positive diagonal
257        let diag_pos: Array1<A> = params.mapv(|p| p + scaled_eps);
258        let loss_diag_pos = loss_fn(&diag_pos)?;
259        if loss_diag_pos > max_loss {
260            max_loss = loss_diag_pos;
261        }
262
263        // All-negative diagonal
264        let diag_neg: Array1<A> = params.mapv(|p| p - scaled_eps);
265        let loss_diag_neg = loss_fn(&diag_neg)?;
266        if loss_diag_neg > max_loss {
267            max_loss = loss_diag_neg;
268        }
269
270        Ok(max_loss - center_loss)
271    }
272
273    /// Find saddle points in the loss landscape
274    ///
275    /// A saddle point is a grid cell where the gradient is approximately zero
276    /// (local extremum behavior) but the point is neither a strict local minimum
277    /// nor a strict local maximum -- it has both higher and lower neighbors.
278    ///
279    /// # Arguments
280    /// * `landscape` - Previously computed landscape data
281    pub fn find_saddle_points(&self, landscape: &LandscapeData<A>) -> Vec<SaddlePointInfo<A>> {
282        let (rows, cols) = landscape.grid.dim();
283        let mut saddle_points = Vec::new();
284
285        // Skip border cells as they don't have full neighborhoods
286        for i in 1..rows.saturating_sub(1) {
287            for j in 1..cols.saturating_sub(1) {
288                let center = landscape.grid[[i, j]];
289
290                // Collect all 8 neighbors
291                let neighbors = [
292                    landscape.grid[[i - 1, j - 1]],
293                    landscape.grid[[i - 1, j]],
294                    landscape.grid[[i - 1, j + 1]],
295                    landscape.grid[[i, j - 1]],
296                    landscape.grid[[i, j + 1]],
297                    landscape.grid[[i + 1, j - 1]],
298                    landscape.grid[[i + 1, j]],
299                    landscape.grid[[i + 1, j + 1]],
300                ];
301
302                let has_higher = neighbors.iter().any(|&n| n > center);
303                let has_lower = neighbors.iter().any(|&n| n < center);
304
305                // A saddle point has both higher and lower neighbors,
306                // and the differences are small enough to suggest a near-zero gradient
307                if has_higher && has_lower {
308                    // Check that the point is not strongly a minimum or maximum:
309                    // count how many neighbors are higher vs lower
310                    let higher_count = neighbors.iter().filter(|&&n| n > center).count();
311                    let lower_count = neighbors.iter().filter(|&&n| n < center).count();
312
313                    // Saddle-like: roughly balanced directional behavior
314                    // (not overwhelmingly a basin or a peak)
315                    if higher_count >= 2 && lower_count >= 2 {
316                        saddle_points.push(SaddlePointInfo {
317                            grid_x: i,
318                            grid_y: j,
319                            loss_value: center,
320                        });
321                    }
322                }
323            }
324        }
325
326        saddle_points
327    }
328
329    /// Render an SVG contour plot of the loss landscape
330    ///
331    /// Produces an SVG string with filled rectangles colored by loss value,
332    /// creating a heat-map style visualization of the landscape.
333    pub fn render_contour_plot(&self, landscape: &LandscapeData<A>) -> Result<String> {
334        let (rows, cols) = landscape.grid.dim();
335        if rows == 0 || cols == 0 {
336            return Err(OptimError::InvalidState(
337                "Landscape grid is empty".to_string(),
338            ));
339        }
340
341        let cell_size = 15;
342        let margin = 60;
343        let width = margin + cols * cell_size + margin;
344        let height = margin + rows * cell_size + margin;
345
346        let min_loss = landscape.min_loss.to_f64().unwrap_or(0.0);
347        let max_loss = landscape.max_loss.to_f64().unwrap_or(1.0);
348        let loss_range = if (max_loss - min_loss).abs() < 1e-15 {
349            1.0
350        } else {
351            max_loss - min_loss
352        };
353
354        let mut svg = format!(
355            r#"<svg xmlns="http://www.w3.org/2000/svg" width="{}" height="{}" viewBox="0 0 {} {}">"#,
356            width, height, width, height
357        );
358        svg.push('\n');
359
360        // Title
361        svg.push_str(&format!(
362            r#"  <text x="{}" y="25" text-anchor="middle" font-size="16" font-weight="bold">Loss Landscape</text>"#,
363            width / 2
364        ));
365        svg.push('\n');
366
367        // Axis labels
368        svg.push_str(&format!(
369            r#"  <text x="{}" y="{}" text-anchor="middle" font-size="12">Direction 1</text>"#,
370            margin + cols * cell_size / 2,
371            height - 10
372        ));
373        svg.push('\n');
374
375        svg.push_str(&format!(
376            r#"  <text x="15" y="{}" text-anchor="middle" font-size="12" transform="rotate(-90, 15, {})">Direction 2</text>"#,
377            margin + rows * cell_size / 2,
378            margin + rows * cell_size / 2
379        ));
380        svg.push('\n');
381
382        // Draw cells as colored rectangles
383        for i in 0..rows {
384            for j in 0..cols {
385                let val = landscape.grid[[i, j]].to_f64().unwrap_or(0.0);
386                let normalized = (val - min_loss) / loss_range;
387                // Clamp to [0, 1]
388                let normalized = normalized.clamp(0.0, 1.0);
389
390                let color = loss_value_to_color(normalized);
391
392                let x = margin + j * cell_size;
393                let y = margin + i * cell_size;
394
395                svg.push_str(&format!(
396                    r#"  <rect x="{}" y="{}" width="{}" height="{}" fill="{}"/>"#,
397                    x, y, cell_size, cell_size, color
398                ));
399                svg.push('\n');
400            }
401        }
402
403        // Color bar legend
404        let legend_x = margin + cols * cell_size + 10;
405        let legend_height = rows * cell_size;
406        let legend_steps = 10;
407        let step_height = legend_height / legend_steps;
408
409        for s in 0..legend_steps {
410            let normalized = 1.0 - (s as f64 / legend_steps as f64);
411            let color = loss_value_to_color(normalized);
412            let y = margin + s * step_height;
413
414            svg.push_str(&format!(
415                r#"  <rect x="{}" y="{}" width="15" height="{}" fill="{}"/>"#,
416                legend_x, y, step_height, color
417            ));
418            svg.push('\n');
419        }
420
421        // Legend labels
422        svg.push_str(&format!(
423            r#"  <text x="{}" y="{}" font-size="9">{:.2e}</text>"#,
424            legend_x + 20,
425            margin + 10,
426            max_loss
427        ));
428        svg.push('\n');
429        svg.push_str(&format!(
430            r#"  <text x="{}" y="{}" font-size="9">{:.2e}</text>"#,
431            legend_x + 20,
432            margin + legend_height,
433            min_loss
434        ));
435        svg.push('\n');
436
437        svg.push_str("</svg>");
438        Ok(svg)
439    }
440}
441
442/// Convert a normalized loss value [0, 1] to an RGB color string
443///
444/// Uses a blue (low) -> green (mid) -> red (high) color scale.
445fn loss_value_to_color(normalized: f64) -> String {
446    let (r, g, b) = if normalized < 0.5 {
447        // Blue to Green
448        let t = normalized * 2.0;
449        (0.0, t, 1.0 - t)
450    } else {
451        // Green to Red
452        let t = (normalized - 0.5) * 2.0;
453        (t, 1.0 - t, 0.0)
454    };
455
456    format!(
457        "rgb({},{},{})",
458        (r * 255.0) as u8,
459        (g * 255.0) as u8,
460        (b * 255.0) as u8
461    )
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use scirs2_core::ndarray::Array1;
468
469    #[test]
470    fn test_compute_landscape_quadratic() {
471        let config = LossLandscapeConfig {
472            grid_resolution: 5,
473            perturbation_range: 1.0,
474            direction_method: DirectionMethod::Random,
475        };
476        let analyzer = LossLandscapeAnalyzer::<f64>::new(config);
477
478        // Simple quadratic loss: f(x) = sum(x_i^2)
479        let params = Array1::from_vec(vec![0.0, 0.0]);
480        let dir1 = Array1::from_vec(vec![1.0, 0.0]);
481        let dir2 = Array1::from_vec(vec![0.0, 1.0]);
482
483        let loss_fn = |p: &Array1<f64>| -> Result<f64> { Ok(p.iter().map(|&x| x * x).sum()) };
484
485        let landscape = analyzer
486            .compute_landscape(&params, loss_fn, &dir1, &dir2)
487            .expect("Should compute landscape");
488
489        assert_eq!(landscape.grid.dim(), (5, 5));
490        // At center (0,0), loss should be 0
491        assert!(landscape.center_loss >= 0.0);
492        // Min should be at center
493        assert!(landscape.min_loss >= 0.0);
494        // Max should be at corners (alpha=1, beta=1 => loss=2)
495        assert!((landscape.max_loss - 2.0).abs() < 1e-10);
496    }
497
498    #[test]
499    fn test_compute_sharpness() {
500        let config: LossLandscapeConfig<f64> = LossLandscapeConfig::default();
501        let analyzer = LossLandscapeAnalyzer::new(config);
502
503        // Quadratic bowl: f(x) = x_1^2 + x_2^2
504        let params = Array1::from_vec(vec![0.0, 0.0]);
505        let loss_fn = |p: &Array1<f64>| -> Result<f64> { Ok(p.iter().map(|&x| x * x).sum()) };
506
507        let epsilon = 0.1;
508        let sharpness = analyzer
509            .compute_sharpness(&params, &loss_fn, epsilon)
510            .expect("Should compute sharpness");
511
512        // At the origin, center_loss = 0
513        // Max in neighborhood: moving epsilon along one axis => epsilon^2 = 0.01
514        // Moving epsilon/sqrt(2) along diagonal => 2*(0.1/sqrt(2))^2 = 0.01
515        // So sharpness = 0.01 - 0 = 0.01
516        assert!(sharpness > 0.0);
517        assert!((sharpness - 0.01).abs() < 1e-10);
518    }
519
520    #[test]
521    fn test_find_saddle_points() {
522        // Create a landscape with a known saddle point
523        // f(x,y) = x^2 - y^2 has a saddle at origin
524        let config = LossLandscapeConfig {
525            grid_resolution: 11,
526            perturbation_range: 1.0,
527            direction_method: DirectionMethod::Random,
528        };
529        let analyzer = LossLandscapeAnalyzer::<f64>::new(config);
530
531        let params = Array1::from_vec(vec![0.0, 0.0]);
532        let dir1 = Array1::from_vec(vec![1.0, 0.0]);
533        let dir2 = Array1::from_vec(vec![0.0, 1.0]);
534
535        // Saddle function: x^2 - y^2
536        let loss_fn = |p: &Array1<f64>| -> Result<f64> { Ok(p[0] * p[0] - p[1] * p[1]) };
537
538        let landscape = analyzer
539            .compute_landscape(&params, loss_fn, &dir1, &dir2)
540            .expect("Should compute landscape");
541
542        let saddle_points = analyzer.find_saddle_points(&landscape);
543
544        // Should find at least one saddle point near the center
545        assert!(
546            !saddle_points.is_empty(),
547            "Should detect saddle points in x^2 - y^2"
548        );
549
550        // The center of the grid (5,5) should be among the saddle points
551        let has_center = saddle_points
552            .iter()
553            .any(|sp| sp.grid_x == 5 && sp.grid_y == 5);
554        assert!(
555            has_center,
556            "Center of x^2 - y^2 landscape should be a saddle point"
557        );
558    }
559
560    #[test]
561    fn test_render_contour_plot_svg() {
562        let config = LossLandscapeConfig {
563            grid_resolution: 5,
564            perturbation_range: 1.0,
565            direction_method: DirectionMethod::Random,
566        };
567        let analyzer = LossLandscapeAnalyzer::<f64>::new(config);
568
569        let params = Array1::from_vec(vec![0.0, 0.0]);
570        let dir1 = Array1::from_vec(vec![1.0, 0.0]);
571        let dir2 = Array1::from_vec(vec![0.0, 1.0]);
572
573        let loss_fn = |p: &Array1<f64>| -> Result<f64> { Ok(p.iter().map(|&x| x * x).sum()) };
574
575        let landscape = analyzer
576            .compute_landscape(&params, loss_fn, &dir1, &dir2)
577            .expect("Should compute landscape");
578
579        let svg = analyzer
580            .render_contour_plot(&landscape)
581            .expect("Should render contour plot");
582
583        assert!(svg.starts_with("<svg"));
584        assert!(svg.ends_with("</svg>"));
585        assert!(svg.contains("Loss Landscape"));
586        assert!(svg.contains("Direction 1"));
587        assert!(svg.contains("Direction 2"));
588        assert!(svg.contains("rect"));
589    }
590
591    #[test]
592    fn test_landscape_config_defaults() {
593        let config: LossLandscapeConfig<f64> = LossLandscapeConfig::default();
594        assert_eq!(config.grid_resolution, 20);
595        assert!((config.perturbation_range - 1.0).abs() < 1e-15);
596        assert_eq!(config.direction_method, DirectionMethod::Random);
597
598        let config32: LossLandscapeConfig<f32> = LossLandscapeConfig::default();
599        assert_eq!(config32.grid_resolution, 20);
600        assert!((config32.perturbation_range - 1.0f32).abs() < 1e-6);
601    }
602}