Skip to main content

nuviz_cli/data/
aggregation.rs

1use std::collections::HashMap;
2
3use crate::data::experiment::Experiment;
4
5/// Aggregated metrics across multiple experiment runs (mean +/- std).
6#[derive(Debug, Clone)]
7#[allow(dead_code)]
8pub struct AggregatedMetrics {
9    pub mean: HashMap<String, f64>,
10    pub std: HashMap<String, f64>,
11    pub count: usize,
12}
13
14/// Compute mean and standard deviation of best_metrics across experiments.
15#[allow(dead_code)]
16pub fn aggregate_experiments(experiments: &[&Experiment]) -> AggregatedMetrics {
17    let n = experiments.len();
18    if n == 0 {
19        return AggregatedMetrics {
20            mean: HashMap::new(),
21            std: HashMap::new(),
22            count: 0,
23        };
24    }
25
26    // Collect all metric values per name
27    let mut values_by_metric: HashMap<String, Vec<f64>> = HashMap::new();
28    for exp in experiments {
29        for (name, &value) in &exp.best_metrics {
30            if value.is_finite() {
31                values_by_metric
32                    .entry(name.clone())
33                    .or_default()
34                    .push(value);
35            }
36        }
37    }
38
39    let mut mean = HashMap::new();
40    let mut std = HashMap::new();
41
42    for (name, values) in &values_by_metric {
43        let count = values.len() as f64;
44        let m = values.iter().sum::<f64>() / count;
45        mean.insert(name.clone(), m);
46
47        if values.len() > 1 {
48            let variance = values.iter().map(|v| (v - m).powi(2)).sum::<f64>() / (count - 1.0);
49            std.insert(name.clone(), variance.sqrt());
50        } else {
51            std.insert(name.clone(), 0.0);
52        }
53    }
54
55    AggregatedMetrics {
56        mean,
57        std,
58        count: n,
59    }
60}
61
62/// Format a value as "mean +/- std" string.
63pub fn format_mean_std(mean: f64, std: f64) -> String {
64    if std.abs() < f64::EPSILON {
65        format!("{mean:.4}")
66    } else {
67        format!("{mean:.4}±{std:.4}")
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74    use std::path::PathBuf;
75
76    fn make_experiment(name: &str, metrics: &[(&str, f64)]) -> Experiment {
77        Experiment {
78            name: name.into(),
79            project: None,
80            dir: PathBuf::from("/tmp"),
81            status: "done".into(),
82            total_steps: Some(100),
83            best_metrics: metrics.iter().map(|(k, v)| (k.to_string(), *v)).collect(),
84            start_time: None,
85            end_time: None,
86            seed: None,
87            config_hash: None,
88            config: None,
89            tags: Vec::new(),
90        }
91    }
92
93    #[test]
94    fn test_aggregate_single() {
95        let exp = make_experiment("e1", &[("psnr", 28.0), ("loss", 0.05)]);
96        let result = aggregate_experiments(&[&exp]);
97        assert_eq!(result.count, 1);
98        assert!((result.mean["psnr"] - 28.0).abs() < f64::EPSILON);
99        assert!((result.std["psnr"]).abs() < f64::EPSILON);
100    }
101
102    #[test]
103    fn test_aggregate_multiple() {
104        let e1 = make_experiment("e1", &[("psnr", 28.0)]);
105        let e2 = make_experiment("e2", &[("psnr", 30.0)]);
106        let e3 = make_experiment("e3", &[("psnr", 29.0)]);
107        let result = aggregate_experiments(&[&e1, &e2, &e3]);
108        assert_eq!(result.count, 3);
109        assert!((result.mean["psnr"] - 29.0).abs() < f64::EPSILON);
110        assert!(result.std["psnr"] > 0.0);
111    }
112
113    #[test]
114    fn test_aggregate_empty() {
115        let result = aggregate_experiments(&[]);
116        assert_eq!(result.count, 0);
117        assert!(result.mean.is_empty());
118    }
119
120    #[test]
121    fn test_format_mean_std() {
122        assert_eq!(format_mean_std(28.0, 0.0), "28.0000");
123        assert_eq!(format_mean_std(28.5, 1.2), "28.5000±1.2000");
124    }
125}