1pub mod storage;
41
42#[cfg(test)]
43mod tests;
44
45use std::collections::HashMap;
46use std::time::{SystemTime, UNIX_EPOCH};
47
48use serde::{Deserialize, Serialize};
49
50use storage::{TrackingBackend, TrackingStorageError};
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54pub enum RunStatus {
55 Active,
57 Completed,
59 Failed,
61 Cancelled,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct Run {
71 pub run_id: String,
73 pub run_name: Option<String>,
75 pub experiment_name: String,
77 pub status: RunStatus,
79 pub params: HashMap<String, String>,
81 pub metrics: HashMap<String, Vec<(f64, u64)>>,
83 pub artifacts: Vec<String>,
85 pub tags: HashMap<String, String>,
87 pub start_time_ms: Option<u64>,
89 pub end_time_ms: Option<u64>,
91}
92
93impl Run {
94 fn new(run_id: String, run_name: Option<String>, experiment_name: String) -> Self {
95 let now_ms =
96 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64;
97
98 Self {
99 run_id,
100 run_name,
101 experiment_name,
102 status: RunStatus::Active,
103 params: HashMap::new(),
104 metrics: HashMap::new(),
105 artifacts: Vec::new(),
106 tags: HashMap::new(),
107 start_time_ms: Some(now_ms),
108 end_time_ms: None,
109 }
110 }
111}
112
113#[derive(Debug, thiserror::Error)]
115pub enum TrackingError {
116 #[error("Run not found: {0}")]
117 RunNotFound(String),
118
119 #[error("Run is not active: {0}")]
120 RunNotActive(String),
121
122 #[error("Storage error: {0}")]
123 Storage(#[from] TrackingStorageError),
124}
125
126pub type Result<T> = std::result::Result<T, TrackingError>;
128
129#[derive(Debug)]
134pub struct ExperimentTracker<B: TrackingBackend> {
135 experiment_name: String,
136 tags: HashMap<String, String>,
137 backend: B,
138 active_runs: HashMap<String, Run>,
140 next_run_id: u64,
141}
142
143impl<B: TrackingBackend> ExperimentTracker<B> {
144 pub fn new(experiment_name: impl Into<String>, backend: B) -> Self {
146 Self {
147 experiment_name: experiment_name.into(),
148 tags: HashMap::new(),
149 backend,
150 active_runs: HashMap::new(),
151 next_run_id: 1,
152 }
153 }
154
155 pub fn add_tag(&mut self, key: impl Into<String>, value: impl Into<String>) {
157 self.tags.insert(key.into(), value.into());
158 }
159
160 #[must_use]
162 pub fn experiment_name(&self) -> &str {
163 &self.experiment_name
164 }
165
166 #[must_use]
168 pub fn tags(&self) -> &HashMap<String, String> {
169 &self.tags
170 }
171
172 pub fn start_run(&mut self, run_name: Option<&str>) -> Result<String> {
176 let run_id = format!("run-{}", self.next_run_id);
177 self.next_run_id += 1;
178
179 let mut run =
180 Run::new(run_id.clone(), run_name.map(String::from), self.experiment_name.clone());
181 for (k, v) in &self.tags {
183 run.tags.insert(k.clone(), v.clone());
184 }
185
186 self.active_runs.insert(run_id.clone(), run);
187 Ok(run_id)
188 }
189
190 pub fn end_run(&mut self, run_id: &str, status: RunStatus) -> Result<()> {
192 let mut run = self
193 .active_runs
194 .remove(run_id)
195 .ok_or_else(|| TrackingError::RunNotFound(run_id.to_string()))?;
196
197 let now_ms =
198 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64;
199
200 run.status = status;
201 run.end_time_ms = Some(now_ms);
202
203 self.backend.save_run(&run)?;
204 Ok(())
205 }
206
207 pub fn log_param(&mut self, run_id: &str, key: &str, value: &str) -> Result<()> {
209 let run = self
210 .active_runs
211 .get_mut(run_id)
212 .ok_or_else(|| TrackingError::RunNotActive(run_id.to_string()))?;
213
214 run.params.insert(key.to_string(), value.to_string());
215 Ok(())
216 }
217
218 pub fn log_params(&mut self, run_id: &str, params: &HashMap<String, String>) -> Result<()> {
220 let run = self
221 .active_runs
222 .get_mut(run_id)
223 .ok_or_else(|| TrackingError::RunNotActive(run_id.to_string()))?;
224
225 for (k, v) in params {
226 run.params.insert(k.clone(), v.clone());
227 }
228 Ok(())
229 }
230
231 pub fn log_metric(&mut self, run_id: &str, key: &str, value: f64, step: u64) -> Result<()> {
233 let run = self
234 .active_runs
235 .get_mut(run_id)
236 .ok_or_else(|| TrackingError::RunNotActive(run_id.to_string()))?;
237
238 run.metrics.entry(key.to_string()).or_default().push((value, step));
239 Ok(())
240 }
241
242 pub fn log_artifact(&mut self, run_id: &str, path: &str) -> Result<()> {
244 let run = self
245 .active_runs
246 .get_mut(run_id)
247 .ok_or_else(|| TrackingError::RunNotActive(run_id.to_string()))?;
248
249 run.artifacts.push(path.to_string());
250 Ok(())
251 }
252
253 pub fn get_run(&self, run_id: &str) -> Result<Run> {
257 if let Some(run) = self.active_runs.get(run_id) {
258 return Ok(run.clone());
259 }
260 self.backend
261 .load_run(run_id)
262 .map_err(|e| TrackingError::RunNotFound(format!("{run_id}: {e}")))
263 }
264
265 pub fn list_runs(&self) -> Result<Vec<Run>> {
267 let mut runs: Vec<Run> = self.active_runs.values().cloned().collect();
268 let persisted = self.backend.list_runs()?;
269 for r in persisted {
271 if !self.active_runs.contains_key(&r.run_id) {
272 runs.push(r);
273 }
274 }
275 runs.sort_by(|a, b| a.run_id.cmp(&b.run_id));
276 Ok(runs)
277 }
278}