pub mod storage;
#[cfg(test)]
mod tests;
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use storage::{TrackingBackend, TrackingStorageError};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RunStatus {
Active,
Completed,
Failed,
Cancelled,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Run {
pub run_id: String,
pub run_name: Option<String>,
pub experiment_name: String,
pub status: RunStatus,
pub params: HashMap<String, String>,
pub metrics: HashMap<String, Vec<(f64, u64)>>,
pub artifacts: Vec<String>,
pub tags: HashMap<String, String>,
pub start_time_ms: Option<u64>,
pub end_time_ms: Option<u64>,
}
impl Run {
fn new(run_id: String, run_name: Option<String>, experiment_name: String) -> Self {
let now_ms =
SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64;
Self {
run_id,
run_name,
experiment_name,
status: RunStatus::Active,
params: HashMap::new(),
metrics: HashMap::new(),
artifacts: Vec::new(),
tags: HashMap::new(),
start_time_ms: Some(now_ms),
end_time_ms: None,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum TrackingError {
#[error("Run not found: {0}")]
RunNotFound(String),
#[error("Run is not active: {0}")]
RunNotActive(String),
#[error("Storage error: {0}")]
Storage(#[from] TrackingStorageError),
}
pub type Result<T> = std::result::Result<T, TrackingError>;
#[derive(Debug)]
pub struct ExperimentTracker<B: TrackingBackend> {
experiment_name: String,
tags: HashMap<String, String>,
backend: B,
active_runs: HashMap<String, Run>,
next_run_id: u64,
}
impl<B: TrackingBackend> ExperimentTracker<B> {
pub fn new(experiment_name: impl Into<String>, backend: B) -> Self {
Self {
experiment_name: experiment_name.into(),
tags: HashMap::new(),
backend,
active_runs: HashMap::new(),
next_run_id: 1,
}
}
pub fn add_tag(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.tags.insert(key.into(), value.into());
}
#[must_use]
pub fn experiment_name(&self) -> &str {
&self.experiment_name
}
#[must_use]
pub fn tags(&self) -> &HashMap<String, String> {
&self.tags
}
pub fn start_run(&mut self, run_name: Option<&str>) -> Result<String> {
let run_id = format!("run-{}", self.next_run_id);
self.next_run_id += 1;
let mut run =
Run::new(run_id.clone(), run_name.map(String::from), self.experiment_name.clone());
for (k, v) in &self.tags {
run.tags.insert(k.clone(), v.clone());
}
self.active_runs.insert(run_id.clone(), run);
Ok(run_id)
}
pub fn end_run(&mut self, run_id: &str, status: RunStatus) -> Result<()> {
let mut run = self
.active_runs
.remove(run_id)
.ok_or_else(|| TrackingError::RunNotFound(run_id.to_string()))?;
let now_ms =
SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64;
run.status = status;
run.end_time_ms = Some(now_ms);
self.backend.save_run(&run)?;
Ok(())
}
pub fn log_param(&mut self, run_id: &str, key: &str, value: &str) -> Result<()> {
let run = self
.active_runs
.get_mut(run_id)
.ok_or_else(|| TrackingError::RunNotActive(run_id.to_string()))?;
run.params.insert(key.to_string(), value.to_string());
Ok(())
}
pub fn log_params(&mut self, run_id: &str, params: &HashMap<String, String>) -> Result<()> {
let run = self
.active_runs
.get_mut(run_id)
.ok_or_else(|| TrackingError::RunNotActive(run_id.to_string()))?;
for (k, v) in params {
run.params.insert(k.clone(), v.clone());
}
Ok(())
}
pub fn log_metric(&mut self, run_id: &str, key: &str, value: f64, step: u64) -> Result<()> {
let run = self
.active_runs
.get_mut(run_id)
.ok_or_else(|| TrackingError::RunNotActive(run_id.to_string()))?;
run.metrics.entry(key.to_string()).or_default().push((value, step));
Ok(())
}
pub fn log_artifact(&mut self, run_id: &str, path: &str) -> Result<()> {
let run = self
.active_runs
.get_mut(run_id)
.ok_or_else(|| TrackingError::RunNotActive(run_id.to_string()))?;
run.artifacts.push(path.to_string());
Ok(())
}
pub fn get_run(&self, run_id: &str) -> Result<Run> {
if let Some(run) = self.active_runs.get(run_id) {
return Ok(run.clone());
}
self.backend
.load_run(run_id)
.map_err(|e| TrackingError::RunNotFound(format!("{run_id}: {e}")))
}
pub fn list_runs(&self) -> Result<Vec<Run>> {
let mut runs: Vec<Run> = self.active_runs.values().cloned().collect();
let persisted = self.backend.list_runs()?;
for r in persisted {
if !self.active_runs.contains_key(&r.run_id) {
runs.push(r);
}
}
runs.sort_by(|a, b| a.run_id.cmp(&b.run_id));
Ok(runs)
}
}