Skip to main content

batuta/serve/banco/
experiment.rs

1//! Experiment tracking — group training runs, compare metrics.
2//!
3//! An experiment is a named collection of training runs that can be compared.
4//! This enables the iterate loop: train → eval → compare → retrain.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9
10/// An experiment (group of training runs).
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Experiment {
13    pub id: String,
14    pub name: String,
15    pub description: String,
16    pub run_ids: Vec<String>,
17    pub created_at: u64,
18}
19
20/// Comparison of runs within an experiment.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct RunComparison {
23    pub experiment_id: String,
24    pub runs: Vec<RunSummary>,
25    pub best_run: Option<String>,
26}
27
28/// Summary of a single run for comparison.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct RunSummary {
31    pub id: String,
32    pub method: String,
33    pub status: String,
34    pub final_loss: Option<f32>,
35    pub total_steps: u64,
36}
37
38/// Experiment store with optional disk persistence.
39pub struct ExperimentStore {
40    experiments: RwLock<HashMap<String, Experiment>>,
41    counter: std::sync::atomic::AtomicU64,
42    data_dir: Option<std::path::PathBuf>,
43}
44
45impl ExperimentStore {
46    /// Create in-memory experiment store.
47    #[must_use]
48    pub fn new() -> Arc<Self> {
49        Arc::new(Self {
50            experiments: RwLock::new(HashMap::new()),
51            counter: std::sync::atomic::AtomicU64::new(0),
52            data_dir: None,
53        })
54    }
55
56    /// Create experiment store with disk persistence.
57    #[must_use]
58    pub fn with_data_dir(dir: std::path::PathBuf) -> Arc<Self> {
59        let _ = std::fs::create_dir_all(&dir);
60        let mut experiments = HashMap::new();
61
62        // Load existing experiments from disk
63        if let Ok(entries) = std::fs::read_dir(&dir) {
64            for entry in entries.flatten() {
65                if entry.path().extension().is_some_and(|e| e == "json") {
66                    if let Ok(data) = std::fs::read_to_string(entry.path()) {
67                        if let Ok(exp) = serde_json::from_str::<Experiment>(&data) {
68                            experiments.insert(exp.id.clone(), exp);
69                        }
70                    }
71                }
72            }
73        }
74
75        let count = experiments.len() as u64;
76        if count > 0 {
77            eprintln!("[banco] Loaded {count} experiments from {}", dir.display());
78        }
79
80        Arc::new(Self {
81            experiments: RwLock::new(experiments),
82            counter: std::sync::atomic::AtomicU64::new(count),
83            data_dir: Some(dir),
84        })
85    }
86
87    /// Create a new experiment.
88    pub fn create(&self, name: &str, description: &str) -> Experiment {
89        let seq = self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
90        let exp = Experiment {
91            id: format!("exp-{}-{seq}", epoch_secs()),
92            name: name.to_string(),
93            description: description.to_string(),
94            run_ids: Vec::new(),
95            created_at: epoch_secs(),
96        };
97        if let Ok(mut store) = self.experiments.write() {
98            store.insert(exp.id.clone(), exp.clone());
99        }
100        self.persist(&exp);
101        exp
102    }
103
104    /// Add a run to an experiment.
105    pub fn add_run(&self, experiment_id: &str, run_id: &str) -> Result<(), ExperimentError> {
106        let mut store = self.experiments.write().map_err(|_| ExperimentError::LockPoisoned)?;
107        let exp = store
108            .get_mut(experiment_id)
109            .ok_or(ExperimentError::NotFound(experiment_id.to_string()))?;
110        if !exp.run_ids.contains(&run_id.to_string()) {
111            exp.run_ids.push(run_id.to_string());
112        }
113        let exp_clone = exp.clone();
114        drop(store);
115        self.persist(&exp_clone);
116        Ok(())
117    }
118
119    /// Persist an experiment to disk (if data_dir is set).
120    fn persist(&self, exp: &Experiment) {
121        if let Some(dir) = &self.data_dir {
122            let path = dir.join(format!("{}.json", exp.id));
123            if let Ok(json) = serde_json::to_string_pretty(exp) {
124                let _ = std::fs::write(path, json);
125            }
126        }
127    }
128
129    /// List all experiments.
130    #[must_use]
131    pub fn list(&self) -> Vec<Experiment> {
132        let store = self.experiments.read().unwrap_or_else(|e| e.into_inner());
133        let mut exps: Vec<Experiment> = store.values().cloned().collect();
134        exps.sort_by(|a, b| b.created_at.cmp(&a.created_at));
135        exps
136    }
137
138    /// Get experiment by ID.
139    #[must_use]
140    pub fn get(&self, id: &str) -> Option<Experiment> {
141        self.experiments.read().unwrap_or_else(|e| e.into_inner()).get(id).cloned()
142    }
143
144    /// Compare runs in an experiment using the training store.
145    pub fn compare(
146        &self,
147        experiment_id: &str,
148        training: &super::training::TrainingStore,
149    ) -> Result<RunComparison, ExperimentError> {
150        let exp =
151            self.get(experiment_id).ok_or(ExperimentError::NotFound(experiment_id.to_string()))?;
152
153        let mut summaries = Vec::new();
154        let mut best_loss = f32::MAX;
155        let mut best_id = None;
156
157        for run_id in &exp.run_ids {
158            if let Some(run) = training.get(run_id) {
159                let final_loss = run.metrics.last().map(|m| m.loss);
160                let total_steps = run.metrics.last().map(|m| m.step).unwrap_or(0);
161
162                if let Some(loss) = final_loss {
163                    if loss < best_loss {
164                        best_loss = loss;
165                        best_id = Some(run_id.clone());
166                    }
167                }
168
169                summaries.push(RunSummary {
170                    id: run_id.clone(),
171                    method: format!("{:?}", run.method),
172                    status: format!("{:?}", run.status),
173                    final_loss,
174                    total_steps,
175                });
176            }
177        }
178
179        Ok(RunComparison {
180            experiment_id: experiment_id.to_string(),
181            runs: summaries,
182            best_run: best_id,
183        })
184    }
185}
186
187/// Experiment errors.
188#[derive(Debug, Clone, PartialEq, Eq)]
189pub enum ExperimentError {
190    NotFound(String),
191    LockPoisoned,
192}
193
194impl std::fmt::Display for ExperimentError {
195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196        match self {
197            Self::NotFound(id) => write!(f, "Experiment not found: {id}"),
198            Self::LockPoisoned => write!(f, "Internal lock error"),
199        }
200    }
201}
202
203impl std::error::Error for ExperimentError {}
204
205fn epoch_secs() -> u64 {
206    std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs()
207}