1use crate::recipe::{Hyperparameters, RecipeReference};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use uuid::Uuid;
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub struct RunId(Uuid);
12
13impl RunId {
14 #[must_use]
16 pub fn new() -> Self {
17 Self(Uuid::new_v4())
18 }
19
20 #[must_use]
22 pub fn from_uuid(uuid: Uuid) -> Self {
23 Self(uuid)
24 }
25
26 #[must_use]
28 pub fn as_uuid(&self) -> &Uuid {
29 &self.0
30 }
31}
32
33impl Default for RunId {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl std::fmt::Display for RunId {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "{}", self.0)
42 }
43}
44
45impl std::str::FromStr for RunId {
46 type Err = uuid::Error;
47
48 fn from_str(s: &str) -> Result<Self, Self::Err> {
49 Ok(Self(Uuid::parse_str(s)?))
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
55#[serde(rename_all = "lowercase")]
56pub enum RunStatus {
57 Pending,
59 Running,
61 Completed,
63 Failed,
65 Cancelled,
67}
68
69impl std::fmt::Display for RunStatus {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 let s = match self {
72 Self::Pending => "pending",
73 Self::Running => "running",
74 Self::Completed => "completed",
75 Self::Failed => "failed",
76 Self::Cancelled => "cancelled",
77 };
78 write!(f, "{s}")
79 }
80}
81
82#[derive(Debug, Clone, Default, Serialize, Deserialize)]
84pub struct HardwareInfo {
85 #[serde(skip_serializing_if = "Option::is_none")]
87 pub cpu_model: Option<String>,
88 #[serde(skip_serializing_if = "Option::is_none")]
90 pub cpu_cores: Option<usize>,
91 #[serde(skip_serializing_if = "Option::is_none")]
93 pub ram_gb: Option<usize>,
94 #[serde(skip_serializing_if = "Option::is_none")]
96 pub gpu_model: Option<String>,
97 #[serde(skip_serializing_if = "Option::is_none")]
99 pub gpu_count: Option<usize>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct MetricRecord {
105 pub name: String,
107 pub value: f64,
109 pub step: u64,
111 pub timestamp: DateTime<Utc>,
113}
114
115impl MetricRecord {
116 #[must_use]
118 pub fn new(name: impl Into<String>, value: f64, step: u64) -> Self {
119 Self {
120 name: name.into(),
121 value,
122 step,
123 timestamp: Utc::now(),
124 }
125 }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct ArtifactReference {
131 pub artifact_type: String,
133 pub name: String,
135 pub content_hash: String,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct ExperimentRun {
142 pub run_id: RunId,
144 #[serde(skip_serializing_if = "Option::is_none")]
146 pub recipe: Option<RecipeReference>,
147 pub hyperparameters: Hyperparameters,
149
150 pub started_at: DateTime<Utc>,
152 #[serde(skip_serializing_if = "Option::is_none")]
154 pub finished_at: Option<DateTime<Utc>>,
155 pub status: RunStatus,
157 pub hardware: HardwareInfo,
159
160 #[serde(default)]
162 pub metrics: Vec<MetricRecord>,
163 #[serde(default)]
165 pub artifacts: Vec<ArtifactReference>,
166 #[serde(skip_serializing_if = "Option::is_none")]
168 pub log_uri: Option<String>,
169
170 #[serde(skip_serializing_if = "Option::is_none")]
172 pub git_commit: Option<String>,
173 #[serde(default)]
175 pub git_dirty: bool,
176
177 #[serde(skip_serializing_if = "Option::is_none")]
179 pub error_message: Option<String>,
180
181 #[serde(default)]
183 pub extra: HashMap<String, serde_json::Value>,
184}
185
186impl ExperimentRun {
187 #[must_use]
189 pub fn new(hyperparameters: Hyperparameters) -> Self {
190 Self {
191 run_id: RunId::new(),
192 recipe: None,
193 hyperparameters,
194 started_at: Utc::now(),
195 finished_at: None,
196 status: RunStatus::Pending,
197 hardware: HardwareInfo::default(),
198 metrics: Vec::new(),
199 artifacts: Vec::new(),
200 log_uri: None,
201 git_commit: None,
202 git_dirty: false,
203 error_message: None,
204 extra: HashMap::new(),
205 }
206 }
207
208 #[must_use]
210 pub fn from_recipe(recipe: RecipeReference, hyperparameters: Hyperparameters) -> Self {
211 let mut run = Self::new(hyperparameters);
212 run.recipe = Some(recipe);
213 run
214 }
215
216 pub fn start(&mut self) {
218 self.status = RunStatus::Running;
219 self.started_at = Utc::now();
220 }
221
222 pub fn complete(&mut self) {
224 self.status = RunStatus::Completed;
225 self.finished_at = Some(Utc::now());
226 }
227
228 pub fn fail(&mut self, error: impl Into<String>) {
230 self.status = RunStatus::Failed;
231 self.finished_at = Some(Utc::now());
232 self.error_message = Some(error.into());
233 }
234
235 pub fn cancel(&mut self) {
237 self.status = RunStatus::Cancelled;
238 self.finished_at = Some(Utc::now());
239 }
240
241 pub fn log_metric(&mut self, name: impl Into<String>, value: f64, step: u64) {
243 self.metrics.push(MetricRecord::new(name, value, step));
244 }
245
246 #[must_use]
248 pub fn get_metric(&self, name: &str) -> Option<f64> {
249 self.metrics
250 .iter()
251 .filter(|m| m.name == name)
252 .max_by_key(|m| m.step)
253 .map(|m| m.value)
254 }
255
256 #[must_use]
258 pub fn duration_secs(&self) -> Option<i64> {
259 self.finished_at
260 .map(|end| (end - self.started_at).num_seconds())
261 }
262
263 #[must_use]
265 pub fn is_finished(&self) -> bool {
266 matches!(
267 self.status,
268 RunStatus::Completed | RunStatus::Failed | RunStatus::Cancelled
269 )
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_run_id_generation() {
279 let id1 = RunId::new();
280 let id2 = RunId::new();
281 assert_ne!(id1, id2);
282 }
283
284 #[test]
285 fn test_run_status_display() {
286 assert_eq!(RunStatus::Running.to_string(), "running");
287 assert_eq!(RunStatus::Completed.to_string(), "completed");
288 assert_eq!(RunStatus::Failed.to_string(), "failed");
289 }
290
291 #[test]
292 fn test_experiment_run_lifecycle() {
293 let params = Hyperparameters::default();
294 let mut run = ExperimentRun::new(params);
295
296 assert_eq!(run.status, RunStatus::Pending);
297 assert!(!run.is_finished());
298
299 run.start();
300 assert_eq!(run.status, RunStatus::Running);
301
302 run.log_metric("loss", 0.5, 100);
303 run.log_metric("loss", 0.3, 200);
304 run.log_metric("accuracy", 0.8, 200);
305
306 assert_eq!(run.get_metric("loss"), Some(0.3));
307 assert_eq!(run.get_metric("accuracy"), Some(0.8));
308 assert_eq!(run.get_metric("nonexistent"), None);
309
310 run.complete();
311 assert_eq!(run.status, RunStatus::Completed);
312 assert!(run.is_finished());
313 assert!(run.duration_secs().is_some());
314 }
315
316 #[test]
317 fn test_experiment_run_failure() {
318 let params = Hyperparameters::default();
319 let mut run = ExperimentRun::new(params);
320
321 run.start();
322 run.fail("Out of memory");
323
324 assert_eq!(run.status, RunStatus::Failed);
325 assert_eq!(run.error_message, Some("Out of memory".to_string()));
326 assert!(run.is_finished());
327 }
328
329 #[test]
330 fn test_experiment_run_cancel() {
331 let params = Hyperparameters::default();
332 let mut run = ExperimentRun::new(params);
333
334 run.start();
335 run.cancel();
336
337 assert_eq!(run.status, RunStatus::Cancelled);
338 assert!(run.is_finished());
339 }
340
341 #[test]
342 fn test_metric_record() {
343 let metric = MetricRecord::new("val_loss", 0.25, 1000);
344 assert_eq!(metric.name, "val_loss");
345 assert!((metric.value - 0.25).abs() < 1e-10);
346 assert_eq!(metric.step, 1000);
347 }
348
349 #[test]
350 fn test_experiment_run_serialization() {
351 let params = Hyperparameters::default();
352 let mut run = ExperimentRun::new(params);
353 run.log_metric("loss", 0.5, 100);
354
355 let json = serde_json::to_string(&run).unwrap();
356 let deserialized: ExperimentRun = serde_json::from_str(&json).unwrap();
357
358 assert_eq!(run.run_id, deserialized.run_id);
359 assert_eq!(run.metrics.len(), deserialized.metrics.len());
360 }
361}