Skip to main content

entrenar/tracking/
mod.rs

1//! Experiment Tracking Module (GH-31)
2//!
3//! Provides high-level experiment tracking with parameter logging, metric
4//! recording, and artifact management. Backed by pluggable storage via the
5//! [`TrackingBackend`](storage::TrackingBackend) trait.
6//!
7//! # Architecture
8//!
9//! - **`ExperimentTracker`**: Top-level handle that manages runs for a named experiment
10//! - **`Run`**: A single training run with parameters, metrics, and artifacts
11//! - **`TrackingBackend`**: Pluggable persistence (JSON files, in-memory)
12//!
13//! # Example
14//!
15//! ```
16//! use entrenar::tracking::{ExperimentTracker, RunStatus};
17//! use entrenar::tracking::storage::InMemoryBackend;
18//!
19//! # fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
20//! let backend = InMemoryBackend::new();
21//! let mut tracker = ExperimentTracker::new("my-experiment", backend);
22//! tracker.add_tag("team", "ml-infra");
23//!
24//! let run_id = tracker.start_run(Some("baseline-v1"))?;
25//! tracker.log_param(&run_id, "lr", "0.001")?;
26//! tracker.log_metric(&run_id, "loss", 0.5, 1)?;
27//! tracker.log_metric(&run_id, "loss", 0.3, 2)?;
28//! tracker.log_artifact(&run_id, "model.safetensors")?;
29//! tracker.end_run(&run_id, RunStatus::Completed)?;
30//!
31//! let run = tracker.get_run(&run_id)?;
32//! assert_eq!(run.params.get("lr").unwrap_or(&String::new()), "0.001");
33//!
34//! let all = tracker.list_runs()?;
35//! assert_eq!(all.len(), 1);
36//! # Ok(())
37//! # }
38//! ```
39
40pub mod storage;
41
42#[cfg(test)]
43mod tests;
44
45use std::collections::HashMap;
46use std::time::{SystemTime, UNIX_EPOCH};
47
48use serde::{Deserialize, Serialize};
49
50use storage::{TrackingBackend, TrackingStorageError};
51
52/// Status of a tracking run
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54pub enum RunStatus {
55    /// Run is actively recording
56    Active,
57    /// Run completed successfully
58    Completed,
59    /// Run failed
60    Failed,
61    /// Run was cancelled
62    Cancelled,
63}
64
65/// A single experiment run
66///
67/// Tracks parameters (hyperparameters), metrics (per-step values),
68/// artifacts (file paths), and tags (key-value metadata).
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct Run {
71    /// Unique identifier for the run
72    pub run_id: String,
73    /// Optional human-readable name
74    pub run_name: Option<String>,
75    /// Parent experiment name
76    pub experiment_name: String,
77    /// Current status
78    pub status: RunStatus,
79    /// Hyperparameters: key -> value (string-encoded)
80    pub params: HashMap<String, String>,
81    /// Metrics: key -> list of (value, step)
82    pub metrics: HashMap<String, Vec<(f64, u64)>>,
83    /// Artifact paths
84    pub artifacts: Vec<String>,
85    /// Tags: key -> value
86    pub tags: HashMap<String, String>,
87    /// Unix timestamp (ms) when the run started
88    pub start_time_ms: Option<u64>,
89    /// Unix timestamp (ms) when the run ended
90    pub end_time_ms: Option<u64>,
91}
92
93impl Run {
94    fn new(run_id: String, run_name: Option<String>, experiment_name: String) -> Self {
95        let now_ms =
96            SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64;
97
98        Self {
99            run_id,
100            run_name,
101            experiment_name,
102            status: RunStatus::Active,
103            params: HashMap::new(),
104            metrics: HashMap::new(),
105            artifacts: Vec::new(),
106            tags: HashMap::new(),
107            start_time_ms: Some(now_ms),
108            end_time_ms: None,
109        }
110    }
111}
112
113/// Errors from experiment tracking operations
114#[derive(Debug, thiserror::Error)]
115pub enum TrackingError {
116    #[error("Run not found: {0}")]
117    RunNotFound(String),
118
119    #[error("Run is not active: {0}")]
120    RunNotActive(String),
121
122    #[error("Storage error: {0}")]
123    Storage(#[from] TrackingStorageError),
124}
125
126/// Result alias for tracking operations
127pub type Result<T> = std::result::Result<T, TrackingError>;
128
129/// Experiment tracker
130///
131/// Manages multiple runs under a single experiment name. Persists run data
132/// through a pluggable [`TrackingBackend`].
133#[derive(Debug)]
134pub struct ExperimentTracker<B: TrackingBackend> {
135    experiment_name: String,
136    tags: HashMap<String, String>,
137    backend: B,
138    /// Active runs held in memory for fast mutation
139    active_runs: HashMap<String, Run>,
140    next_run_id: u64,
141}
142
143impl<B: TrackingBackend> ExperimentTracker<B> {
144    /// Create a new tracker for the given experiment name
145    pub fn new(experiment_name: impl Into<String>, backend: B) -> Self {
146        Self {
147            experiment_name: experiment_name.into(),
148            tags: HashMap::new(),
149            backend,
150            active_runs: HashMap::new(),
151            next_run_id: 1,
152        }
153    }
154
155    /// Add an experiment-level tag
156    pub fn add_tag(&mut self, key: impl Into<String>, value: impl Into<String>) {
157        self.tags.insert(key.into(), value.into());
158    }
159
160    /// Get the experiment name
161    #[must_use]
162    pub fn experiment_name(&self) -> &str {
163        &self.experiment_name
164    }
165
166    /// Get experiment-level tags
167    #[must_use]
168    pub fn tags(&self) -> &HashMap<String, String> {
169        &self.tags
170    }
171
172    /// Start a new run, optionally with a human-readable name
173    ///
174    /// Returns the run ID.
175    pub fn start_run(&mut self, run_name: Option<&str>) -> Result<String> {
176        let run_id = format!("run-{}", self.next_run_id);
177        self.next_run_id += 1;
178
179        let mut run =
180            Run::new(run_id.clone(), run_name.map(String::from), self.experiment_name.clone());
181        // Inherit experiment-level tags
182        for (k, v) in &self.tags {
183            run.tags.insert(k.clone(), v.clone());
184        }
185
186        self.active_runs.insert(run_id.clone(), run);
187        Ok(run_id)
188    }
189
190    /// End a run with the given status, persisting it to the backend
191    pub fn end_run(&mut self, run_id: &str, status: RunStatus) -> Result<()> {
192        let mut run = self
193            .active_runs
194            .remove(run_id)
195            .ok_or_else(|| TrackingError::RunNotFound(run_id.to_string()))?;
196
197        let now_ms =
198            SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64;
199
200        run.status = status;
201        run.end_time_ms = Some(now_ms);
202
203        self.backend.save_run(&run)?;
204        Ok(())
205    }
206
207    /// Log a single parameter (hyperparameter)
208    pub fn log_param(&mut self, run_id: &str, key: &str, value: &str) -> Result<()> {
209        let run = self
210            .active_runs
211            .get_mut(run_id)
212            .ok_or_else(|| TrackingError::RunNotActive(run_id.to_string()))?;
213
214        run.params.insert(key.to_string(), value.to_string());
215        Ok(())
216    }
217
218    /// Log multiple parameters at once
219    pub fn log_params(&mut self, run_id: &str, params: &HashMap<String, String>) -> Result<()> {
220        let run = self
221            .active_runs
222            .get_mut(run_id)
223            .ok_or_else(|| TrackingError::RunNotActive(run_id.to_string()))?;
224
225        for (k, v) in params {
226            run.params.insert(k.clone(), v.clone());
227        }
228        Ok(())
229    }
230
231    /// Log a metric value at a given step
232    pub fn log_metric(&mut self, run_id: &str, key: &str, value: f64, step: u64) -> Result<()> {
233        let run = self
234            .active_runs
235            .get_mut(run_id)
236            .ok_or_else(|| TrackingError::RunNotActive(run_id.to_string()))?;
237
238        run.metrics.entry(key.to_string()).or_default().push((value, step));
239        Ok(())
240    }
241
242    /// Log an artifact path
243    pub fn log_artifact(&mut self, run_id: &str, path: &str) -> Result<()> {
244        let run = self
245            .active_runs
246            .get_mut(run_id)
247            .ok_or_else(|| TrackingError::RunNotActive(run_id.to_string()))?;
248
249        run.artifacts.push(path.to_string());
250        Ok(())
251    }
252
253    /// Retrieve a run by ID
254    ///
255    /// Checks active (in-memory) runs first, then falls back to the backend.
256    pub fn get_run(&self, run_id: &str) -> Result<Run> {
257        if let Some(run) = self.active_runs.get(run_id) {
258            return Ok(run.clone());
259        }
260        self.backend
261            .load_run(run_id)
262            .map_err(|e| TrackingError::RunNotFound(format!("{run_id}: {e}")))
263    }
264
265    /// List all runs (active + persisted)
266    pub fn list_runs(&self) -> Result<Vec<Run>> {
267        let mut runs: Vec<Run> = self.active_runs.values().cloned().collect();
268        let persisted = self.backend.list_runs()?;
269        // Avoid duplicates: only add persisted runs whose IDs are not active
270        for r in persisted {
271            if !self.active_runs.contains_key(&r.run_id) {
272                runs.push(r);
273            }
274        }
275        runs.sort_by(|a, b| a.run_id.cmp(&b.run_id));
276        Ok(runs)
277    }
278}