nuviz_cli/data/
aggregation.rs1use std::collections::HashMap;
2
3use crate::data::experiment::Experiment;
4
5#[derive(Debug, Clone)]
7#[allow(dead_code)]
8pub struct AggregatedMetrics {
9 pub mean: HashMap<String, f64>,
10 pub std: HashMap<String, f64>,
11 pub count: usize,
12}
13
14#[allow(dead_code)]
16pub fn aggregate_experiments(experiments: &[&Experiment]) -> AggregatedMetrics {
17 let n = experiments.len();
18 if n == 0 {
19 return AggregatedMetrics {
20 mean: HashMap::new(),
21 std: HashMap::new(),
22 count: 0,
23 };
24 }
25
26 let mut values_by_metric: HashMap<String, Vec<f64>> = HashMap::new();
28 for exp in experiments {
29 for (name, &value) in &exp.best_metrics {
30 if value.is_finite() {
31 values_by_metric
32 .entry(name.clone())
33 .or_default()
34 .push(value);
35 }
36 }
37 }
38
39 let mut mean = HashMap::new();
40 let mut std = HashMap::new();
41
42 for (name, values) in &values_by_metric {
43 let count = values.len() as f64;
44 let m = values.iter().sum::<f64>() / count;
45 mean.insert(name.clone(), m);
46
47 if values.len() > 1 {
48 let variance = values.iter().map(|v| (v - m).powi(2)).sum::<f64>() / (count - 1.0);
49 std.insert(name.clone(), variance.sqrt());
50 } else {
51 std.insert(name.clone(), 0.0);
52 }
53 }
54
55 AggregatedMetrics {
56 mean,
57 std,
58 count: n,
59 }
60}
61
62pub fn format_mean_std(mean: f64, std: f64) -> String {
64 if std.abs() < f64::EPSILON {
65 format!("{mean:.4}")
66 } else {
67 format!("{mean:.4}±{std:.4}")
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74 use std::path::PathBuf;
75
76 fn make_experiment(name: &str, metrics: &[(&str, f64)]) -> Experiment {
77 Experiment {
78 name: name.into(),
79 project: None,
80 dir: PathBuf::from("/tmp"),
81 status: "done".into(),
82 total_steps: Some(100),
83 best_metrics: metrics.iter().map(|(k, v)| (k.to_string(), *v)).collect(),
84 start_time: None,
85 end_time: None,
86 seed: None,
87 config_hash: None,
88 config: None,
89 tags: Vec::new(),
90 }
91 }
92
93 #[test]
94 fn test_aggregate_single() {
95 let exp = make_experiment("e1", &[("psnr", 28.0), ("loss", 0.05)]);
96 let result = aggregate_experiments(&[&exp]);
97 assert_eq!(result.count, 1);
98 assert!((result.mean["psnr"] - 28.0).abs() < f64::EPSILON);
99 assert!((result.std["psnr"]).abs() < f64::EPSILON);
100 }
101
102 #[test]
103 fn test_aggregate_multiple() {
104 let e1 = make_experiment("e1", &[("psnr", 28.0)]);
105 let e2 = make_experiment("e2", &[("psnr", 30.0)]);
106 let e3 = make_experiment("e3", &[("psnr", 29.0)]);
107 let result = aggregate_experiments(&[&e1, &e2, &e3]);
108 assert_eq!(result.count, 3);
109 assert!((result.mean["psnr"] - 29.0).abs() < f64::EPSILON);
110 assert!(result.std["psnr"] > 0.0);
111 }
112
113 #[test]
114 fn test_aggregate_empty() {
115 let result = aggregate_experiments(&[]);
116 assert_eq!(result.count, 0);
117 assert!(result.mean.is_empty());
118 }
119
120 #[test]
121 fn test_format_mean_std() {
122 assert_eq!(format_mean_std(28.0, 0.0), "28.0000");
123 assert_eq!(format_mean_std(28.5, 1.2), "28.5000±1.2000");
124 }
125}