pacha/experiment/
mod.rs

1//! Experiment tracking for training runs.
2
3use crate::recipe::{Hyperparameters, RecipeReference};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use uuid::Uuid;
8
9/// Unique identifier for an experiment run.
10#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub struct RunId(Uuid);
12
13impl RunId {
14    /// Create a new random run ID.
15    #[must_use]
16    pub fn new() -> Self {
17        Self(Uuid::new_v4())
18    }
19
20    /// Create from a UUID.
21    #[must_use]
22    pub fn from_uuid(uuid: Uuid) -> Self {
23        Self(uuid)
24    }
25
26    /// Get the underlying UUID.
27    #[must_use]
28    pub fn as_uuid(&self) -> &Uuid {
29        &self.0
30    }
31}
32
33impl Default for RunId {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl std::fmt::Display for RunId {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "{}", self.0)
42    }
43}
44
45impl std::str::FromStr for RunId {
46    type Err = uuid::Error;
47
48    fn from_str(s: &str) -> Result<Self, Self::Err> {
49        Ok(Self(Uuid::parse_str(s)?))
50    }
51}
52
53/// Status of an experiment run.
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
55#[serde(rename_all = "lowercase")]
56pub enum RunStatus {
57    /// Run is pending start.
58    Pending,
59    /// Run is currently executing.
60    Running,
61    /// Run completed successfully.
62    Completed,
63    /// Run failed with an error.
64    Failed,
65    /// Run was cancelled.
66    Cancelled,
67}
68
69impl std::fmt::Display for RunStatus {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        let s = match self {
72            Self::Pending => "pending",
73            Self::Running => "running",
74            Self::Completed => "completed",
75            Self::Failed => "failed",
76            Self::Cancelled => "cancelled",
77        };
78        write!(f, "{s}")
79    }
80}
81
82/// Information about hardware used for a run.
83#[derive(Debug, Clone, Default, Serialize, Deserialize)]
84pub struct HardwareInfo {
85    /// CPU model.
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub cpu_model: Option<String>,
88    /// Number of CPU cores used.
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub cpu_cores: Option<usize>,
91    /// RAM in GB.
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub ram_gb: Option<usize>,
94    /// GPU model.
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub gpu_model: Option<String>,
97    /// Number of GPUs used.
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub gpu_count: Option<usize>,
100}
101
102/// A metric recorded during training.
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct MetricRecord {
105    /// Metric name.
106    pub name: String,
107    /// Metric value.
108    pub value: f64,
109    /// Training step.
110    pub step: u64,
111    /// Timestamp.
112    pub timestamp: DateTime<Utc>,
113}
114
115impl MetricRecord {
116    /// Create a new metric record.
117    #[must_use]
118    pub fn new(name: impl Into<String>, value: f64, step: u64) -> Self {
119        Self {
120            name: name.into(),
121            value,
122            step,
123            timestamp: Utc::now(),
124        }
125    }
126}
127
128/// Reference to an artifact produced by a run.
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct ArtifactReference {
131    /// Artifact type (e.g., "model", "checkpoint").
132    pub artifact_type: String,
133    /// Artifact name.
134    pub name: String,
135    /// Content hash.
136    pub content_hash: String,
137}
138
139/// An experiment run tracking a training execution.
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct ExperimentRun {
142    /// Unique run identifier.
143    pub run_id: RunId,
144    /// Recipe used for this run.
145    #[serde(skip_serializing_if = "Option::is_none")]
146    pub recipe: Option<RecipeReference>,
147    /// Actual hyperparameters used (may override recipe).
148    pub hyperparameters: Hyperparameters,
149
150    /// When the run started.
151    pub started_at: DateTime<Utc>,
152    /// When the run finished.
153    #[serde(skip_serializing_if = "Option::is_none")]
154    pub finished_at: Option<DateTime<Utc>>,
155    /// Current status.
156    pub status: RunStatus,
157    /// Hardware used.
158    pub hardware: HardwareInfo,
159
160    /// Metrics recorded during training.
161    #[serde(default)]
162    pub metrics: Vec<MetricRecord>,
163    /// Artifacts produced.
164    #[serde(default)]
165    pub artifacts: Vec<ArtifactReference>,
166    /// Log URI.
167    #[serde(skip_serializing_if = "Option::is_none")]
168    pub log_uri: Option<String>,
169
170    /// Git commit hash.
171    #[serde(skip_serializing_if = "Option::is_none")]
172    pub git_commit: Option<String>,
173    /// Whether the git working directory was dirty.
174    #[serde(default)]
175    pub git_dirty: bool,
176
177    /// Error message if failed.
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub error_message: Option<String>,
180
181    /// Additional metadata.
182    #[serde(default)]
183    pub extra: HashMap<String, serde_json::Value>,
184}
185
186impl ExperimentRun {
187    /// Create a new experiment run.
188    #[must_use]
189    pub fn new(hyperparameters: Hyperparameters) -> Self {
190        Self {
191            run_id: RunId::new(),
192            recipe: None,
193            hyperparameters,
194            started_at: Utc::now(),
195            finished_at: None,
196            status: RunStatus::Pending,
197            hardware: HardwareInfo::default(),
198            metrics: Vec::new(),
199            artifacts: Vec::new(),
200            log_uri: None,
201            git_commit: None,
202            git_dirty: false,
203            error_message: None,
204            extra: HashMap::new(),
205        }
206    }
207
208    /// Create a run from a recipe.
209    #[must_use]
210    pub fn from_recipe(recipe: RecipeReference, hyperparameters: Hyperparameters) -> Self {
211        let mut run = Self::new(hyperparameters);
212        run.recipe = Some(recipe);
213        run
214    }
215
216    /// Start the run.
217    pub fn start(&mut self) {
218        self.status = RunStatus::Running;
219        self.started_at = Utc::now();
220    }
221
222    /// Complete the run successfully.
223    pub fn complete(&mut self) {
224        self.status = RunStatus::Completed;
225        self.finished_at = Some(Utc::now());
226    }
227
228    /// Mark the run as failed.
229    pub fn fail(&mut self, error: impl Into<String>) {
230        self.status = RunStatus::Failed;
231        self.finished_at = Some(Utc::now());
232        self.error_message = Some(error.into());
233    }
234
235    /// Cancel the run.
236    pub fn cancel(&mut self) {
237        self.status = RunStatus::Cancelled;
238        self.finished_at = Some(Utc::now());
239    }
240
241    /// Log a metric.
242    pub fn log_metric(&mut self, name: impl Into<String>, value: f64, step: u64) {
243        self.metrics.push(MetricRecord::new(name, value, step));
244    }
245
246    /// Get the latest value for a metric.
247    #[must_use]
248    pub fn get_metric(&self, name: &str) -> Option<f64> {
249        self.metrics
250            .iter()
251            .filter(|m| m.name == name)
252            .max_by_key(|m| m.step)
253            .map(|m| m.value)
254    }
255
256    /// Get duration in seconds.
257    #[must_use]
258    pub fn duration_secs(&self) -> Option<i64> {
259        self.finished_at
260            .map(|end| (end - self.started_at).num_seconds())
261    }
262
263    /// Check if the run is finished.
264    #[must_use]
265    pub fn is_finished(&self) -> bool {
266        matches!(
267            self.status,
268            RunStatus::Completed | RunStatus::Failed | RunStatus::Cancelled
269        )
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_run_id_generation() {
279        let id1 = RunId::new();
280        let id2 = RunId::new();
281        assert_ne!(id1, id2);
282    }
283
284    #[test]
285    fn test_run_status_display() {
286        assert_eq!(RunStatus::Running.to_string(), "running");
287        assert_eq!(RunStatus::Completed.to_string(), "completed");
288        assert_eq!(RunStatus::Failed.to_string(), "failed");
289    }
290
291    #[test]
292    fn test_experiment_run_lifecycle() {
293        let params = Hyperparameters::default();
294        let mut run = ExperimentRun::new(params);
295
296        assert_eq!(run.status, RunStatus::Pending);
297        assert!(!run.is_finished());
298
299        run.start();
300        assert_eq!(run.status, RunStatus::Running);
301
302        run.log_metric("loss", 0.5, 100);
303        run.log_metric("loss", 0.3, 200);
304        run.log_metric("accuracy", 0.8, 200);
305
306        assert_eq!(run.get_metric("loss"), Some(0.3));
307        assert_eq!(run.get_metric("accuracy"), Some(0.8));
308        assert_eq!(run.get_metric("nonexistent"), None);
309
310        run.complete();
311        assert_eq!(run.status, RunStatus::Completed);
312        assert!(run.is_finished());
313        assert!(run.duration_secs().is_some());
314    }
315
316    #[test]
317    fn test_experiment_run_failure() {
318        let params = Hyperparameters::default();
319        let mut run = ExperimentRun::new(params);
320
321        run.start();
322        run.fail("Out of memory");
323
324        assert_eq!(run.status, RunStatus::Failed);
325        assert_eq!(run.error_message, Some("Out of memory".to_string()));
326        assert!(run.is_finished());
327    }
328
329    #[test]
330    fn test_experiment_run_cancel() {
331        let params = Hyperparameters::default();
332        let mut run = ExperimentRun::new(params);
333
334        run.start();
335        run.cancel();
336
337        assert_eq!(run.status, RunStatus::Cancelled);
338        assert!(run.is_finished());
339    }
340
341    #[test]
342    fn test_metric_record() {
343        let metric = MetricRecord::new("val_loss", 0.25, 1000);
344        assert_eq!(metric.name, "val_loss");
345        assert!((metric.value - 0.25).abs() < 1e-10);
346        assert_eq!(metric.step, 1000);
347    }
348
349    #[test]
350    fn test_experiment_run_serialization() {
351        let params = Hyperparameters::default();
352        let mut run = ExperimentRun::new(params);
353        run.log_metric("loss", 0.5, 100);
354
355        let json = serde_json::to_string(&run).unwrap();
356        let deserialized: ExperimentRun = serde_json::from_str(&json).unwrap();
357
358        assert_eq!(run.run_id, deserialized.run_id);
359        assert_eq!(run.metrics.len(), deserialized.metrics.len());
360    }
361}