Skip to main content

batuta/experiment/
run.rs

1//! Experiment run tracking and storage.
2//!
3//! This module provides types for tracking individual experiment runs,
4//! including metrics, hyperparameters, and storage backends.
5
6use super::{
7    ComputeDevice, CostMetrics, CpuArchitecture, EnergyMetrics, ExperimentError, ModelParadigm,
8    PlatformEfficiency,
9};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// Experiment run with full tracking metadata
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ExperimentRun {
16    /// Unique run ID
17    pub run_id: String,
18    /// Experiment name
19    pub experiment_name: String,
20    /// Model paradigm
21    pub paradigm: ModelParadigm,
22    /// Compute device used
23    pub device: ComputeDevice,
24    /// Platform efficiency class
25    pub platform: PlatformEfficiency,
26    /// Energy metrics
27    pub energy: Option<EnergyMetrics>,
28    /// Cost metrics
29    pub cost: Option<CostMetrics>,
30    /// Hyperparameters
31    pub hyperparameters: HashMap<String, serde_json::Value>,
32    /// Metrics collected
33    pub metrics: HashMap<String, f64>,
34    /// Tags for organization
35    pub tags: Vec<String>,
36    /// Start time
37    pub started_at: String,
38    /// End time
39    pub ended_at: Option<String>,
40    /// Status
41    pub status: RunStatus,
42}
43
44/// Run status
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46pub enum RunStatus {
47    Running,
48    Completed,
49    Failed,
50    Cancelled,
51}
52
53impl ExperimentRun {
54    /// Create a new experiment run
55    pub fn new(
56        run_id: impl Into<String>,
57        experiment_name: impl Into<String>,
58        paradigm: ModelParadigm,
59        device: ComputeDevice,
60    ) -> Self {
61        Self {
62            run_id: run_id.into(),
63            experiment_name: experiment_name.into(),
64            paradigm,
65            device,
66            platform: PlatformEfficiency::Server,
67            energy: None,
68            cost: None,
69            hyperparameters: HashMap::new(),
70            metrics: HashMap::new(),
71            tags: Vec::new(),
72            started_at: chrono::Utc::now().to_rfc3339(),
73            ended_at: None,
74            status: RunStatus::Running,
75        }
76    }
77
78    /// Log a metric
79    pub fn log_metric(&mut self, name: impl Into<String>, value: f64) {
80        self.metrics.insert(name.into(), value);
81    }
82
83    /// Log a hyperparameter
84    pub fn log_param(&mut self, name: impl Into<String>, value: serde_json::Value) {
85        self.hyperparameters.insert(name.into(), value);
86    }
87
88    /// Complete the run
89    pub fn complete(&mut self) {
90        self.ended_at = Some(chrono::Utc::now().to_rfc3339());
91        self.status = RunStatus::Completed;
92    }
93
94    /// Mark the run as failed
95    pub fn fail(&mut self) {
96        self.ended_at = Some(chrono::Utc::now().to_rfc3339());
97        self.status = RunStatus::Failed;
98    }
99}
100
101/// Experiment storage backend trait
102pub trait ExperimentStorage: Send + Sync {
103    /// Store an experiment run
104    fn store_run(&self, run: &ExperimentRun) -> Result<(), ExperimentError>;
105
106    /// Retrieve a run by ID
107    fn get_run(&self, run_id: &str) -> Result<Option<ExperimentRun>, ExperimentError>;
108
109    /// List runs for an experiment
110    fn list_runs(&self, experiment_name: &str) -> Result<Vec<ExperimentRun>, ExperimentError>;
111
112    /// Delete a run
113    fn delete_run(&self, run_id: &str) -> Result<(), ExperimentError>;
114}
115
116/// In-memory experiment storage for testing
117#[derive(Debug, Default)]
118pub struct InMemoryExperimentStorage {
119    runs: std::sync::RwLock<HashMap<String, ExperimentRun>>,
120}
121
122impl InMemoryExperimentStorage {
123    /// Create new in-memory storage
124    pub fn new() -> Self {
125        Self::default()
126    }
127}
128
129impl ExperimentStorage for InMemoryExperimentStorage {
130    fn store_run(&self, run: &ExperimentRun) -> Result<(), ExperimentError> {
131        let mut runs = self
132            .runs
133            .write()
134            .map_err(|e| ExperimentError::StorageError(format!("Lock error: {}", e)))?;
135        runs.insert(run.run_id.clone(), run.clone());
136        Ok(())
137    }
138
139    fn get_run(&self, run_id: &str) -> Result<Option<ExperimentRun>, ExperimentError> {
140        let runs = self
141            .runs
142            .read()
143            .map_err(|e| ExperimentError::StorageError(format!("Lock error: {}", e)))?;
144        Ok(runs.get(run_id).cloned())
145    }
146
147    fn list_runs(&self, experiment_name: &str) -> Result<Vec<ExperimentRun>, ExperimentError> {
148        let runs = self
149            .runs
150            .read()
151            .map_err(|e| ExperimentError::StorageError(format!("Lock error: {}", e)))?;
152        Ok(runs.values().filter(|r| r.experiment_name == experiment_name).cloned().collect())
153    }
154
155    fn delete_run(&self, run_id: &str) -> Result<(), ExperimentError> {
156        let mut runs = self
157            .runs
158            .write()
159            .map_err(|e| ExperimentError::StorageError(format!("Lock error: {}", e)))?;
160        runs.remove(run_id);
161        Ok(())
162    }
163}
164
165#[cfg(test)]
166mod lock_poison_tests {
167    use super::*;
168    use crate::experiment::{ComputeDevice, CpuArchitecture, ModelParadigm};
169
170    /// Helper: poison the RwLock by panicking while holding write guard
171    fn poison_storage() -> InMemoryExperimentStorage {
172        let storage = InMemoryExperimentStorage::new();
173        let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
174            let _guard = storage.runs.write().expect("unexpected failure");
175            panic!("intentional poison");
176        }));
177        storage
178    }
179
180    fn test_device() -> ComputeDevice {
181        ComputeDevice::Cpu { cores: 1, threads_per_core: 1, architecture: CpuArchitecture::X86_64 }
182    }
183
184    #[test]
185    fn test_poisoned_lock_store_run() {
186        let storage = poison_storage();
187        let run = ExperimentRun::new("r1", "exp", ModelParadigm::TraditionalML, test_device());
188        let result = storage.store_run(&run);
189        assert!(result.is_err());
190        match result.unwrap_err() {
191            ExperimentError::StorageError(msg) => assert!(msg.contains("Lock error")),
192            other => panic!("Expected StorageError, got: {:?}", other),
193        }
194    }
195
196    #[test]
197    fn test_poisoned_lock_get_run() {
198        let storage = poison_storage();
199        let result = storage.get_run("any");
200        assert!(result.is_err());
201        match result.unwrap_err() {
202            ExperimentError::StorageError(msg) => assert!(msg.contains("Lock error")),
203            other => panic!("Expected StorageError, got: {:?}", other),
204        }
205    }
206
207    #[test]
208    fn test_poisoned_lock_list_runs() {
209        let storage = poison_storage();
210        let result = storage.list_runs("exp");
211        assert!(result.is_err());
212        match result.unwrap_err() {
213            ExperimentError::StorageError(msg) => assert!(msg.contains("Lock error")),
214            other => panic!("Expected StorageError, got: {:?}", other),
215        }
216    }
217
218    #[test]
219    fn test_poisoned_lock_delete_run() {
220        let storage = poison_storage();
221        let result = storage.delete_run("any");
222        assert!(result.is_err());
223        match result.unwrap_err() {
224            ExperimentError::StorageError(msg) => assert!(msg.contains("Lock error")),
225            other => panic!("Expected StorageError, got: {:?}", other),
226        }
227    }
228}