1use super::{
7 ComputeDevice, CostMetrics, CpuArchitecture, EnergyMetrics, ExperimentError, ModelParadigm,
8 PlatformEfficiency,
9};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ExperimentRun {
16 pub run_id: String,
18 pub experiment_name: String,
20 pub paradigm: ModelParadigm,
22 pub device: ComputeDevice,
24 pub platform: PlatformEfficiency,
26 pub energy: Option<EnergyMetrics>,
28 pub cost: Option<CostMetrics>,
30 pub hyperparameters: HashMap<String, serde_json::Value>,
32 pub metrics: HashMap<String, f64>,
34 pub tags: Vec<String>,
36 pub started_at: String,
38 pub ended_at: Option<String>,
40 pub status: RunStatus,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46pub enum RunStatus {
47 Running,
48 Completed,
49 Failed,
50 Cancelled,
51}
52
53impl ExperimentRun {
54 pub fn new(
56 run_id: impl Into<String>,
57 experiment_name: impl Into<String>,
58 paradigm: ModelParadigm,
59 device: ComputeDevice,
60 ) -> Self {
61 Self {
62 run_id: run_id.into(),
63 experiment_name: experiment_name.into(),
64 paradigm,
65 device,
66 platform: PlatformEfficiency::Server,
67 energy: None,
68 cost: None,
69 hyperparameters: HashMap::new(),
70 metrics: HashMap::new(),
71 tags: Vec::new(),
72 started_at: chrono::Utc::now().to_rfc3339(),
73 ended_at: None,
74 status: RunStatus::Running,
75 }
76 }
77
78 pub fn log_metric(&mut self, name: impl Into<String>, value: f64) {
80 self.metrics.insert(name.into(), value);
81 }
82
83 pub fn log_param(&mut self, name: impl Into<String>, value: serde_json::Value) {
85 self.hyperparameters.insert(name.into(), value);
86 }
87
88 pub fn complete(&mut self) {
90 self.ended_at = Some(chrono::Utc::now().to_rfc3339());
91 self.status = RunStatus::Completed;
92 }
93
94 pub fn fail(&mut self) {
96 self.ended_at = Some(chrono::Utc::now().to_rfc3339());
97 self.status = RunStatus::Failed;
98 }
99}
100
101pub trait ExperimentStorage: Send + Sync {
103 fn store_run(&self, run: &ExperimentRun) -> Result<(), ExperimentError>;
105
106 fn get_run(&self, run_id: &str) -> Result<Option<ExperimentRun>, ExperimentError>;
108
109 fn list_runs(&self, experiment_name: &str) -> Result<Vec<ExperimentRun>, ExperimentError>;
111
112 fn delete_run(&self, run_id: &str) -> Result<(), ExperimentError>;
114}
115
116#[derive(Debug, Default)]
118pub struct InMemoryExperimentStorage {
119 runs: std::sync::RwLock<HashMap<String, ExperimentRun>>,
120}
121
122impl InMemoryExperimentStorage {
123 pub fn new() -> Self {
125 Self::default()
126 }
127}
128
129impl ExperimentStorage for InMemoryExperimentStorage {
130 fn store_run(&self, run: &ExperimentRun) -> Result<(), ExperimentError> {
131 let mut runs = self
132 .runs
133 .write()
134 .map_err(|e| ExperimentError::StorageError(format!("Lock error: {}", e)))?;
135 runs.insert(run.run_id.clone(), run.clone());
136 Ok(())
137 }
138
139 fn get_run(&self, run_id: &str) -> Result<Option<ExperimentRun>, ExperimentError> {
140 let runs = self
141 .runs
142 .read()
143 .map_err(|e| ExperimentError::StorageError(format!("Lock error: {}", e)))?;
144 Ok(runs.get(run_id).cloned())
145 }
146
147 fn list_runs(&self, experiment_name: &str) -> Result<Vec<ExperimentRun>, ExperimentError> {
148 let runs = self
149 .runs
150 .read()
151 .map_err(|e| ExperimentError::StorageError(format!("Lock error: {}", e)))?;
152 Ok(runs.values().filter(|r| r.experiment_name == experiment_name).cloned().collect())
153 }
154
155 fn delete_run(&self, run_id: &str) -> Result<(), ExperimentError> {
156 let mut runs = self
157 .runs
158 .write()
159 .map_err(|e| ExperimentError::StorageError(format!("Lock error: {}", e)))?;
160 runs.remove(run_id);
161 Ok(())
162 }
163}
164
165#[cfg(test)]
166mod lock_poison_tests {
167 use super::*;
168 use crate::experiment::{ComputeDevice, CpuArchitecture, ModelParadigm};
169
170 fn poison_storage() -> InMemoryExperimentStorage {
172 let storage = InMemoryExperimentStorage::new();
173 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
174 let _guard = storage.runs.write().expect("unexpected failure");
175 panic!("intentional poison");
176 }));
177 storage
178 }
179
180 fn test_device() -> ComputeDevice {
181 ComputeDevice::Cpu { cores: 1, threads_per_core: 1, architecture: CpuArchitecture::X86_64 }
182 }
183
184 #[test]
185 fn test_poisoned_lock_store_run() {
186 let storage = poison_storage();
187 let run = ExperimentRun::new("r1", "exp", ModelParadigm::TraditionalML, test_device());
188 let result = storage.store_run(&run);
189 assert!(result.is_err());
190 match result.unwrap_err() {
191 ExperimentError::StorageError(msg) => assert!(msg.contains("Lock error")),
192 other => panic!("Expected StorageError, got: {:?}", other),
193 }
194 }
195
196 #[test]
197 fn test_poisoned_lock_get_run() {
198 let storage = poison_storage();
199 let result = storage.get_run("any");
200 assert!(result.is_err());
201 match result.unwrap_err() {
202 ExperimentError::StorageError(msg) => assert!(msg.contains("Lock error")),
203 other => panic!("Expected StorageError, got: {:?}", other),
204 }
205 }
206
207 #[test]
208 fn test_poisoned_lock_list_runs() {
209 let storage = poison_storage();
210 let result = storage.list_runs("exp");
211 assert!(result.is_err());
212 match result.unwrap_err() {
213 ExperimentError::StorageError(msg) => assert!(msg.contains("Lock error")),
214 other => panic!("Expected StorageError, got: {:?}", other),
215 }
216 }
217
218 #[test]
219 fn test_poisoned_lock_delete_run() {
220 let storage = poison_storage();
221 let result = storage.delete_run("any");
222 assert!(result.is_err());
223 match result.unwrap_err() {
224 ExperimentError::StorageError(msg) => assert!(msg.contains("Lock error")),
225 other => panic!("Expected StorageError, got: {:?}", other),
226 }
227 }
228}