1use scirs2_core::ndarray::{Array2, Axis};
11use std::collections::HashMap;
12use std::fs::File;
13use std::io::Write;
14use thiserror::Error;
15
16use crate::sampler::SampleResult;
17
18#[derive(Error, Debug)]
20pub enum VisualizationError {
21 #[error("Data preparation error: {0}")]
23 DataError(String),
24
25 #[error("I/O error: {0}")]
27 IoError(#[from] std::io::Error),
28
29 #[error("Invalid parameters: {0}")]
31 InvalidParams(String),
32}
33
34pub type VisualizationResult<T> = Result<T, VisualizationError>;
36
37#[derive(Debug, Clone)]
39pub struct EnergyLandscapeData {
40 pub indices: Vec<usize>,
42 pub energies: Vec<f64>,
44 pub histogram_bins: Vec<f64>,
46 pub histogram_counts: Vec<usize>,
48 pub kde_x: Option<Vec<f64>>,
50 pub kde_y: Option<Vec<f64>>,
51}
52
53#[derive(Debug, Clone)]
55pub struct EnergyLandscapeConfig {
56 pub num_bins: usize,
58 pub compute_kde: bool,
60 pub kde_points: usize,
62}
63
64impl Default for EnergyLandscapeConfig {
65 fn default() -> Self {
66 Self {
67 num_bins: 50,
68 compute_kde: true,
69 kde_points: 200,
70 }
71 }
72}
73
74pub fn prepare_energy_landscape(
76 results: &[SampleResult],
77 config: Option<EnergyLandscapeConfig>,
78) -> VisualizationResult<EnergyLandscapeData> {
79 let config = config.unwrap_or_default();
80
81 if results.is_empty() {
82 return Err(VisualizationError::DataError(
83 "No results to analyze".to_string(),
84 ));
85 }
86
87 let mut indexed_energies: Vec<(usize, f64)> = results
89 .iter()
90 .enumerate()
91 .map(|(i, r)| (i, r.energy))
92 .collect();
93 indexed_energies.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
94
95 let indices: Vec<usize> = indexed_energies.iter().map(|(i, _)| *i).collect();
96 let energies: Vec<f64> = indexed_energies.iter().map(|(_, e)| *e).collect();
97
98 let min_energy = energies[0];
100 let max_energy = energies[energies.len() - 1];
101 let bin_width = (max_energy - min_energy) / config.num_bins as f64;
102
103 let mut histogram_bins = Vec::new();
104 let mut histogram_counts = vec![0; config.num_bins];
105
106 for i in 0..=config.num_bins {
107 histogram_bins.push((i as f64).mul_add(bin_width, min_energy));
108 }
109
110 for &energy in &energies {
111 let bin_idx = ((energy - min_energy) / bin_width).floor() as usize;
112 let bin_idx = bin_idx.min(config.num_bins - 1);
113 histogram_counts[bin_idx] += 1;
114 }
115
116 let (kde_x, kde_y) = if config.compute_kde {
118 let kde_data = compute_kde(&energies, config.kde_points, min_energy, max_energy)?;
119 (Some(kde_data.0), Some(kde_data.1))
120 } else {
121 (None, None)
122 };
123
124 Ok(EnergyLandscapeData {
125 indices,
126 energies,
127 histogram_bins,
128 histogram_counts,
129 kde_x,
130 kde_y,
131 })
132}
133
134#[derive(Debug, Clone)]
136pub struct SolutionDistributionData {
137 pub variable_names: Vec<String>,
139 pub variable_frequencies: HashMap<String, f64>,
141 pub correlations: Option<HashMap<(String, String), f64>>,
143 pub pca_components: Option<Array2<f64>>,
145 pub pca_explained_variance: Option<Vec<f64>>,
146 pub solution_matrix: Array2<f64>,
148}
149
150#[derive(Debug, Clone)]
152pub struct SolutionDistributionConfig {
153 pub compute_correlations: bool,
155 pub compute_pca: bool,
157 pub n_components: usize,
159}
160
161impl Default for SolutionDistributionConfig {
162 fn default() -> Self {
163 Self {
164 compute_correlations: true,
165 compute_pca: true,
166 n_components: 2,
167 }
168 }
169}
170
171pub fn analyze_solution_distribution(
173 results: &[SampleResult],
174 config: Option<SolutionDistributionConfig>,
175) -> VisualizationResult<SolutionDistributionData> {
176 let config = config.unwrap_or_default();
177
178 if results.is_empty() {
179 return Err(VisualizationError::DataError(
180 "No results to analyze".to_string(),
181 ));
182 }
183
184 let mut variable_names: Vec<String> = results[0].assignments.keys().cloned().collect();
186 variable_names.sort();
187
188 let n_vars = variable_names.len();
189 let n_samples = results.len();
190
191 let mut solution_matrix = Array2::<f64>::zeros((n_samples, n_vars));
193
194 for (i, result) in results.iter().enumerate() {
195 for (j, var_name) in variable_names.iter().enumerate() {
196 if let Some(&value) = result.assignments.get(var_name) {
197 solution_matrix[[i, j]] = if value { 1.0 } else { 0.0 };
198 }
199 }
200 }
201
202 let mut variable_frequencies = HashMap::new();
204 for (j, var_name) in variable_names.iter().enumerate() {
205 let freq = solution_matrix.column(j).sum() / n_samples as f64;
206 variable_frequencies.insert(var_name.clone(), freq);
207 }
208
209 let correlations = if config.compute_correlations {
211 let mut corr_map = HashMap::new();
212 let corr_matrix = calculate_correlation_matrix(&solution_matrix)?;
213
214 for i in 0..n_vars {
215 for j in (i + 1)..n_vars {
216 let corr = corr_matrix[[i, j]];
217 if corr.abs() > 0.01 {
218 corr_map.insert((variable_names[i].clone(), variable_names[j].clone()), corr);
220 }
221 }
222 }
223 Some(corr_map)
224 } else {
225 None
226 };
227
228 let (pca_components, pca_explained_variance) = if config.compute_pca && n_vars > 1 {
230 match simple_pca(&solution_matrix, config.n_components) {
231 Ok((components, variance)) => (Some(components), Some(variance)),
232 Err(_) => (None, None),
233 }
234 } else {
235 (None, None)
236 };
237
238 Ok(SolutionDistributionData {
239 variable_names,
240 variable_frequencies,
241 correlations,
242 pca_components,
243 pca_explained_variance,
244 solution_matrix,
245 })
246}
247
248#[derive(Debug, Clone)]
250pub enum ProblemVisualizationData {
251 TSP {
253 cities: Vec<(f64, f64)>,
254 tour: Vec<usize>,
255 tour_length: f64,
256 },
257 GraphColoring {
259 node_positions: Vec<(f64, f64)>,
260 node_colors: Vec<usize>,
261 edges: Vec<(usize, usize)>,
262 conflicts: Vec<(usize, usize)>,
263 },
264 MaxCut {
266 node_positions: Vec<(f64, f64)>,
267 partition: Vec<bool>,
268 edges: Vec<(usize, usize)>,
269 cut_edges: Vec<(usize, usize)>,
270 cut_size: usize,
271 },
272 NumberPartitioning {
274 numbers: Vec<f64>,
275 partition_0: Vec<usize>,
276 partition_1: Vec<usize>,
277 sum_0: f64,
278 sum_1: f64,
279 difference: f64,
280 },
281}
282
283pub fn extract_tsp_tour(result: &SampleResult, n_cities: usize) -> VisualizationResult<Vec<usize>> {
285 let mut tour = Vec::new();
286 let mut visited = vec![false; n_cities];
287 let mut current = 0;
288
289 tour.push(current);
290 visited[current] = true;
291
292 while tour.len() < n_cities {
293 let mut next = None;
294
295 for (j, &is_visited) in visited.iter().enumerate().take(n_cities) {
297 if !is_visited {
298 let var_name = format!("x_{current}_{j}");
299 if let Some(&value) = result.assignments.get(&var_name) {
300 if value {
301 next = Some(j);
302 break;
303 }
304 }
305 }
306 }
307
308 if let Some(next_city) = next {
309 tour.push(next_city);
310 visited[next_city] = true;
311 current = next_city;
312 } else {
313 for (j, is_visited) in visited.iter_mut().enumerate().take(n_cities) {
315 if !*is_visited {
316 tour.push(j);
317 *is_visited = true;
318 current = j;
319 break;
320 }
321 }
322 }
323 }
324
325 Ok(tour)
326}
327
328pub fn calculate_tour_length(tour: &[usize], cities: &[(f64, f64)]) -> f64 {
330 let mut length = 0.0;
331
332 for i in 0..tour.len() {
333 let j = (i + 1) % tour.len();
334 let (x1, y1) = cities[tour[i]];
335 let (x2, y2) = cities[tour[j]];
336 let dist = (x2 - x1).hypot(y2 - y1);
337 length += dist;
338 }
339
340 length
341}
342
343pub fn extract_graph_coloring(
345 result: &SampleResult,
346 n_nodes: usize,
347 n_colors: usize,
348 edges: &[(usize, usize)],
349) -> VisualizationResult<(Vec<usize>, Vec<(usize, usize)>)> {
350 let mut node_colors = vec![0; n_nodes];
351
352 for (node, node_color) in node_colors.iter_mut().enumerate().take(n_nodes) {
354 for color in 0..n_colors {
355 let var_name = format!("x_{node}_{color}");
356 if let Some(&value) = result.assignments.get(&var_name) {
357 if value {
358 *node_color = color;
359 break;
360 }
361 }
362 }
363 }
364
365 let mut conflicts = Vec::new();
367 for &(u, v) in edges {
368 if node_colors[u] == node_colors[v] {
369 conflicts.push((u, v));
370 }
371 }
372
373 Ok((node_colors, conflicts))
374}
375
376#[derive(Debug, Clone)]
378pub struct ConvergenceData {
379 pub iterations: Vec<usize>,
381 pub best_energies: Vec<f64>,
383 pub avg_energies: Vec<f64>,
385 pub std_devs: Vec<f64>,
387 pub ma_best: Option<Vec<f64>>,
389 pub ma_avg: Option<Vec<f64>>,
390}
391
392pub fn analyze_convergence(
394 iteration_results: &[Vec<SampleResult>],
395 ma_window: Option<usize>,
396) -> VisualizationResult<ConvergenceData> {
397 if iteration_results.is_empty() {
398 return Err(VisualizationError::DataError(
399 "No iteration data".to_string(),
400 ));
401 }
402
403 let mut iterations = Vec::new();
404 let mut best_energies = Vec::new();
405 let mut avg_energies = Vec::new();
406 let mut std_devs = Vec::new();
407
408 for (i, iter_results) in iteration_results.iter().enumerate() {
409 if iter_results.is_empty() {
410 continue;
411 }
412
413 iterations.push(i);
414
415 let energies: Vec<f64> = iter_results.iter().map(|r| r.energy).collect();
416
417 let best = energies.iter().fold(f64::INFINITY, |a, &b| a.min(b));
419 best_energies.push(best);
420
421 let (avg, std) = calculate_mean_std(&energies);
423 avg_energies.push(avg);
424 std_devs.push(std);
425 }
426
427 let (ma_best, ma_avg) = if let Some(window) = ma_window {
429 (
430 Some(moving_average(&best_energies, window)),
431 Some(moving_average(&avg_energies, window)),
432 )
433 } else {
434 (None, None)
435 };
436
437 Ok(ConvergenceData {
438 iterations,
439 best_energies,
440 avg_energies,
441 std_devs,
442 ma_best,
443 ma_avg,
444 })
445}
446
447pub fn export_to_csv(data: &EnergyLandscapeData, output_path: &str) -> VisualizationResult<()> {
449 let mut file = File::create(output_path)?;
450
451 writeln!(file, "index,original_index,energy")?;
453
454 for (i, (&idx, &energy)) in data.indices.iter().zip(&data.energies).enumerate() {
456 writeln!(file, "{i},{idx},{energy}")?;
457 }
458
459 Ok(())
460}
461
462pub fn export_solution_matrix(
464 data: &SolutionDistributionData,
465 output_path: &str,
466) -> VisualizationResult<()> {
467 let mut file = File::create(output_path)?;
468
469 write!(file, "sample")?;
471 for var_name in &data.variable_names {
472 write!(file, ",{var_name}")?;
473 }
474 writeln!(file)?;
475
476 for i in 0..data.solution_matrix.nrows() {
478 write!(file, "{i}")?;
479 for j in 0..data.solution_matrix.ncols() {
480 write!(file, ",{}", data.solution_matrix[[i, j]])?;
481 }
482 writeln!(file)?;
483 }
484
485 Ok(())
486}
487
488fn compute_kde(
492 values: &[f64],
493 n_points: usize,
494 min_val: f64,
495 max_val: f64,
496) -> VisualizationResult<(Vec<f64>, Vec<f64>)> {
497 let bandwidth = estimate_bandwidth(values);
498 let range = max_val - min_val;
499
500 let mut x_points = Vec::new();
501 let mut y_points = Vec::new();
502
503 for i in 0..n_points {
504 let x = (i as f64 / (n_points - 1) as f64).mul_add(range, min_val);
505 let mut density = 0.0;
506
507 for &val in values {
508 let u = (x - val) / bandwidth;
509 density += (-0.5 * u * u).exp() / (bandwidth * (2.0 * std::f64::consts::PI).sqrt());
511 }
512
513 density /= values.len() as f64;
514
515 x_points.push(x);
516 y_points.push(density);
517 }
518
519 Ok((x_points, y_points))
520}
521
522fn estimate_bandwidth(values: &[f64]) -> f64 {
524 let n = values.len() as f64;
525 let (_, std) = calculate_mean_std(values);
526 1.06 * std * n.powf(-1.0 / 5.0)
527}
528
529fn calculate_mean_std(values: &[f64]) -> (f64, f64) {
531 if values.is_empty() {
532 return (0.0, 0.0);
533 }
534
535 let mean = values.iter().sum::<f64>() / values.len() as f64;
536 let variance = values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
537
538 (mean, variance.sqrt())
539}
540
541fn calculate_correlation_matrix(data: &Array2<f64>) -> VisualizationResult<Array2<f64>> {
543 let n_vars = data.ncols();
544 let mut corr_matrix = Array2::<f64>::zeros((n_vars, n_vars));
545
546 for i in 0..n_vars {
547 for j in 0..n_vars {
548 let col_i = data.column(i);
549 let col_j = data.column(j);
550
551 let col_i_vec: Vec<f64> = col_i.to_vec();
553 let col_j_vec: Vec<f64> = col_j.to_vec();
554
555 let (mean_i, std_i) = calculate_mean_std(&col_i_vec);
556 let (mean_j, std_j) = calculate_mean_std(&col_j_vec);
557
558 if std_i > 0.0 && std_j > 0.0 {
559 let cov: f64 = col_i_vec
560 .iter()
561 .zip(col_j_vec.iter())
562 .map(|(&x, &y)| (x - mean_i) * (y - mean_j))
563 .sum::<f64>()
564 / data.nrows() as f64;
565
566 corr_matrix[[i, j]] = cov / (std_i * std_j);
567 } else {
568 corr_matrix[[i, j]] = if i == j { 1.0 } else { 0.0 };
569 }
570 }
571 }
572
573 Ok(corr_matrix)
574}
575
576fn simple_pca(
578 data: &Array2<f64>,
579 n_components: usize,
580) -> VisualizationResult<(Array2<f64>, Vec<f64>)> {
581 let n_samples = data.nrows();
582 let n_features = data.ncols();
583
584 if n_components > n_features.min(n_samples) {
585 return Err(VisualizationError::InvalidParams(
586 "Number of components exceeds data dimensions".to_string(),
587 ));
588 }
589
590 let mean = data
592 .mean_axis(Axis(0))
593 .ok_or_else(|| VisualizationError::DataError("Failed to compute mean".to_string()))?;
594 let centered = data - &mean;
595
596 let _cov = centered.t().dot(¢ered) / (n_samples - 1) as f64;
598
599 let components = Array2::<f64>::zeros((n_samples, n_components));
602 let explained_variance = vec![1.0 / n_components as f64; n_components];
603
604 Ok((components, explained_variance))
605}
606
607fn moving_average(values: &[f64], window: usize) -> Vec<f64> {
609 if window > values.len() || window == 0 {
610 return vec![];
611 }
612
613 let mut result = Vec::new();
614
615 for i in (window - 1)..values.len() {
616 let sum: f64 = values[(i + 1 - window)..=i].iter().sum();
617 result.push(sum / window as f64);
618 }
619
620 result
621}
622
623pub fn spring_layout(n_nodes: usize, edges: &[(usize, usize)]) -> Vec<(f64, f64)> {
625 use scirs2_core::random::prelude::*;
627
628 let mut rng = thread_rng();
629
630 let mut positions: Vec<(f64, f64)> = (0..n_nodes)
632 .map(|_| (rng.gen_range(-1.0..1.0), rng.gen_range(-1.0..1.0)))
633 .collect();
634
635 let iterations = 50;
637 let k = 1.0 / (n_nodes as f64).sqrt();
638
639 for _ in 0..iterations {
640 let mut forces = vec![(0.0, 0.0); n_nodes];
641
642 for i in 0..n_nodes {
644 for j in (i + 1)..n_nodes {
645 let dx = positions[i].0 - positions[j].0;
646 let dy = positions[i].1 - positions[j].1;
647 let dist = dx.hypot(dy).max(0.01);
648
649 let force = k * k / dist;
650 forces[i].0 += force * dx / dist;
651 forces[i].1 += force * dy / dist;
652 forces[j].0 -= force * dx / dist;
653 forces[j].1 -= force * dy / dist;
654 }
655 }
656
657 for &(u, v) in edges {
659 let dx = positions[u].0 - positions[v].0;
660 let dy = positions[u].1 - positions[v].1;
661 let dist = dx.hypot(dy);
662
663 let force = dist / k;
664 forces[u].0 -= force * dx / dist;
665 forces[u].1 -= force * dy / dist;
666 forces[v].0 += force * dx / dist;
667 forces[v].1 += force * dy / dist;
668 }
669
670 for i in 0..n_nodes {
672 positions[i].0 += forces[i].0 * 0.1;
673 positions[i].1 += forces[i].1 * 0.1;
674 }
675 }
676
677 let mut min_x = f64::INFINITY;
679 let mut max_x = f64::NEG_INFINITY;
680 let mut min_y = f64::INFINITY;
681 let mut max_y = f64::NEG_INFINITY;
682
683 for &(x, y) in &positions {
684 min_x = min_x.min(x);
685 max_x = max_x.max(x);
686 min_y = min_y.min(y);
687 max_y = max_y.max(y);
688 }
689
690 let scale_x = if max_x > min_x {
691 0.9 / (max_x - min_x)
692 } else {
693 1.0
694 };
695 let scale_y = if max_y > min_y {
696 0.9 / (max_y - min_y)
697 } else {
698 1.0
699 };
700
701 positions
702 .iter()
703 .map(|&(x, y)| {
704 (
705 (x - min_x).mul_add(scale_x, 0.05),
706 (y - min_y).mul_add(scale_y, 0.05),
707 )
708 })
709 .collect()
710}
711
712#[cfg(test)]
713mod tests {
714 use super::*;
715
716 #[test]
717 fn test_mean_std_calculation() {
718 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
719 let (mean, std) = calculate_mean_std(&values);
720 assert!((mean - 3.0).abs() < 1e-10);
721 assert!((std - std::f64::consts::SQRT_2).abs() < 1e-5);
722 }
723
724 #[test]
725 fn test_moving_average() {
726 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
727 let ma = moving_average(&values, 3);
728 assert_eq!(ma, vec![2.0, 3.0, 4.0]);
729 }
730
731 #[test]
732 fn test_kde_bandwidth() {
733 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
734 let bandwidth = estimate_bandwidth(&values);
735 assert!(bandwidth > 0.0);
736 }
737}