entrenar/tracking/
storage.rs1use std::collections::HashMap;
7use std::fs;
8use std::path::{Path, PathBuf};
9
10use serde::{Deserialize, Serialize};
11
12use super::{Run, RunStatus};
13
14#[derive(Debug, thiserror::Error)]
16pub enum TrackingStorageError {
17 #[error("I/O error: {0}")]
18 Io(#[from] std::io::Error),
19
20 #[error("JSON serialization error: {0}")]
21 Json(#[from] serde_json::Error),
22
23 #[error("Run not found: {0}")]
24 RunNotFound(String),
25}
26
27pub type Result<T> = std::result::Result<T, TrackingStorageError>;
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct RunRecord {
33 pub run_id: String,
34 pub run_name: Option<String>,
35 pub experiment_name: String,
36 pub status: RunStatus,
37 pub params: HashMap<String, String>,
38 pub metrics: HashMap<String, Vec<MetricEntry>>,
39 pub artifacts: Vec<String>,
40 pub tags: HashMap<String, String>,
41 pub start_time_ms: Option<u64>,
42 pub end_time_ms: Option<u64>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
47pub struct MetricEntry {
48 pub value: f64,
49 pub step: u64,
50}
51
52impl From<&Run> for RunRecord {
53 fn from(run: &Run) -> Self {
54 Self {
55 run_id: run.run_id.clone(),
56 run_name: run.run_name.clone(),
57 experiment_name: run.experiment_name.clone(),
58 status: run.status,
59 params: run.params.clone(),
60 metrics: run
61 .metrics
62 .iter()
63 .map(|(k, v)| {
64 (
65 k.clone(),
66 v.iter()
67 .map(|(val, step)| MetricEntry { value: *val, step: *step })
68 .collect(),
69 )
70 })
71 .collect(),
72 artifacts: run.artifacts.clone(),
73 tags: run.tags.clone(),
74 start_time_ms: run.start_time_ms,
75 end_time_ms: run.end_time_ms,
76 }
77 }
78}
79
80impl RunRecord {
81 pub fn into_run(self) -> Run {
83 Run {
84 run_id: self.run_id,
85 run_name: self.run_name,
86 experiment_name: self.experiment_name,
87 status: self.status,
88 params: self.params,
89 metrics: self
90 .metrics
91 .into_iter()
92 .map(|(k, v)| (k, v.into_iter().map(|e| (e.value, e.step)).collect()))
93 .collect(),
94 artifacts: self.artifacts,
95 tags: self.tags,
96 start_time_ms: self.start_time_ms,
97 end_time_ms: self.end_time_ms,
98 }
99 }
100}
101
102pub trait TrackingBackend {
106 fn save_run(&mut self, run: &Run) -> Result<()>;
108
109 fn load_run(&self, run_id: &str) -> Result<Run>;
111
112 fn list_runs(&self) -> Result<Vec<Run>>;
114
115 fn delete_run(&mut self, run_id: &str) -> Result<()>;
117}
118
119#[derive(Debug)]
133pub struct JsonFileBackend {
134 dir: PathBuf,
135}
136
137impl JsonFileBackend {
138 pub fn new(dir: impl AsRef<Path>) -> Self {
140 Self { dir: dir.as_ref().to_path_buf() }
141 }
142
143 fn run_path(&self, run_id: &str) -> PathBuf {
144 self.dir.join(format!("{run_id}.json"))
145 }
146
147 fn ensure_dir(&self) -> Result<()> {
148 if !self.dir.exists() {
149 fs::create_dir_all(&self.dir)?;
150 }
151 Ok(())
152 }
153}
154
155impl TrackingBackend for JsonFileBackend {
156 fn save_run(&mut self, run: &Run) -> Result<()> {
157 self.ensure_dir()?;
158 let record = RunRecord::from(run);
159 let json = serde_json::to_string_pretty(&record)?;
160 fs::write(self.run_path(&run.run_id), json)?;
161 Ok(())
162 }
163
164 fn load_run(&self, run_id: &str) -> Result<Run> {
165 let path = self.run_path(run_id);
166 if !path.exists() {
167 return Err(TrackingStorageError::RunNotFound(run_id.to_string()));
168 }
169 let json = fs::read_to_string(path)?;
170 let record: RunRecord = serde_json::from_str(&json)?;
171 Ok(record.into_run())
172 }
173
174 fn list_runs(&self) -> Result<Vec<Run>> {
175 if !self.dir.exists() {
176 return Ok(Vec::new());
177 }
178 let mut runs = Vec::new();
179 for entry in fs::read_dir(&self.dir)? {
180 let entry = entry?;
181 let path = entry.path();
182 if path.extension().and_then(|e| e.to_str()) == Some("json") {
183 let json = fs::read_to_string(&path)?;
184 let record: RunRecord = serde_json::from_str(&json)?;
185 runs.push(record.into_run());
186 }
187 }
188 runs.sort_by(|a, b| a.run_id.cmp(&b.run_id));
189 Ok(runs)
190 }
191
192 fn delete_run(&mut self, run_id: &str) -> Result<()> {
193 let path = self.run_path(run_id);
194 if !path.exists() {
195 return Err(TrackingStorageError::RunNotFound(run_id.to_string()));
196 }
197 fs::remove_file(path)?;
198 Ok(())
199 }
200}
201
202#[derive(Debug, Default)]
206pub struct InMemoryBackend {
207 runs: HashMap<String, RunRecord>,
208}
209
210impl InMemoryBackend {
211 #[must_use]
212 pub fn new() -> Self {
213 Self::default()
214 }
215}
216
217impl TrackingBackend for InMemoryBackend {
218 fn save_run(&mut self, run: &Run) -> Result<()> {
219 self.runs.insert(run.run_id.clone(), RunRecord::from(run));
220 Ok(())
221 }
222
223 fn load_run(&self, run_id: &str) -> Result<Run> {
224 self.runs
225 .get(run_id)
226 .map(|r| r.clone().into_run())
227 .ok_or_else(|| TrackingStorageError::RunNotFound(run_id.to_string()))
228 }
229
230 fn list_runs(&self) -> Result<Vec<Run>> {
231 let mut runs: Vec<Run> = self.runs.values().map(|r| r.clone().into_run()).collect();
232 runs.sort_by(|a, b| a.run_id.cmp(&b.run_id));
233 Ok(runs)
234 }
235
236 fn delete_run(&mut self, run_id: &str) -> Result<()> {
237 self.runs
238 .remove(run_id)
239 .map(|_| ())
240 .ok_or_else(|| TrackingStorageError::RunNotFound(run_id.to_string()))
241 }
242}