Skip to main content

entrenar/storage/
memory.rs

1//! In-Memory Storage Backend
2//!
3//! Provides an in-memory implementation of `ExperimentStorage` for testing
4//! and environments where file-based storage is not available (e.g., WASM).
5
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicU64, Ordering};
8
9use chrono::Utc;
10use sha2::{Digest, Sha256};
11
12use super::{ExperimentStorage, MetricPoint, Result, RunStatus, StorageError};
13
14/// In-memory experiment storage backend
15///
16/// Useful for testing, fuzzing, and WASM environments where
17/// file-based storage is not available.
18#[derive(Debug, Default)]
19pub struct InMemoryStorage {
20    experiments: HashMap<String, ExperimentData>,
21    runs: HashMap<String, RunData>,
22    metrics: HashMap<String, Vec<MetricData>>, // run_id:key -> metrics
23    artifacts: HashMap<String, Vec<u8>>,       // CAS hash -> data
24    next_exp_id: AtomicU64,
25    next_run_id: AtomicU64,
26}
27
28#[derive(Debug, Clone)]
29struct ExperimentData {
30    #[allow(dead_code)]
31    name: String,
32    #[allow(dead_code)]
33    config: Option<serde_json::Value>,
34}
35
36#[derive(Debug, Clone)]
37struct RunData {
38    #[allow(dead_code)]
39    experiment_id: String,
40    status: RunStatus,
41    span_id: Option<String>,
42}
43
44#[derive(Debug, Clone)]
45struct MetricData {
46    step: u64,
47    value: f64,
48    timestamp: chrono::DateTime<Utc>,
49}
50
51impl InMemoryStorage {
52    /// Create a new in-memory storage
53    pub fn new() -> Self {
54        Self::default()
55    }
56
57    /// Get the number of experiments
58    pub fn experiment_count(&self) -> usize {
59        self.experiments.len()
60    }
61
62    /// Get the number of runs
63    pub fn run_count(&self) -> usize {
64        self.runs.len()
65    }
66
67    /// Get the number of metric entries (run_id:key combinations)
68    pub fn metric_key_count(&self) -> usize {
69        self.metrics.len()
70    }
71
72    /// Get the number of artifacts
73    pub fn artifact_count(&self) -> usize {
74        self.artifacts.len()
75    }
76
77    /// Log an experiment assignment (MA-08: A/B test logging)
78    ///
79    /// Records which experiment variant a run was assigned to (experiment_log).
80    pub fn experiment_log(&self, experiment_id: &str) -> Vec<String> {
81        self.runs
82            .iter()
83            .filter(|(_, r)| r.experiment_id == experiment_id)
84            .map(|(id, _)| format!("assignment_log: run={id} experiment={experiment_id}"))
85            .collect()
86    }
87
88    /// Compute CAS hash for artifact data
89    fn compute_hash(data: &[u8]) -> String {
90        let mut hasher = Sha256::new();
91        hasher.update(data);
92        let result = hasher.finalize();
93        format!("sha256-{}", hex::encode(result.get(..16).unwrap_or(&result))) // Use first 16 bytes
94    }
95}
96
97impl ExperimentStorage for InMemoryStorage {
98    fn create_experiment(
99        &mut self,
100        name: &str,
101        config: Option<serde_json::Value>,
102    ) -> Result<String> {
103        let id = self.next_exp_id.fetch_add(1, Ordering::SeqCst);
104        let exp_id = format!("exp-{id}");
105
106        self.experiments.insert(exp_id.clone(), ExperimentData { name: name.to_string(), config });
107
108        Ok(exp_id)
109    }
110
111    fn create_run(&mut self, experiment_id: &str) -> Result<String> {
112        if !self.experiments.contains_key(experiment_id) {
113            return Err(StorageError::ExperimentNotFound(experiment_id.to_string()));
114        }
115
116        let id = self.next_run_id.fetch_add(1, Ordering::SeqCst);
117        let run_id = format!("run-{id}");
118
119        self.runs.insert(
120            run_id.clone(),
121            RunData {
122                experiment_id: experiment_id.to_string(),
123                status: RunStatus::Pending,
124                span_id: None,
125            },
126        );
127
128        Ok(run_id)
129    }
130
131    fn start_run(&mut self, run_id: &str) -> Result<()> {
132        let run = self
133            .runs
134            .get_mut(run_id)
135            .ok_or_else(|| StorageError::RunNotFound(run_id.to_string()))?;
136
137        if run.status != RunStatus::Pending {
138            return Err(StorageError::InvalidState(format!(
139                "Run {run_id} is not in Pending state"
140            )));
141        }
142
143        run.status = RunStatus::Running;
144        Ok(())
145    }
146
147    fn complete_run(&mut self, run_id: &str, status: RunStatus) -> Result<()> {
148        let run = self
149            .runs
150            .get_mut(run_id)
151            .ok_or_else(|| StorageError::RunNotFound(run_id.to_string()))?;
152
153        if run.status != RunStatus::Running {
154            return Err(StorageError::InvalidState(format!(
155                "Run {run_id} is not in Running state"
156            )));
157        }
158
159        run.status = status;
160        Ok(())
161    }
162
163    fn log_metric(&mut self, run_id: &str, key: &str, step: u64, value: f64) -> Result<()> {
164        if !self.runs.contains_key(run_id) {
165            return Err(StorageError::RunNotFound(run_id.to_string()));
166        }
167
168        let metric_key = format!("{run_id}:{key}");
169        let metrics = self.metrics.entry(metric_key).or_default();
170
171        metrics.push(MetricData { step, value, timestamp: Utc::now() });
172
173        Ok(())
174    }
175
176    fn log_artifact(&mut self, run_id: &str, key: &str, data: &[u8]) -> Result<String> {
177        if !self.runs.contains_key(run_id) {
178            return Err(StorageError::RunNotFound(run_id.to_string()));
179        }
180
181        let hash = Self::compute_hash(data);
182
183        // Store with composite key for retrieval
184        let artifact_key = format!("{run_id}:{key}:{hash}");
185        self.artifacts.insert(artifact_key, data.to_vec());
186
187        Ok(hash)
188    }
189
190    fn get_metrics(&self, run_id: &str, key: &str) -> Result<Vec<MetricPoint>> {
191        if !self.runs.contains_key(run_id) {
192            return Err(StorageError::RunNotFound(run_id.to_string()));
193        }
194
195        let metric_key = format!("{run_id}:{key}");
196        let metrics = self.metrics.get(&metric_key).cloned().unwrap_or_default();
197
198        let mut points: Vec<MetricPoint> = metrics
199            .into_iter()
200            .map(|m| MetricPoint::with_timestamp(m.step, m.value, m.timestamp))
201            .collect();
202
203        // Sort by step
204        points.sort_by_key(|p| p.step);
205
206        Ok(points)
207    }
208
209    fn get_run_status(&self, run_id: &str) -> Result<RunStatus> {
210        self.runs
211            .get(run_id)
212            .map(|r| r.status)
213            .ok_or_else(|| StorageError::RunNotFound(run_id.to_string()))
214    }
215
216    fn set_span_id(&mut self, run_id: &str, span_id: &str) -> Result<()> {
217        let run = self
218            .runs
219            .get_mut(run_id)
220            .ok_or_else(|| StorageError::RunNotFound(run_id.to_string()))?;
221
222        run.span_id = Some(span_id.to_string());
223        Ok(())
224    }
225
226    fn get_span_id(&self, run_id: &str) -> Result<Option<String>> {
227        self.runs
228            .get(run_id)
229            .map(|r| r.span_id.clone())
230            .ok_or_else(|| StorageError::RunNotFound(run_id.to_string()))
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_in_memory_storage_new() {
240        let storage = InMemoryStorage::new();
241        assert_eq!(storage.experiment_count(), 0);
242        assert_eq!(storage.run_count(), 0);
243    }
244
245    #[test]
246    fn test_create_experiment() {
247        let mut storage = InMemoryStorage::new();
248        let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
249
250        assert!(exp_id.starts_with("exp-"));
251        assert_eq!(storage.experiment_count(), 1);
252    }
253
254    #[test]
255    fn test_create_experiment_with_config() {
256        let mut storage = InMemoryStorage::new();
257        let config = serde_json::json!({"learning_rate": 0.001});
258        let exp_id =
259            storage.create_experiment("test-exp", Some(config)).expect("config should be valid");
260
261        assert!(exp_id.starts_with("exp-"));
262    }
263
264    #[test]
265    fn test_create_run() {
266        let mut storage = InMemoryStorage::new();
267        let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
268        let run_id = storage.create_run(&exp_id).expect("operation should succeed");
269
270        assert!(run_id.starts_with("run-"));
271        assert_eq!(storage.run_count(), 1);
272        assert_eq!(
273            storage.get_run_status(&run_id).expect("operation should succeed"),
274            RunStatus::Pending
275        );
276    }
277
278    #[test]
279    fn test_create_run_invalid_experiment() {
280        let mut storage = InMemoryStorage::new();
281        let result = storage.create_run("fake-exp");
282
283        assert!(result.is_err());
284        match result.unwrap_err() {
285            StorageError::ExperimentNotFound(id) => assert_eq!(id, "fake-exp"),
286            e => panic!("Expected ExperimentNotFound, got {e:?}"),
287        }
288    }
289
290    #[test]
291    fn test_start_run() {
292        let mut storage = InMemoryStorage::new();
293        let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
294        let run_id = storage.create_run(&exp_id).expect("operation should succeed");
295
296        storage.start_run(&run_id).expect("operation should succeed");
297        assert_eq!(
298            storage.get_run_status(&run_id).expect("operation should succeed"),
299            RunStatus::Running
300        );
301    }
302
303    #[test]
304    fn test_start_run_invalid_state() {
305        let mut storage = InMemoryStorage::new();
306        let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
307        let run_id = storage.create_run(&exp_id).expect("operation should succeed");
308
309        storage.start_run(&run_id).expect("operation should succeed");
310        let result = storage.start_run(&run_id); // Already started
311
312        assert!(result.is_err());
313        match result.unwrap_err() {
314            StorageError::InvalidState(_) => {}
315            e => panic!("Expected InvalidState, got {e:?}"),
316        }
317    }
318
319    #[test]
320    fn test_complete_run() {
321        let mut storage = InMemoryStorage::new();
322        let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
323        let run_id = storage.create_run(&exp_id).expect("operation should succeed");
324
325        storage.start_run(&run_id).expect("operation should succeed");
326        storage.complete_run(&run_id, RunStatus::Success).expect("operation should succeed");
327
328        assert_eq!(
329            storage.get_run_status(&run_id).expect("operation should succeed"),
330            RunStatus::Success
331        );
332    }
333
334    #[test]
335    fn test_complete_run_failed() {
336        let mut storage = InMemoryStorage::new();
337        let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
338        let run_id = storage.create_run(&exp_id).expect("operation should succeed");
339
340        storage.start_run(&run_id).expect("operation should succeed");
341        storage.complete_run(&run_id, RunStatus::Failed).expect("operation should succeed");
342
343        assert_eq!(
344            storage.get_run_status(&run_id).expect("operation should succeed"),
345            RunStatus::Failed
346        );
347    }
348
349    #[test]
350    fn test_complete_run_invalid_state() {
351        let mut storage = InMemoryStorage::new();
352        let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
353        let run_id = storage.create_run(&exp_id).expect("operation should succeed");
354
355        // Try to complete without starting
356        let result = storage.complete_run(&run_id, RunStatus::Success);
357
358        assert!(result.is_err());
359        match result.unwrap_err() {
360            StorageError::InvalidState(_) => {}
361            e => panic!("Expected InvalidState, got {e:?}"),
362        }
363    }
364
365    #[test]
366    fn test_log_metric() {
367        let mut storage = InMemoryStorage::new();
368        let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
369        let run_id = storage.create_run(&exp_id).expect("operation should succeed");
370
371        storage.log_metric(&run_id, "loss", 0, 0.5).expect("operation should succeed");
372        storage.log_metric(&run_id, "loss", 1, 0.4).expect("operation should succeed");
373
374        let metrics = storage.get_metrics(&run_id, "loss").expect("operation should succeed");
375        assert_eq!(metrics.len(), 2);
376        assert_eq!(metrics[0].step, 0);
377        assert!((metrics[0].value - 0.5).abs() < f64::EPSILON);
378        assert_eq!(metrics[1].step, 1);
379        assert!((metrics[1].value - 0.4).abs() < f64::EPSILON);
380    }
381
382    #[test]
383    fn test_log_metric_invalid_run() {
384        let mut storage = InMemoryStorage::new();
385        let result = storage.log_metric("fake-run", "loss", 0, 0.5);
386
387        assert!(result.is_err());
388        match result.unwrap_err() {
389            StorageError::RunNotFound(id) => assert_eq!(id, "fake-run"),
390            e => panic!("Expected RunNotFound, got {e:?}"),
391        }
392    }
393
394    #[test]
395    fn test_get_metrics_ordering() {
396        let mut storage = InMemoryStorage::new();
397        let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
398        let run_id = storage.create_run(&exp_id).expect("operation should succeed");
399
400        // Log out of order
401        storage.log_metric(&run_id, "loss", 2, 0.3).expect("operation should succeed");
402        storage.log_metric(&run_id, "loss", 0, 0.5).expect("operation should succeed");
403        storage.log_metric(&run_id, "loss", 1, 0.4).expect("operation should succeed");
404
405        let metrics = storage.get_metrics(&run_id, "loss").expect("operation should succeed");
406        assert_eq!(metrics[0].step, 0);
407        assert_eq!(metrics[1].step, 1);
408        assert_eq!(metrics[2].step, 2);
409    }
410
411    #[test]
412    fn test_get_metrics_empty() {
413        let mut storage = InMemoryStorage::new();
414        let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
415        let run_id = storage.create_run(&exp_id).expect("operation should succeed");
416
417        let metrics = storage.get_metrics(&run_id, "loss").expect("operation should succeed");
418        assert!(metrics.is_empty());
419    }
420
421    #[test]
422    fn test_log_artifact() {
423        let mut storage = InMemoryStorage::new();
424        let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
425        let run_id = storage.create_run(&exp_id).expect("operation should succeed");
426
427        let data = b"model weights data";
428        let hash =
429            storage.log_artifact(&run_id, "model.bin", data).expect("operation should succeed");
430
431        assert!(hash.starts_with("sha256-"));
432        assert_eq!(storage.artifact_count(), 1);
433    }
434
435    #[test]
436    fn test_log_artifact_invalid_run() {
437        let mut storage = InMemoryStorage::new();
438        let result = storage.log_artifact("fake-run", "model.bin", b"data");
439
440        assert!(result.is_err());
441    }
442
443    #[test]
444    fn test_set_and_get_span_id() {
445        let mut storage = InMemoryStorage::new();
446        let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
447        let run_id = storage.create_run(&exp_id).expect("operation should succeed");
448
449        assert!(storage.get_span_id(&run_id).expect("operation should succeed").is_none());
450
451        storage.set_span_id(&run_id, "span-12345").expect("operation should succeed");
452
453        assert_eq!(
454            storage.get_span_id(&run_id).expect("operation should succeed"),
455            Some("span-12345".to_string())
456        );
457    }
458
459    #[test]
460    fn test_multiple_experiments_and_runs() {
461        let mut storage = InMemoryStorage::new();
462
463        let exp1 = storage.create_experiment("exp-1", None).expect("operation should succeed");
464        let exp2 = storage.create_experiment("exp-2", None).expect("operation should succeed");
465
466        let run1 = storage.create_run(&exp1).expect("operation should succeed");
467        let run2 = storage.create_run(&exp1).expect("operation should succeed");
468        let run3 = storage.create_run(&exp2).expect("operation should succeed");
469
470        assert_eq!(storage.experiment_count(), 2);
471        assert_eq!(storage.run_count(), 3);
472
473        // Each run is independent
474        storage.start_run(&run1).expect("operation should succeed");
475        storage.start_run(&run2).expect("operation should succeed");
476
477        assert_eq!(
478            storage.get_run_status(&run1).expect("operation should succeed"),
479            RunStatus::Running
480        );
481        assert_eq!(
482            storage.get_run_status(&run2).expect("operation should succeed"),
483            RunStatus::Running
484        );
485        assert_eq!(
486            storage.get_run_status(&run3).expect("operation should succeed"),
487            RunStatus::Pending
488        );
489    }
490}