1use std::collections::HashMap;
7use std::sync::atomic::{AtomicU64, Ordering};
8
9use chrono::Utc;
10use sha2::{Digest, Sha256};
11
12use super::{ExperimentStorage, MetricPoint, Result, RunStatus, StorageError};
13
14#[derive(Debug, Default)]
19pub struct InMemoryStorage {
20 experiments: HashMap<String, ExperimentData>,
21 runs: HashMap<String, RunData>,
22 metrics: HashMap<String, Vec<MetricData>>, artifacts: HashMap<String, Vec<u8>>, next_exp_id: AtomicU64,
25 next_run_id: AtomicU64,
26}
27
28#[derive(Debug, Clone)]
29struct ExperimentData {
30 #[allow(dead_code)]
31 name: String,
32 #[allow(dead_code)]
33 config: Option<serde_json::Value>,
34}
35
36#[derive(Debug, Clone)]
37struct RunData {
38 #[allow(dead_code)]
39 experiment_id: String,
40 status: RunStatus,
41 span_id: Option<String>,
42}
43
44#[derive(Debug, Clone)]
45struct MetricData {
46 step: u64,
47 value: f64,
48 timestamp: chrono::DateTime<Utc>,
49}
50
51impl InMemoryStorage {
52 pub fn new() -> Self {
54 Self::default()
55 }
56
57 pub fn experiment_count(&self) -> usize {
59 self.experiments.len()
60 }
61
62 pub fn run_count(&self) -> usize {
64 self.runs.len()
65 }
66
67 pub fn metric_key_count(&self) -> usize {
69 self.metrics.len()
70 }
71
72 pub fn artifact_count(&self) -> usize {
74 self.artifacts.len()
75 }
76
77 pub fn experiment_log(&self, experiment_id: &str) -> Vec<String> {
81 self.runs
82 .iter()
83 .filter(|(_, r)| r.experiment_id == experiment_id)
84 .map(|(id, _)| format!("assignment_log: run={id} experiment={experiment_id}"))
85 .collect()
86 }
87
88 fn compute_hash(data: &[u8]) -> String {
90 let mut hasher = Sha256::new();
91 hasher.update(data);
92 let result = hasher.finalize();
93 format!("sha256-{}", hex::encode(result.get(..16).unwrap_or(&result))) }
95}
96
97impl ExperimentStorage for InMemoryStorage {
98 fn create_experiment(
99 &mut self,
100 name: &str,
101 config: Option<serde_json::Value>,
102 ) -> Result<String> {
103 let id = self.next_exp_id.fetch_add(1, Ordering::SeqCst);
104 let exp_id = format!("exp-{id}");
105
106 self.experiments.insert(exp_id.clone(), ExperimentData { name: name.to_string(), config });
107
108 Ok(exp_id)
109 }
110
111 fn create_run(&mut self, experiment_id: &str) -> Result<String> {
112 if !self.experiments.contains_key(experiment_id) {
113 return Err(StorageError::ExperimentNotFound(experiment_id.to_string()));
114 }
115
116 let id = self.next_run_id.fetch_add(1, Ordering::SeqCst);
117 let run_id = format!("run-{id}");
118
119 self.runs.insert(
120 run_id.clone(),
121 RunData {
122 experiment_id: experiment_id.to_string(),
123 status: RunStatus::Pending,
124 span_id: None,
125 },
126 );
127
128 Ok(run_id)
129 }
130
131 fn start_run(&mut self, run_id: &str) -> Result<()> {
132 let run = self
133 .runs
134 .get_mut(run_id)
135 .ok_or_else(|| StorageError::RunNotFound(run_id.to_string()))?;
136
137 if run.status != RunStatus::Pending {
138 return Err(StorageError::InvalidState(format!(
139 "Run {run_id} is not in Pending state"
140 )));
141 }
142
143 run.status = RunStatus::Running;
144 Ok(())
145 }
146
147 fn complete_run(&mut self, run_id: &str, status: RunStatus) -> Result<()> {
148 let run = self
149 .runs
150 .get_mut(run_id)
151 .ok_or_else(|| StorageError::RunNotFound(run_id.to_string()))?;
152
153 if run.status != RunStatus::Running {
154 return Err(StorageError::InvalidState(format!(
155 "Run {run_id} is not in Running state"
156 )));
157 }
158
159 run.status = status;
160 Ok(())
161 }
162
163 fn log_metric(&mut self, run_id: &str, key: &str, step: u64, value: f64) -> Result<()> {
164 if !self.runs.contains_key(run_id) {
165 return Err(StorageError::RunNotFound(run_id.to_string()));
166 }
167
168 let metric_key = format!("{run_id}:{key}");
169 let metrics = self.metrics.entry(metric_key).or_default();
170
171 metrics.push(MetricData { step, value, timestamp: Utc::now() });
172
173 Ok(())
174 }
175
176 fn log_artifact(&mut self, run_id: &str, key: &str, data: &[u8]) -> Result<String> {
177 if !self.runs.contains_key(run_id) {
178 return Err(StorageError::RunNotFound(run_id.to_string()));
179 }
180
181 let hash = Self::compute_hash(data);
182
183 let artifact_key = format!("{run_id}:{key}:{hash}");
185 self.artifacts.insert(artifact_key, data.to_vec());
186
187 Ok(hash)
188 }
189
190 fn get_metrics(&self, run_id: &str, key: &str) -> Result<Vec<MetricPoint>> {
191 if !self.runs.contains_key(run_id) {
192 return Err(StorageError::RunNotFound(run_id.to_string()));
193 }
194
195 let metric_key = format!("{run_id}:{key}");
196 let metrics = self.metrics.get(&metric_key).cloned().unwrap_or_default();
197
198 let mut points: Vec<MetricPoint> = metrics
199 .into_iter()
200 .map(|m| MetricPoint::with_timestamp(m.step, m.value, m.timestamp))
201 .collect();
202
203 points.sort_by_key(|p| p.step);
205
206 Ok(points)
207 }
208
209 fn get_run_status(&self, run_id: &str) -> Result<RunStatus> {
210 self.runs
211 .get(run_id)
212 .map(|r| r.status)
213 .ok_or_else(|| StorageError::RunNotFound(run_id.to_string()))
214 }
215
216 fn set_span_id(&mut self, run_id: &str, span_id: &str) -> Result<()> {
217 let run = self
218 .runs
219 .get_mut(run_id)
220 .ok_or_else(|| StorageError::RunNotFound(run_id.to_string()))?;
221
222 run.span_id = Some(span_id.to_string());
223 Ok(())
224 }
225
226 fn get_span_id(&self, run_id: &str) -> Result<Option<String>> {
227 self.runs
228 .get(run_id)
229 .map(|r| r.span_id.clone())
230 .ok_or_else(|| StorageError::RunNotFound(run_id.to_string()))
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn test_in_memory_storage_new() {
240 let storage = InMemoryStorage::new();
241 assert_eq!(storage.experiment_count(), 0);
242 assert_eq!(storage.run_count(), 0);
243 }
244
245 #[test]
246 fn test_create_experiment() {
247 let mut storage = InMemoryStorage::new();
248 let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
249
250 assert!(exp_id.starts_with("exp-"));
251 assert_eq!(storage.experiment_count(), 1);
252 }
253
254 #[test]
255 fn test_create_experiment_with_config() {
256 let mut storage = InMemoryStorage::new();
257 let config = serde_json::json!({"learning_rate": 0.001});
258 let exp_id =
259 storage.create_experiment("test-exp", Some(config)).expect("config should be valid");
260
261 assert!(exp_id.starts_with("exp-"));
262 }
263
264 #[test]
265 fn test_create_run() {
266 let mut storage = InMemoryStorage::new();
267 let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
268 let run_id = storage.create_run(&exp_id).expect("operation should succeed");
269
270 assert!(run_id.starts_with("run-"));
271 assert_eq!(storage.run_count(), 1);
272 assert_eq!(
273 storage.get_run_status(&run_id).expect("operation should succeed"),
274 RunStatus::Pending
275 );
276 }
277
278 #[test]
279 fn test_create_run_invalid_experiment() {
280 let mut storage = InMemoryStorage::new();
281 let result = storage.create_run("fake-exp");
282
283 assert!(result.is_err());
284 match result.unwrap_err() {
285 StorageError::ExperimentNotFound(id) => assert_eq!(id, "fake-exp"),
286 e => panic!("Expected ExperimentNotFound, got {e:?}"),
287 }
288 }
289
290 #[test]
291 fn test_start_run() {
292 let mut storage = InMemoryStorage::new();
293 let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
294 let run_id = storage.create_run(&exp_id).expect("operation should succeed");
295
296 storage.start_run(&run_id).expect("operation should succeed");
297 assert_eq!(
298 storage.get_run_status(&run_id).expect("operation should succeed"),
299 RunStatus::Running
300 );
301 }
302
303 #[test]
304 fn test_start_run_invalid_state() {
305 let mut storage = InMemoryStorage::new();
306 let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
307 let run_id = storage.create_run(&exp_id).expect("operation should succeed");
308
309 storage.start_run(&run_id).expect("operation should succeed");
310 let result = storage.start_run(&run_id); assert!(result.is_err());
313 match result.unwrap_err() {
314 StorageError::InvalidState(_) => {}
315 e => panic!("Expected InvalidState, got {e:?}"),
316 }
317 }
318
319 #[test]
320 fn test_complete_run() {
321 let mut storage = InMemoryStorage::new();
322 let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
323 let run_id = storage.create_run(&exp_id).expect("operation should succeed");
324
325 storage.start_run(&run_id).expect("operation should succeed");
326 storage.complete_run(&run_id, RunStatus::Success).expect("operation should succeed");
327
328 assert_eq!(
329 storage.get_run_status(&run_id).expect("operation should succeed"),
330 RunStatus::Success
331 );
332 }
333
334 #[test]
335 fn test_complete_run_failed() {
336 let mut storage = InMemoryStorage::new();
337 let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
338 let run_id = storage.create_run(&exp_id).expect("operation should succeed");
339
340 storage.start_run(&run_id).expect("operation should succeed");
341 storage.complete_run(&run_id, RunStatus::Failed).expect("operation should succeed");
342
343 assert_eq!(
344 storage.get_run_status(&run_id).expect("operation should succeed"),
345 RunStatus::Failed
346 );
347 }
348
349 #[test]
350 fn test_complete_run_invalid_state() {
351 let mut storage = InMemoryStorage::new();
352 let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
353 let run_id = storage.create_run(&exp_id).expect("operation should succeed");
354
355 let result = storage.complete_run(&run_id, RunStatus::Success);
357
358 assert!(result.is_err());
359 match result.unwrap_err() {
360 StorageError::InvalidState(_) => {}
361 e => panic!("Expected InvalidState, got {e:?}"),
362 }
363 }
364
365 #[test]
366 fn test_log_metric() {
367 let mut storage = InMemoryStorage::new();
368 let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
369 let run_id = storage.create_run(&exp_id).expect("operation should succeed");
370
371 storage.log_metric(&run_id, "loss", 0, 0.5).expect("operation should succeed");
372 storage.log_metric(&run_id, "loss", 1, 0.4).expect("operation should succeed");
373
374 let metrics = storage.get_metrics(&run_id, "loss").expect("operation should succeed");
375 assert_eq!(metrics.len(), 2);
376 assert_eq!(metrics[0].step, 0);
377 assert!((metrics[0].value - 0.5).abs() < f64::EPSILON);
378 assert_eq!(metrics[1].step, 1);
379 assert!((metrics[1].value - 0.4).abs() < f64::EPSILON);
380 }
381
382 #[test]
383 fn test_log_metric_invalid_run() {
384 let mut storage = InMemoryStorage::new();
385 let result = storage.log_metric("fake-run", "loss", 0, 0.5);
386
387 assert!(result.is_err());
388 match result.unwrap_err() {
389 StorageError::RunNotFound(id) => assert_eq!(id, "fake-run"),
390 e => panic!("Expected RunNotFound, got {e:?}"),
391 }
392 }
393
394 #[test]
395 fn test_get_metrics_ordering() {
396 let mut storage = InMemoryStorage::new();
397 let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
398 let run_id = storage.create_run(&exp_id).expect("operation should succeed");
399
400 storage.log_metric(&run_id, "loss", 2, 0.3).expect("operation should succeed");
402 storage.log_metric(&run_id, "loss", 0, 0.5).expect("operation should succeed");
403 storage.log_metric(&run_id, "loss", 1, 0.4).expect("operation should succeed");
404
405 let metrics = storage.get_metrics(&run_id, "loss").expect("operation should succeed");
406 assert_eq!(metrics[0].step, 0);
407 assert_eq!(metrics[1].step, 1);
408 assert_eq!(metrics[2].step, 2);
409 }
410
411 #[test]
412 fn test_get_metrics_empty() {
413 let mut storage = InMemoryStorage::new();
414 let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
415 let run_id = storage.create_run(&exp_id).expect("operation should succeed");
416
417 let metrics = storage.get_metrics(&run_id, "loss").expect("operation should succeed");
418 assert!(metrics.is_empty());
419 }
420
421 #[test]
422 fn test_log_artifact() {
423 let mut storage = InMemoryStorage::new();
424 let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
425 let run_id = storage.create_run(&exp_id).expect("operation should succeed");
426
427 let data = b"model weights data";
428 let hash =
429 storage.log_artifact(&run_id, "model.bin", data).expect("operation should succeed");
430
431 assert!(hash.starts_with("sha256-"));
432 assert_eq!(storage.artifact_count(), 1);
433 }
434
435 #[test]
436 fn test_log_artifact_invalid_run() {
437 let mut storage = InMemoryStorage::new();
438 let result = storage.log_artifact("fake-run", "model.bin", b"data");
439
440 assert!(result.is_err());
441 }
442
443 #[test]
444 fn test_set_and_get_span_id() {
445 let mut storage = InMemoryStorage::new();
446 let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
447 let run_id = storage.create_run(&exp_id).expect("operation should succeed");
448
449 assert!(storage.get_span_id(&run_id).expect("operation should succeed").is_none());
450
451 storage.set_span_id(&run_id, "span-12345").expect("operation should succeed");
452
453 assert_eq!(
454 storage.get_span_id(&run_id).expect("operation should succeed"),
455 Some("span-12345".to_string())
456 );
457 }
458
459 #[test]
460 fn test_multiple_experiments_and_runs() {
461 let mut storage = InMemoryStorage::new();
462
463 let exp1 = storage.create_experiment("exp-1", None).expect("operation should succeed");
464 let exp2 = storage.create_experiment("exp-2", None).expect("operation should succeed");
465
466 let run1 = storage.create_run(&exp1).expect("operation should succeed");
467 let run2 = storage.create_run(&exp1).expect("operation should succeed");
468 let run3 = storage.create_run(&exp2).expect("operation should succeed");
469
470 assert_eq!(storage.experiment_count(), 2);
471 assert_eq!(storage.run_count(), 3);
472
473 storage.start_run(&run1).expect("operation should succeed");
475 storage.start_run(&run2).expect("operation should succeed");
476
477 assert_eq!(
478 storage.get_run_status(&run1).expect("operation should succeed"),
479 RunStatus::Running
480 );
481 assert_eq!(
482 storage.get_run_status(&run2).expect("operation should succeed"),
483 RunStatus::Running
484 );
485 assert_eq!(
486 storage.get_run_status(&run3).expect("operation should succeed"),
487 RunStatus::Pending
488 );
489 }
490}