batuta/serve/banco/
experiment.rs1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Experiment {
13 pub id: String,
14 pub name: String,
15 pub description: String,
16 pub run_ids: Vec<String>,
17 pub created_at: u64,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct RunComparison {
23 pub experiment_id: String,
24 pub runs: Vec<RunSummary>,
25 pub best_run: Option<String>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct RunSummary {
31 pub id: String,
32 pub method: String,
33 pub status: String,
34 pub final_loss: Option<f32>,
35 pub total_steps: u64,
36}
37
38pub struct ExperimentStore {
40 experiments: RwLock<HashMap<String, Experiment>>,
41 counter: std::sync::atomic::AtomicU64,
42 data_dir: Option<std::path::PathBuf>,
43}
44
45impl ExperimentStore {
46 #[must_use]
48 pub fn new() -> Arc<Self> {
49 Arc::new(Self {
50 experiments: RwLock::new(HashMap::new()),
51 counter: std::sync::atomic::AtomicU64::new(0),
52 data_dir: None,
53 })
54 }
55
56 #[must_use]
58 pub fn with_data_dir(dir: std::path::PathBuf) -> Arc<Self> {
59 let _ = std::fs::create_dir_all(&dir);
60 let mut experiments = HashMap::new();
61
62 if let Ok(entries) = std::fs::read_dir(&dir) {
64 for entry in entries.flatten() {
65 if entry.path().extension().is_some_and(|e| e == "json") {
66 if let Ok(data) = std::fs::read_to_string(entry.path()) {
67 if let Ok(exp) = serde_json::from_str::<Experiment>(&data) {
68 experiments.insert(exp.id.clone(), exp);
69 }
70 }
71 }
72 }
73 }
74
75 let count = experiments.len() as u64;
76 if count > 0 {
77 eprintln!("[banco] Loaded {count} experiments from {}", dir.display());
78 }
79
80 Arc::new(Self {
81 experiments: RwLock::new(experiments),
82 counter: std::sync::atomic::AtomicU64::new(count),
83 data_dir: Some(dir),
84 })
85 }
86
87 pub fn create(&self, name: &str, description: &str) -> Experiment {
89 let seq = self.counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
90 let exp = Experiment {
91 id: format!("exp-{}-{seq}", epoch_secs()),
92 name: name.to_string(),
93 description: description.to_string(),
94 run_ids: Vec::new(),
95 created_at: epoch_secs(),
96 };
97 if let Ok(mut store) = self.experiments.write() {
98 store.insert(exp.id.clone(), exp.clone());
99 }
100 self.persist(&exp);
101 exp
102 }
103
104 pub fn add_run(&self, experiment_id: &str, run_id: &str) -> Result<(), ExperimentError> {
106 let mut store = self.experiments.write().map_err(|_| ExperimentError::LockPoisoned)?;
107 let exp = store
108 .get_mut(experiment_id)
109 .ok_or(ExperimentError::NotFound(experiment_id.to_string()))?;
110 if !exp.run_ids.contains(&run_id.to_string()) {
111 exp.run_ids.push(run_id.to_string());
112 }
113 let exp_clone = exp.clone();
114 drop(store);
115 self.persist(&exp_clone);
116 Ok(())
117 }
118
119 fn persist(&self, exp: &Experiment) {
121 if let Some(dir) = &self.data_dir {
122 let path = dir.join(format!("{}.json", exp.id));
123 if let Ok(json) = serde_json::to_string_pretty(exp) {
124 let _ = std::fs::write(path, json);
125 }
126 }
127 }
128
129 #[must_use]
131 pub fn list(&self) -> Vec<Experiment> {
132 let store = self.experiments.read().unwrap_or_else(|e| e.into_inner());
133 let mut exps: Vec<Experiment> = store.values().cloned().collect();
134 exps.sort_by(|a, b| b.created_at.cmp(&a.created_at));
135 exps
136 }
137
138 #[must_use]
140 pub fn get(&self, id: &str) -> Option<Experiment> {
141 self.experiments.read().unwrap_or_else(|e| e.into_inner()).get(id).cloned()
142 }
143
144 pub fn compare(
146 &self,
147 experiment_id: &str,
148 training: &super::training::TrainingStore,
149 ) -> Result<RunComparison, ExperimentError> {
150 let exp =
151 self.get(experiment_id).ok_or(ExperimentError::NotFound(experiment_id.to_string()))?;
152
153 let mut summaries = Vec::new();
154 let mut best_loss = f32::MAX;
155 let mut best_id = None;
156
157 for run_id in &exp.run_ids {
158 if let Some(run) = training.get(run_id) {
159 let final_loss = run.metrics.last().map(|m| m.loss);
160 let total_steps = run.metrics.last().map(|m| m.step).unwrap_or(0);
161
162 if let Some(loss) = final_loss {
163 if loss < best_loss {
164 best_loss = loss;
165 best_id = Some(run_id.clone());
166 }
167 }
168
169 summaries.push(RunSummary {
170 id: run_id.clone(),
171 method: format!("{:?}", run.method),
172 status: format!("{:?}", run.status),
173 final_loss,
174 total_steps,
175 });
176 }
177 }
178
179 Ok(RunComparison {
180 experiment_id: experiment_id.to_string(),
181 runs: summaries,
182 best_run: best_id,
183 })
184 }
185}
186
187#[derive(Debug, Clone, PartialEq, Eq)]
189pub enum ExperimentError {
190 NotFound(String),
191 LockPoisoned,
192}
193
194impl std::fmt::Display for ExperimentError {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 match self {
197 Self::NotFound(id) => write!(f, "Experiment not found: {id}"),
198 Self::LockPoisoned => write!(f, "Internal lock error"),
199 }
200 }
201}
202
203impl std::error::Error for ExperimentError {}
204
205fn epoch_secs() -> u64 {
206 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs()
207}