1use crate::error::{OptimError, Result};
8use scirs2_core::ndarray::{Array1, Array2, ScalarOperand};
9use scirs2_core::numeric::Float;
10use std::fmt::Debug;
11
12#[derive(Debug, Clone, Default, PartialEq)]
14pub enum DirectionMethod {
15 #[default]
17 Random,
18 PCA,
20 FilterNormalized,
22}
23
24#[derive(Debug, Clone)]
26pub struct LossLandscapeConfig<A> {
27 pub grid_resolution: usize,
29 pub perturbation_range: A,
31 pub direction_method: DirectionMethod,
33}
34
35impl Default for LossLandscapeConfig<f64> {
36 fn default() -> Self {
37 Self {
38 grid_resolution: 20,
39 perturbation_range: 1.0,
40 direction_method: DirectionMethod::Random,
41 }
42 }
43}
44
45impl Default for LossLandscapeConfig<f32> {
46 fn default() -> Self {
47 Self {
48 grid_resolution: 20,
49 perturbation_range: 1.0f32,
50 direction_method: DirectionMethod::Random,
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct LandscapeData<A> {
58 pub grid: Array2<A>,
60 pub x_range: (A, A),
62 pub y_range: (A, A),
64 pub center_loss: A,
66 pub min_loss: A,
68 pub max_loss: A,
70}
71
72#[derive(Debug, Clone)]
74pub struct SaddlePointInfo<A> {
75 pub grid_x: usize,
77 pub grid_y: usize,
79 pub loss_value: A,
81}
82
83pub struct LossLandscapeAnalyzer<A> {
85 config: LossLandscapeConfig<A>,
87}
88
89impl<A> LossLandscapeAnalyzer<A>
90where
91 A: Float + ScalarOperand + Debug + std::iter::Sum,
92{
93 pub fn new(config: LossLandscapeConfig<A>) -> Self {
95 Self { config }
96 }
97
98 pub fn compute_landscape<F>(
109 &self,
110 params: &Array1<A>,
111 loss_fn: F,
112 dir1: &Array1<A>,
113 dir2: &Array1<A>,
114 ) -> Result<LandscapeData<A>>
115 where
116 F: Fn(&Array1<A>) -> Result<A>,
117 {
118 let n = self.config.grid_resolution;
119 if n == 0 {
120 return Err(OptimError::InvalidConfig(
121 "Grid resolution must be positive".to_string(),
122 ));
123 }
124 if params.len() != dir1.len() || params.len() != dir2.len() {
125 return Err(OptimError::DimensionMismatch(format!(
126 "Parameter dimension ({}) must match direction dimensions ({}, {})",
127 params.len(),
128 dir1.len(),
129 dir2.len()
130 )));
131 }
132
133 let range = self.config.perturbation_range;
134 let neg_range = A::zero() - range;
135
136 let mut grid = Array2::zeros((n, n));
137 let mut min_loss = A::infinity();
138 let mut max_loss = A::neg_infinity();
139 let mut center_loss = A::zero();
140
141 let n_minus_1 = if n > 1 {
142 A::from(n - 1).ok_or_else(|| {
143 OptimError::ComputationError("Failed to convert grid size".to_string())
144 })?
145 } else {
146 A::one()
147 };
148
149 let two = A::from(2.0).ok_or_else(|| {
150 OptimError::ComputationError("Failed to convert constant".to_string())
151 })?;
152
153 for i in 0..n {
154 let alpha = neg_range
155 + (A::from(i).ok_or_else(|| {
156 OptimError::ComputationError("Failed to convert index".to_string())
157 })? / n_minus_1)
158 * two
159 * range;
160
161 for j in 0..n {
162 let beta = neg_range
163 + (A::from(j).ok_or_else(|| {
164 OptimError::ComputationError("Failed to convert index".to_string())
165 })? / n_minus_1)
166 * two
167 * range;
168
169 let perturbed = params
171 .iter()
172 .zip(dir1.iter())
173 .zip(dir2.iter())
174 .map(|((&p, &d1), &d2)| p + alpha * d1 + beta * d2)
175 .collect::<Vec<A>>();
176 let perturbed = Array1::from_vec(perturbed);
177
178 let loss = loss_fn(&perturbed)?;
179 grid[[i, j]] = loss;
180
181 if loss < min_loss {
182 min_loss = loss;
183 }
184 if loss > max_loss {
185 max_loss = loss;
186 }
187
188 if (n > 1 && i == n / 2 && j == n / 2) || n == 1 {
190 center_loss = loss;
191 }
192 }
193 }
194
195 Ok(LandscapeData {
196 grid,
197 x_range: (neg_range, range),
198 y_range: (neg_range, range),
199 center_loss,
200 min_loss,
201 max_loss,
202 })
203 }
204
205 pub fn compute_sharpness<F>(&self, params: &Array1<A>, loss_fn: &F, epsilon: A) -> Result<A>
216 where
217 F: Fn(&Array1<A>) -> Result<A>,
218 {
219 let center_loss = loss_fn(params)?;
220 let dim = params.len();
221
222 if dim == 0 {
223 return Err(OptimError::InvalidParameter(
224 "Parameter array must not be empty".to_string(),
225 ));
226 }
227
228 let mut max_loss = center_loss;
229
230 for d in 0..dim {
232 let mut perturbed_pos = params.to_owned();
234 perturbed_pos[d] = perturbed_pos[d] + epsilon;
235 let loss_pos = loss_fn(&perturbed_pos)?;
236 if loss_pos > max_loss {
237 max_loss = loss_pos;
238 }
239
240 let mut perturbed_neg = params.to_owned();
242 perturbed_neg[d] = perturbed_neg[d] - epsilon;
243 let loss_neg = loss_fn(&perturbed_neg)?;
244 if loss_neg > max_loss {
245 max_loss = loss_neg;
246 }
247 }
248
249 let dim_f = A::from(dim).ok_or_else(|| {
252 OptimError::ComputationError("Failed to convert dimension".to_string())
253 })?;
254 let scaled_eps = epsilon / dim_f.sqrt();
255
256 let diag_pos: Array1<A> = params.mapv(|p| p + scaled_eps);
258 let loss_diag_pos = loss_fn(&diag_pos)?;
259 if loss_diag_pos > max_loss {
260 max_loss = loss_diag_pos;
261 }
262
263 let diag_neg: Array1<A> = params.mapv(|p| p - scaled_eps);
265 let loss_diag_neg = loss_fn(&diag_neg)?;
266 if loss_diag_neg > max_loss {
267 max_loss = loss_diag_neg;
268 }
269
270 Ok(max_loss - center_loss)
271 }
272
273 pub fn find_saddle_points(&self, landscape: &LandscapeData<A>) -> Vec<SaddlePointInfo<A>> {
282 let (rows, cols) = landscape.grid.dim();
283 let mut saddle_points = Vec::new();
284
285 for i in 1..rows.saturating_sub(1) {
287 for j in 1..cols.saturating_sub(1) {
288 let center = landscape.grid[[i, j]];
289
290 let neighbors = [
292 landscape.grid[[i - 1, j - 1]],
293 landscape.grid[[i - 1, j]],
294 landscape.grid[[i - 1, j + 1]],
295 landscape.grid[[i, j - 1]],
296 landscape.grid[[i, j + 1]],
297 landscape.grid[[i + 1, j - 1]],
298 landscape.grid[[i + 1, j]],
299 landscape.grid[[i + 1, j + 1]],
300 ];
301
302 let has_higher = neighbors.iter().any(|&n| n > center);
303 let has_lower = neighbors.iter().any(|&n| n < center);
304
305 if has_higher && has_lower {
308 let higher_count = neighbors.iter().filter(|&&n| n > center).count();
311 let lower_count = neighbors.iter().filter(|&&n| n < center).count();
312
313 if higher_count >= 2 && lower_count >= 2 {
316 saddle_points.push(SaddlePointInfo {
317 grid_x: i,
318 grid_y: j,
319 loss_value: center,
320 });
321 }
322 }
323 }
324 }
325
326 saddle_points
327 }
328
329 pub fn render_contour_plot(&self, landscape: &LandscapeData<A>) -> Result<String> {
334 let (rows, cols) = landscape.grid.dim();
335 if rows == 0 || cols == 0 {
336 return Err(OptimError::InvalidState(
337 "Landscape grid is empty".to_string(),
338 ));
339 }
340
341 let cell_size = 15;
342 let margin = 60;
343 let width = margin + cols * cell_size + margin;
344 let height = margin + rows * cell_size + margin;
345
346 let min_loss = landscape.min_loss.to_f64().unwrap_or(0.0);
347 let max_loss = landscape.max_loss.to_f64().unwrap_or(1.0);
348 let loss_range = if (max_loss - min_loss).abs() < 1e-15 {
349 1.0
350 } else {
351 max_loss - min_loss
352 };
353
354 let mut svg = format!(
355 r#"<svg xmlns="http://www.w3.org/2000/svg" width="{}" height="{}" viewBox="0 0 {} {}">"#,
356 width, height, width, height
357 );
358 svg.push('\n');
359
360 svg.push_str(&format!(
362 r#" <text x="{}" y="25" text-anchor="middle" font-size="16" font-weight="bold">Loss Landscape</text>"#,
363 width / 2
364 ));
365 svg.push('\n');
366
367 svg.push_str(&format!(
369 r#" <text x="{}" y="{}" text-anchor="middle" font-size="12">Direction 1</text>"#,
370 margin + cols * cell_size / 2,
371 height - 10
372 ));
373 svg.push('\n');
374
375 svg.push_str(&format!(
376 r#" <text x="15" y="{}" text-anchor="middle" font-size="12" transform="rotate(-90, 15, {})">Direction 2</text>"#,
377 margin + rows * cell_size / 2,
378 margin + rows * cell_size / 2
379 ));
380 svg.push('\n');
381
382 for i in 0..rows {
384 for j in 0..cols {
385 let val = landscape.grid[[i, j]].to_f64().unwrap_or(0.0);
386 let normalized = (val - min_loss) / loss_range;
387 let normalized = normalized.clamp(0.0, 1.0);
389
390 let color = loss_value_to_color(normalized);
391
392 let x = margin + j * cell_size;
393 let y = margin + i * cell_size;
394
395 svg.push_str(&format!(
396 r#" <rect x="{}" y="{}" width="{}" height="{}" fill="{}"/>"#,
397 x, y, cell_size, cell_size, color
398 ));
399 svg.push('\n');
400 }
401 }
402
403 let legend_x = margin + cols * cell_size + 10;
405 let legend_height = rows * cell_size;
406 let legend_steps = 10;
407 let step_height = legend_height / legend_steps;
408
409 for s in 0..legend_steps {
410 let normalized = 1.0 - (s as f64 / legend_steps as f64);
411 let color = loss_value_to_color(normalized);
412 let y = margin + s * step_height;
413
414 svg.push_str(&format!(
415 r#" <rect x="{}" y="{}" width="15" height="{}" fill="{}"/>"#,
416 legend_x, y, step_height, color
417 ));
418 svg.push('\n');
419 }
420
421 svg.push_str(&format!(
423 r#" <text x="{}" y="{}" font-size="9">{:.2e}</text>"#,
424 legend_x + 20,
425 margin + 10,
426 max_loss
427 ));
428 svg.push('\n');
429 svg.push_str(&format!(
430 r#" <text x="{}" y="{}" font-size="9">{:.2e}</text>"#,
431 legend_x + 20,
432 margin + legend_height,
433 min_loss
434 ));
435 svg.push('\n');
436
437 svg.push_str("</svg>");
438 Ok(svg)
439 }
440}
441
442fn loss_value_to_color(normalized: f64) -> String {
446 let (r, g, b) = if normalized < 0.5 {
447 let t = normalized * 2.0;
449 (0.0, t, 1.0 - t)
450 } else {
451 let t = (normalized - 0.5) * 2.0;
453 (t, 1.0 - t, 0.0)
454 };
455
456 format!(
457 "rgb({},{},{})",
458 (r * 255.0) as u8,
459 (g * 255.0) as u8,
460 (b * 255.0) as u8
461 )
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use scirs2_core::ndarray::Array1;
468
469 #[test]
470 fn test_compute_landscape_quadratic() {
471 let config = LossLandscapeConfig {
472 grid_resolution: 5,
473 perturbation_range: 1.0,
474 direction_method: DirectionMethod::Random,
475 };
476 let analyzer = LossLandscapeAnalyzer::<f64>::new(config);
477
478 let params = Array1::from_vec(vec![0.0, 0.0]);
480 let dir1 = Array1::from_vec(vec![1.0, 0.0]);
481 let dir2 = Array1::from_vec(vec![0.0, 1.0]);
482
483 let loss_fn = |p: &Array1<f64>| -> Result<f64> { Ok(p.iter().map(|&x| x * x).sum()) };
484
485 let landscape = analyzer
486 .compute_landscape(¶ms, loss_fn, &dir1, &dir2)
487 .expect("Should compute landscape");
488
489 assert_eq!(landscape.grid.dim(), (5, 5));
490 assert!(landscape.center_loss >= 0.0);
492 assert!(landscape.min_loss >= 0.0);
494 assert!((landscape.max_loss - 2.0).abs() < 1e-10);
496 }
497
498 #[test]
499 fn test_compute_sharpness() {
500 let config: LossLandscapeConfig<f64> = LossLandscapeConfig::default();
501 let analyzer = LossLandscapeAnalyzer::new(config);
502
503 let params = Array1::from_vec(vec![0.0, 0.0]);
505 let loss_fn = |p: &Array1<f64>| -> Result<f64> { Ok(p.iter().map(|&x| x * x).sum()) };
506
507 let epsilon = 0.1;
508 let sharpness = analyzer
509 .compute_sharpness(¶ms, &loss_fn, epsilon)
510 .expect("Should compute sharpness");
511
512 assert!(sharpness > 0.0);
517 assert!((sharpness - 0.01).abs() < 1e-10);
518 }
519
520 #[test]
521 fn test_find_saddle_points() {
522 let config = LossLandscapeConfig {
525 grid_resolution: 11,
526 perturbation_range: 1.0,
527 direction_method: DirectionMethod::Random,
528 };
529 let analyzer = LossLandscapeAnalyzer::<f64>::new(config);
530
531 let params = Array1::from_vec(vec![0.0, 0.0]);
532 let dir1 = Array1::from_vec(vec![1.0, 0.0]);
533 let dir2 = Array1::from_vec(vec![0.0, 1.0]);
534
535 let loss_fn = |p: &Array1<f64>| -> Result<f64> { Ok(p[0] * p[0] - p[1] * p[1]) };
537
538 let landscape = analyzer
539 .compute_landscape(¶ms, loss_fn, &dir1, &dir2)
540 .expect("Should compute landscape");
541
542 let saddle_points = analyzer.find_saddle_points(&landscape);
543
544 assert!(
546 !saddle_points.is_empty(),
547 "Should detect saddle points in x^2 - y^2"
548 );
549
550 let has_center = saddle_points
552 .iter()
553 .any(|sp| sp.grid_x == 5 && sp.grid_y == 5);
554 assert!(
555 has_center,
556 "Center of x^2 - y^2 landscape should be a saddle point"
557 );
558 }
559
560 #[test]
561 fn test_render_contour_plot_svg() {
562 let config = LossLandscapeConfig {
563 grid_resolution: 5,
564 perturbation_range: 1.0,
565 direction_method: DirectionMethod::Random,
566 };
567 let analyzer = LossLandscapeAnalyzer::<f64>::new(config);
568
569 let params = Array1::from_vec(vec![0.0, 0.0]);
570 let dir1 = Array1::from_vec(vec![1.0, 0.0]);
571 let dir2 = Array1::from_vec(vec![0.0, 1.0]);
572
573 let loss_fn = |p: &Array1<f64>| -> Result<f64> { Ok(p.iter().map(|&x| x * x).sum()) };
574
575 let landscape = analyzer
576 .compute_landscape(¶ms, loss_fn, &dir1, &dir2)
577 .expect("Should compute landscape");
578
579 let svg = analyzer
580 .render_contour_plot(&landscape)
581 .expect("Should render contour plot");
582
583 assert!(svg.starts_with("<svg"));
584 assert!(svg.ends_with("</svg>"));
585 assert!(svg.contains("Loss Landscape"));
586 assert!(svg.contains("Direction 1"));
587 assert!(svg.contains("Direction 2"));
588 assert!(svg.contains("rect"));
589 }
590
591 #[test]
592 fn test_landscape_config_defaults() {
593 let config: LossLandscapeConfig<f64> = LossLandscapeConfig::default();
594 assert_eq!(config.grid_resolution, 20);
595 assert!((config.perturbation_range - 1.0).abs() < 1e-15);
596 assert_eq!(config.direction_method, DirectionMethod::Random);
597
598 let config32: LossLandscapeConfig<f32> = LossLandscapeConfig::default();
599 assert_eq!(config32.grid_resolution, 20);
600 assert!((config32.perturbation_range - 1.0f32).abs() < 1e-6);
601 }
602}