Skip to main content

pacha/experiment/
mod.rs

1//! Experiment tracking for training runs.
2
3use crate::recipe::{Hyperparameters, RecipeReference};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use uuid::Uuid;
8
9/// Unique identifier for an experiment run.
10#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub struct RunId(Uuid);
12
13impl RunId {
14    /// Create a new random run ID.
15    #[must_use]
16    pub fn new() -> Self {
17        Self(Uuid::new_v4())
18    }
19
20    /// Create from a UUID.
21    #[must_use]
22    pub fn from_uuid(uuid: Uuid) -> Self {
23        Self(uuid)
24    }
25
26    /// Get the underlying UUID.
27    #[must_use]
28    pub fn as_uuid(&self) -> &Uuid {
29        &self.0
30    }
31}
32
33impl Default for RunId {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl std::fmt::Display for RunId {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "{}", self.0)
42    }
43}
44
45impl std::str::FromStr for RunId {
46    type Err = uuid::Error;
47
48    fn from_str(s: &str) -> Result<Self, Self::Err> {
49        Ok(Self(Uuid::parse_str(s)?))
50    }
51}
52
53/// Status of an experiment run.
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
55#[serde(rename_all = "lowercase")]
56pub enum RunStatus {
57    /// Run is pending start.
58    Pending,
59    /// Run is currently executing.
60    Running,
61    /// Run completed successfully.
62    Completed,
63    /// Run failed with an error.
64    Failed,
65    /// Run was cancelled.
66    Cancelled,
67}
68
69impl std::fmt::Display for RunStatus {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        let s = match self {
72            Self::Pending => "pending",
73            Self::Running => "running",
74            Self::Completed => "completed",
75            Self::Failed => "failed",
76            Self::Cancelled => "cancelled",
77        };
78        write!(f, "{s}")
79    }
80}
81
82/// Information about hardware used for a run.
83#[derive(Debug, Clone, Default, Serialize, Deserialize)]
84pub struct HardwareInfo {
85    /// CPU model.
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub cpu_model: Option<String>,
88    /// Number of CPU cores used.
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub cpu_cores: Option<usize>,
91    /// RAM in GB.
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub ram_gb: Option<usize>,
94    /// GPU model.
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub gpu_model: Option<String>,
97    /// Number of GPUs used.
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub gpu_count: Option<usize>,
100}
101
102/// A metric recorded during training.
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct MetricRecord {
105    /// Metric name.
106    pub name: String,
107    /// Metric value.
108    pub value: f64,
109    /// Training step.
110    pub step: u64,
111    /// Timestamp.
112    pub timestamp: DateTime<Utc>,
113}
114
115impl MetricRecord {
116    /// Create a new metric record.
117    #[must_use]
118    pub fn new(name: impl Into<String>, value: f64, step: u64) -> Self {
119        Self { name: name.into(), value, step, timestamp: Utc::now() }
120    }
121}
122
123/// Reference to an artifact produced by a run.
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct ArtifactReference {
126    /// Artifact type (e.g., "model", "checkpoint").
127    pub artifact_type: String,
128    /// Artifact name.
129    pub name: String,
130    /// Content hash.
131    pub content_hash: String,
132}
133
134/// An experiment run tracking a training execution.
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct ExperimentRun {
137    /// Unique run identifier.
138    pub run_id: RunId,
139    /// Recipe used for this run.
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub recipe: Option<RecipeReference>,
142    /// Actual hyperparameters used (may override recipe).
143    pub hyperparameters: Hyperparameters,
144
145    /// When the run started.
146    pub started_at: DateTime<Utc>,
147    /// When the run finished.
148    #[serde(skip_serializing_if = "Option::is_none")]
149    pub finished_at: Option<DateTime<Utc>>,
150    /// Current status.
151    pub status: RunStatus,
152    /// Hardware used.
153    pub hardware: HardwareInfo,
154
155    /// Metrics recorded during training.
156    #[serde(default)]
157    pub metrics: Vec<MetricRecord>,
158    /// Artifacts produced.
159    #[serde(default)]
160    pub artifacts: Vec<ArtifactReference>,
161    /// Log URI.
162    #[serde(skip_serializing_if = "Option::is_none")]
163    pub log_uri: Option<String>,
164
165    /// Git commit hash.
166    #[serde(skip_serializing_if = "Option::is_none")]
167    pub git_commit: Option<String>,
168    /// Whether the git working directory was dirty.
169    #[serde(default)]
170    pub git_dirty: bool,
171
172    /// Error message if failed.
173    #[serde(skip_serializing_if = "Option::is_none")]
174    pub error_message: Option<String>,
175
176    /// Additional metadata.
177    #[serde(default)]
178    pub extra: HashMap<String, serde_json::Value>,
179}
180
181impl ExperimentRun {
182    /// Create a new experiment run.
183    #[must_use]
184    pub fn new(hyperparameters: Hyperparameters) -> Self {
185        Self {
186            run_id: RunId::new(),
187            recipe: None,
188            hyperparameters,
189            started_at: Utc::now(),
190            finished_at: None,
191            status: RunStatus::Pending,
192            hardware: HardwareInfo::default(),
193            metrics: Vec::new(),
194            artifacts: Vec::new(),
195            log_uri: None,
196            git_commit: None,
197            git_dirty: false,
198            error_message: None,
199            extra: HashMap::new(),
200        }
201    }
202
203    /// Create a run from a recipe.
204    #[must_use]
205    pub fn from_recipe(recipe: RecipeReference, hyperparameters: Hyperparameters) -> Self {
206        let mut run = Self::new(hyperparameters);
207        run.recipe = Some(recipe);
208        run
209    }
210
211    /// Start the run.
212    pub fn start(&mut self) {
213        self.status = RunStatus::Running;
214        self.started_at = Utc::now();
215    }
216
217    /// Complete the run successfully.
218    pub fn complete(&mut self) {
219        self.status = RunStatus::Completed;
220        self.finished_at = Some(Utc::now());
221    }
222
223    /// Mark the run as failed.
224    pub fn fail(&mut self, error: impl Into<String>) {
225        self.status = RunStatus::Failed;
226        self.finished_at = Some(Utc::now());
227        self.error_message = Some(error.into());
228    }
229
230    /// Cancel the run.
231    pub fn cancel(&mut self) {
232        self.status = RunStatus::Cancelled;
233        self.finished_at = Some(Utc::now());
234    }
235
236    /// Log a metric.
237    pub fn log_metric(&mut self, name: impl Into<String>, value: f64, step: u64) {
238        self.metrics.push(MetricRecord::new(name, value, step));
239    }
240
241    /// Get the latest value for a metric.
242    #[must_use]
243    pub fn get_metric(&self, name: &str) -> Option<f64> {
244        self.metrics.iter().filter(|m| m.name == name).max_by_key(|m| m.step).map(|m| m.value)
245    }
246
247    /// Get duration in seconds.
248    #[must_use]
249    pub fn duration_secs(&self) -> Option<i64> {
250        self.finished_at.map(|end| (end - self.started_at).num_seconds())
251    }
252
253    /// Check if the run is finished.
254    #[must_use]
255    pub fn is_finished(&self) -> bool {
256        matches!(self.status, RunStatus::Completed | RunStatus::Failed | RunStatus::Cancelled)
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn test_run_id_generation() {
266        let id1 = RunId::new();
267        let id2 = RunId::new();
268        assert_ne!(id1, id2);
269    }
270
271    #[test]
272    fn test_run_status_display() {
273        assert_eq!(RunStatus::Running.to_string(), "running");
274        assert_eq!(RunStatus::Completed.to_string(), "completed");
275        assert_eq!(RunStatus::Failed.to_string(), "failed");
276    }
277
278    #[test]
279    fn test_experiment_run_lifecycle() {
280        let params = Hyperparameters::default();
281        let mut run = ExperimentRun::new(params);
282
283        assert_eq!(run.status, RunStatus::Pending);
284        assert!(!run.is_finished());
285
286        run.start();
287        assert_eq!(run.status, RunStatus::Running);
288
289        run.log_metric("loss", 0.5, 100);
290        run.log_metric("loss", 0.3, 200);
291        run.log_metric("accuracy", 0.8, 200);
292
293        assert_eq!(run.get_metric("loss"), Some(0.3));
294        assert_eq!(run.get_metric("accuracy"), Some(0.8));
295        assert_eq!(run.get_metric("nonexistent"), None);
296
297        run.complete();
298        assert_eq!(run.status, RunStatus::Completed);
299        assert!(run.is_finished());
300        assert!(run.duration_secs().is_some());
301    }
302
303    #[test]
304    fn test_experiment_run_failure() {
305        let params = Hyperparameters::default();
306        let mut run = ExperimentRun::new(params);
307
308        run.start();
309        run.fail("Out of memory");
310
311        assert_eq!(run.status, RunStatus::Failed);
312        assert_eq!(run.error_message, Some("Out of memory".to_string()));
313        assert!(run.is_finished());
314    }
315
316    #[test]
317    fn test_experiment_run_cancel() {
318        let params = Hyperparameters::default();
319        let mut run = ExperimentRun::new(params);
320
321        run.start();
322        run.cancel();
323
324        assert_eq!(run.status, RunStatus::Cancelled);
325        assert!(run.is_finished());
326    }
327
328    #[test]
329    fn test_metric_record() {
330        let metric = MetricRecord::new("val_loss", 0.25, 1000);
331        assert_eq!(metric.name, "val_loss");
332        assert!((metric.value - 0.25).abs() < 1e-10);
333        assert_eq!(metric.step, 1000);
334    }
335
336    #[test]
337    fn test_experiment_run_serialization() {
338        let params = Hyperparameters::default();
339        let mut run = ExperimentRun::new(params);
340        run.log_metric("loss", 0.5, 100);
341
342        let json = serde_json::to_string(&run).unwrap();
343        let deserialized: ExperimentRun = serde_json::from_str(&json).unwrap();
344
345        assert_eq!(run.run_id, deserialized.run_id);
346        assert_eq!(run.metrics.len(), deserialized.metrics.len());
347    }
348}