Skip to main content

entrenar/tracking/
storage.rs

1//! Tracking storage backends
2//!
3//! Provides the `TrackingBackend` trait and a JSON file-based implementation
4//! for persisting experiment runs to disk.
5
6use std::collections::HashMap;
7use std::fs;
8use std::path::{Path, PathBuf};
9
10use serde::{Deserialize, Serialize};
11
12use super::{Run, RunStatus};
13
14/// Errors from tracking storage operations
15#[derive(Debug, thiserror::Error)]
16pub enum TrackingStorageError {
17    #[error("I/O error: {0}")]
18    Io(#[from] std::io::Error),
19
20    #[error("JSON serialization error: {0}")]
21    Json(#[from] serde_json::Error),
22
23    #[error("Run not found: {0}")]
24    RunNotFound(String),
25}
26
27/// Result alias for tracking storage operations
28pub type Result<T> = std::result::Result<T, TrackingStorageError>;
29
30/// Serializable snapshot of a run for persistence
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct RunRecord {
33    pub run_id: String,
34    pub run_name: Option<String>,
35    pub experiment_name: String,
36    pub status: RunStatus,
37    pub params: HashMap<String, String>,
38    pub metrics: HashMap<String, Vec<MetricEntry>>,
39    pub artifacts: Vec<String>,
40    pub tags: HashMap<String, String>,
41    pub start_time_ms: Option<u64>,
42    pub end_time_ms: Option<u64>,
43}
44
45/// A single metric data point for serialization
46#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
47pub struct MetricEntry {
48    pub value: f64,
49    pub step: u64,
50}
51
52impl From<&Run> for RunRecord {
53    fn from(run: &Run) -> Self {
54        Self {
55            run_id: run.run_id.clone(),
56            run_name: run.run_name.clone(),
57            experiment_name: run.experiment_name.clone(),
58            status: run.status,
59            params: run.params.clone(),
60            metrics: run
61                .metrics
62                .iter()
63                .map(|(k, v)| {
64                    (
65                        k.clone(),
66                        v.iter()
67                            .map(|(val, step)| MetricEntry { value: *val, step: *step })
68                            .collect(),
69                    )
70                })
71                .collect(),
72            artifacts: run.artifacts.clone(),
73            tags: run.tags.clone(),
74            start_time_ms: run.start_time_ms,
75            end_time_ms: run.end_time_ms,
76        }
77    }
78}
79
80impl RunRecord {
81    /// Convert back into a `Run`
82    pub fn into_run(self) -> Run {
83        Run {
84            run_id: self.run_id,
85            run_name: self.run_name,
86            experiment_name: self.experiment_name,
87            status: self.status,
88            params: self.params,
89            metrics: self
90                .metrics
91                .into_iter()
92                .map(|(k, v)| (k, v.into_iter().map(|e| (e.value, e.step)).collect()))
93                .collect(),
94            artifacts: self.artifacts,
95            tags: self.tags,
96            start_time_ms: self.start_time_ms,
97            end_time_ms: self.end_time_ms,
98        }
99    }
100}
101
102/// Trait for tracking storage backends
103///
104/// Implementations persist and retrieve experiment runs.
105pub trait TrackingBackend {
106    /// Save a run to the backend
107    fn save_run(&mut self, run: &Run) -> Result<()>;
108
109    /// Load a run by its ID
110    fn load_run(&self, run_id: &str) -> Result<Run>;
111
112    /// List all stored runs
113    fn list_runs(&self) -> Result<Vec<Run>>;
114
115    /// Delete a run by its ID
116    fn delete_run(&mut self, run_id: &str) -> Result<()>;
117}
118
119/// JSON file-based tracking backend
120///
121/// Stores each run as a separate JSON file in a directory.
122/// File names are `{run_id}.json`.
123///
124/// # Example
125///
126/// ```no_run
127/// use entrenar::tracking::storage::JsonFileBackend;
128/// use entrenar::tracking::storage::TrackingBackend;
129///
130/// let mut backend = JsonFileBackend::new("/tmp/runs");
131/// ```
132#[derive(Debug)]
133pub struct JsonFileBackend {
134    dir: PathBuf,
135}
136
137impl JsonFileBackend {
138    /// Create a new JSON file backend, creating the directory if it does not exist
139    pub fn new(dir: impl AsRef<Path>) -> Self {
140        Self { dir: dir.as_ref().to_path_buf() }
141    }
142
143    fn run_path(&self, run_id: &str) -> PathBuf {
144        self.dir.join(format!("{run_id}.json"))
145    }
146
147    fn ensure_dir(&self) -> Result<()> {
148        if !self.dir.exists() {
149            fs::create_dir_all(&self.dir)?;
150        }
151        Ok(())
152    }
153}
154
155impl TrackingBackend for JsonFileBackend {
156    fn save_run(&mut self, run: &Run) -> Result<()> {
157        self.ensure_dir()?;
158        let record = RunRecord::from(run);
159        let json = serde_json::to_string_pretty(&record)?;
160        fs::write(self.run_path(&run.run_id), json)?;
161        Ok(())
162    }
163
164    fn load_run(&self, run_id: &str) -> Result<Run> {
165        let path = self.run_path(run_id);
166        if !path.exists() {
167            return Err(TrackingStorageError::RunNotFound(run_id.to_string()));
168        }
169        let json = fs::read_to_string(path)?;
170        let record: RunRecord = serde_json::from_str(&json)?;
171        Ok(record.into_run())
172    }
173
174    fn list_runs(&self) -> Result<Vec<Run>> {
175        if !self.dir.exists() {
176            return Ok(Vec::new());
177        }
178        let mut runs = Vec::new();
179        for entry in fs::read_dir(&self.dir)? {
180            let entry = entry?;
181            let path = entry.path();
182            if path.extension().and_then(|e| e.to_str()) == Some("json") {
183                let json = fs::read_to_string(&path)?;
184                let record: RunRecord = serde_json::from_str(&json)?;
185                runs.push(record.into_run());
186            }
187        }
188        runs.sort_by(|a, b| a.run_id.cmp(&b.run_id));
189        Ok(runs)
190    }
191
192    fn delete_run(&mut self, run_id: &str) -> Result<()> {
193        let path = self.run_path(run_id);
194        if !path.exists() {
195            return Err(TrackingStorageError::RunNotFound(run_id.to_string()));
196        }
197        fs::remove_file(path)?;
198        Ok(())
199    }
200}
201
202/// In-memory tracking backend for testing
203///
204/// Stores runs in a `HashMap`. No persistence.
205#[derive(Debug, Default)]
206pub struct InMemoryBackend {
207    runs: HashMap<String, RunRecord>,
208}
209
210impl InMemoryBackend {
211    #[must_use]
212    pub fn new() -> Self {
213        Self::default()
214    }
215}
216
217impl TrackingBackend for InMemoryBackend {
218    fn save_run(&mut self, run: &Run) -> Result<()> {
219        self.runs.insert(run.run_id.clone(), RunRecord::from(run));
220        Ok(())
221    }
222
223    fn load_run(&self, run_id: &str) -> Result<Run> {
224        self.runs
225            .get(run_id)
226            .map(|r| r.clone().into_run())
227            .ok_or_else(|| TrackingStorageError::RunNotFound(run_id.to_string()))
228    }
229
230    fn list_runs(&self) -> Result<Vec<Run>> {
231        let mut runs: Vec<Run> = self.runs.values().map(|r| r.clone().into_run()).collect();
232        runs.sort_by(|a, b| a.run_id.cmp(&b.run_id));
233        Ok(runs)
234    }
235
236    fn delete_run(&mut self, run_id: &str) -> Result<()> {
237        self.runs
238            .remove(run_id)
239            .map(|_| ())
240            .ok_or_else(|| TrackingStorageError::RunNotFound(run_id.to_string()))
241    }
242}