1use super::backend::SqliteBackend;
6use crate::storage::{ExperimentStorage, MetricPoint, Result, RunStatus, StorageError};
7use chrono::{DateTime, Utc};
8use rusqlite::params;
9use sha2::{Digest, Sha256};
10
11fn status_to_str(status: RunStatus) -> &'static str {
13 match status {
14 RunStatus::Pending => "pending",
15 RunStatus::Running => "running",
16 RunStatus::Success => "completed",
17 RunStatus::Failed => "failed",
18 RunStatus::Cancelled => "cancelled",
19 }
20}
21
22pub(crate) fn str_to_status(s: &str) -> RunStatus {
24 match s {
25 "pending" => RunStatus::Pending,
26 "running" => RunStatus::Running,
27 "completed" => RunStatus::Success,
28 "failed" => RunStatus::Failed,
29 "cancelled" => RunStatus::Cancelled,
30 _ => RunStatus::Failed,
31 }
32}
33
34impl ExperimentStorage for SqliteBackend {
35 fn create_experiment(
36 &mut self,
37 name: &str,
38 config: Option<serde_json::Value>,
39 ) -> Result<String> {
40 let id = Self::generate_id();
41 let config_json = config.map(|c| c.to_string());
42 let now = Utc::now().to_rfc3339();
43
44 let conn = self.lock_conn()?;
45 conn.execute(
46 "INSERT INTO experiments (id, name, config, created_at, updated_at) VALUES (?1, ?2, ?3, ?4, ?5)",
47 params![id, name, config_json, now, now],
48 )
49 .map_err(|e| StorageError::Backend(format!("Failed to create experiment: {e}")))?;
50
51 Ok(id)
52 }
53
54 fn create_run(&mut self, experiment_id: &str) -> Result<String> {
55 let conn = self.lock_conn()?;
56
57 let exists: bool = conn
59 .query_row(
60 "SELECT EXISTS(SELECT 1 FROM experiments WHERE id = ?1)",
61 [experiment_id],
62 |row| row.get(0),
63 )
64 .map_err(|e| StorageError::Backend(format!("Failed to check experiment: {e}")))?;
65
66 if !exists {
67 return Err(StorageError::ExperimentNotFound(experiment_id.to_string()));
68 }
69
70 let id = Self::generate_id();
71 let now = Utc::now().to_rfc3339();
72
73 conn.execute(
74 "INSERT INTO runs (id, experiment_id, status, start_time) VALUES (?1, ?2, 'pending', ?3)",
75 params![id, experiment_id, now],
76 )
77 .map_err(|e| StorageError::Backend(format!("Failed to create run: {e}")))?;
78
79 Ok(id)
80 }
81
82 fn start_run(&mut self, run_id: &str) -> Result<()> {
83 let conn = self.lock_conn()?;
84
85 let current_status: String = conn
86 .query_row("SELECT status FROM runs WHERE id = ?1", [run_id], |row| row.get(0))
87 .map_err(|e| match e {
88 rusqlite::Error::QueryReturnedNoRows => {
89 StorageError::RunNotFound(run_id.to_string())
90 }
91 _ => StorageError::Backend(format!("Failed to get run status: {e}")),
92 })?;
93
94 if current_status != "pending" {
95 return Err(StorageError::InvalidState(format!(
96 "Cannot start run in {current_status} status"
97 )));
98 }
99
100 let now = Utc::now().to_rfc3339();
101 conn.execute(
102 "UPDATE runs SET status = 'running', start_time = ?1 WHERE id = ?2",
103 params![now, run_id],
104 )
105 .map_err(|e| StorageError::Backend(format!("Failed to start run: {e}")))?;
106
107 Ok(())
108 }
109
110 fn complete_run(&mut self, run_id: &str, status: RunStatus) -> Result<()> {
111 let conn = self.lock_conn()?;
112
113 let current_status: String = conn
114 .query_row("SELECT status FROM runs WHERE id = ?1", [run_id], |row| row.get(0))
115 .map_err(|e| match e {
116 rusqlite::Error::QueryReturnedNoRows => {
117 StorageError::RunNotFound(run_id.to_string())
118 }
119 _ => StorageError::Backend(format!("Failed to get run status: {e}")),
120 })?;
121
122 if current_status != "running" {
123 return Err(StorageError::InvalidState(format!(
124 "Cannot complete run in {current_status} status"
125 )));
126 }
127
128 let now = Utc::now().to_rfc3339();
129 conn.execute(
130 "UPDATE runs SET status = ?1, end_time = ?2 WHERE id = ?3",
131 params![status_to_str(status), now, run_id],
132 )
133 .map_err(|e| StorageError::Backend(format!("Failed to complete run: {e}")))?;
134
135 Ok(())
136 }
137
138 fn log_metric(&mut self, run_id: &str, key: &str, step: u64, value: f64) -> Result<()> {
139 let conn = self.lock_conn()?;
140
141 let exists: bool = conn
143 .query_row("SELECT EXISTS(SELECT 1 FROM runs WHERE id = ?1)", [run_id], |row| {
144 row.get(0)
145 })
146 .map_err(|e| StorageError::Backend(format!("Failed to check run: {e}")))?;
147
148 if !exists {
149 return Err(StorageError::RunNotFound(run_id.to_string()));
150 }
151
152 let now = Utc::now().to_rfc3339();
153 conn.execute(
154 "INSERT INTO metrics (run_id, key, step, value, timestamp) VALUES (?1, ?2, ?3, ?4, ?5)",
155 params![run_id, key, step as i64, value, now],
156 )
157 .map_err(|e| StorageError::Backend(format!("Failed to log metric: {e}")))?;
158
159 Ok(())
160 }
161
162 fn log_artifact(&mut self, run_id: &str, key: &str, data: &[u8]) -> Result<String> {
163 let conn = self.lock_conn()?;
164
165 let exists: bool = conn
167 .query_row("SELECT EXISTS(SELECT 1 FROM runs WHERE id = ?1)", [run_id], |row| {
168 row.get(0)
169 })
170 .map_err(|e| StorageError::Backend(format!("Failed to check run: {e}")))?;
171
172 if !exists {
173 return Err(StorageError::RunNotFound(run_id.to_string()));
174 }
175
176 let mut hasher = Sha256::new();
178 hasher.update(data);
179 let sha256 = format!("{:x}", hasher.finalize());
180
181 let id = Self::generate_id();
182 let size = data.len() as i64;
183
184 conn.execute(
185 "INSERT INTO artifacts (id, run_id, path, size_bytes, sha256, data) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
186 params![id, run_id, key, size, sha256, data],
187 )
188 .map_err(|e| StorageError::Backend(format!("Failed to log artifact: {e}")))?;
189
190 Ok(sha256)
191 }
192
193 fn get_metrics(&self, run_id: &str, key: &str) -> Result<Vec<MetricPoint>> {
194 let conn = self.lock_conn()?;
195
196 let exists: bool = conn
198 .query_row("SELECT EXISTS(SELECT 1 FROM runs WHERE id = ?1)", [run_id], |row| {
199 row.get(0)
200 })
201 .map_err(|e| StorageError::Backend(format!("Failed to check run: {e}")))?;
202
203 if !exists {
204 return Err(StorageError::RunNotFound(run_id.to_string()));
205 }
206
207 let mut stmt = conn
208 .prepare("SELECT step, value, timestamp FROM metrics WHERE run_id = ?1 AND key = ?2 ORDER BY step")
209 .map_err(|e| StorageError::Backend(format!("Failed to prepare metrics query: {e}")))?;
210
211 let points = stmt
212 .query_map(params![run_id, key], |row| {
213 let step: i64 = row.get(0)?;
214 let value: f64 = row.get(1)?;
215 let ts_str: String = row.get(2)?;
216 let timestamp: DateTime<Utc> = ts_str.parse().unwrap_or_else(|_| Utc::now());
217 Ok(MetricPoint::with_timestamp(step as u64, value, timestamp))
218 })
219 .map_err(|e| StorageError::Backend(format!("Failed to query metrics: {e}")))?
220 .collect::<std::result::Result<Vec<_>, _>>()
221 .map_err(|e| StorageError::Backend(format!("Failed to read metric row: {e}")))?;
222
223 Ok(points)
224 }
225
226 fn get_run_status(&self, run_id: &str) -> Result<RunStatus> {
227 let conn = self.lock_conn()?;
228
229 let status_str: String = conn
230 .query_row("SELECT status FROM runs WHERE id = ?1", [run_id], |row| row.get(0))
231 .map_err(|e| match e {
232 rusqlite::Error::QueryReturnedNoRows => {
233 StorageError::RunNotFound(run_id.to_string())
234 }
235 _ => StorageError::Backend(format!("Failed to get run status: {e}")),
236 })?;
237
238 Ok(str_to_status(&status_str))
239 }
240
241 fn set_span_id(&mut self, run_id: &str, span_id: &str) -> Result<()> {
242 let conn = self.lock_conn()?;
243
244 let exists: bool = conn
246 .query_row("SELECT EXISTS(SELECT 1 FROM runs WHERE id = ?1)", [run_id], |row| {
247 row.get(0)
248 })
249 .map_err(|e| StorageError::Backend(format!("Failed to check run: {e}")))?;
250
251 if !exists {
252 return Err(StorageError::RunNotFound(run_id.to_string()));
253 }
254
255 conn.execute(
256 "INSERT OR REPLACE INTO span_ids (run_id, span_id) VALUES (?1, ?2)",
257 params![run_id, span_id],
258 )
259 .map_err(|e| StorageError::Backend(format!("Failed to set span ID: {e}")))?;
260
261 Ok(())
262 }
263
264 fn get_span_id(&self, run_id: &str) -> Result<Option<String>> {
265 let conn = self.lock_conn()?;
266
267 let exists: bool = conn
269 .query_row("SELECT EXISTS(SELECT 1 FROM runs WHERE id = ?1)", [run_id], |row| {
270 row.get(0)
271 })
272 .map_err(|e| StorageError::Backend(format!("Failed to check run: {e}")))?;
273
274 if !exists {
275 return Err(StorageError::RunNotFound(run_id.to_string()));
276 }
277
278 let result =
279 conn.query_row("SELECT span_id FROM span_ids WHERE run_id = ?1", [run_id], |row| {
280 row.get(0)
281 });
282
283 match result {
284 Ok(span_id) => Ok(Some(span_id)),
285 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
286 Err(e) => Err(StorageError::Backend(format!("Failed to get span ID: {e}"))),
287 }
288 }
289}
290
291#[cfg(test)]
292#[allow(clippy::unwrap_used)]
293mod tests {
294 use super::*;
295 use crate::storage::ExperimentStorage;
296
297 fn test_backend() -> SqliteBackend {
298 SqliteBackend::open_in_memory().expect("in-memory db should succeed")
299 }
300
301 #[test]
302 fn test_status_to_str_all_variants() {
303 assert_eq!(status_to_str(RunStatus::Pending), "pending");
304 assert_eq!(status_to_str(RunStatus::Running), "running");
305 assert_eq!(status_to_str(RunStatus::Success), "completed");
306 assert_eq!(status_to_str(RunStatus::Failed), "failed");
307 assert_eq!(status_to_str(RunStatus::Cancelled), "cancelled");
308 }
309
310 #[test]
311 fn test_str_to_status_all_variants() {
312 assert_eq!(str_to_status("pending"), RunStatus::Pending);
313 assert_eq!(str_to_status("running"), RunStatus::Running);
314 assert_eq!(str_to_status("completed"), RunStatus::Success);
315 assert_eq!(str_to_status("failed"), RunStatus::Failed);
316 assert_eq!(str_to_status("cancelled"), RunStatus::Cancelled);
317 }
318
319 #[test]
320 fn test_str_to_status_unknown_defaults_failed() {
321 assert_eq!(str_to_status("xyz"), RunStatus::Failed);
322 assert_eq!(str_to_status(""), RunStatus::Failed);
323 }
324
325 #[test]
326 fn test_create_experiment() {
327 let mut backend = test_backend();
328 let id = backend.create_experiment("test-exp", None).unwrap();
329 assert!(!id.is_empty());
330 }
331
332 #[test]
333 fn test_create_experiment_with_config() {
334 let mut backend = test_backend();
335 let config = serde_json::json!({"lr": 0.001, "epochs": 10});
336 let id = backend.create_experiment("config-exp", Some(config)).unwrap();
337 assert!(!id.is_empty());
338 }
339
340 #[test]
341 fn test_create_run() {
342 let mut backend = test_backend();
343 let exp_id = backend.create_experiment("test", None).unwrap();
344 let run_id = backend.create_run(&exp_id).unwrap();
345 assert!(!run_id.is_empty());
346 }
347
348 #[test]
349 fn test_create_run_nonexistent_experiment() {
350 let mut backend = test_backend();
351 let result = backend.create_run("nonexistent-exp");
352 assert!(result.is_err());
353 }
354
355 #[test]
356 fn test_start_run() {
357 let mut backend = test_backend();
358 let exp_id = backend.create_experiment("test", None).unwrap();
359 let run_id = backend.create_run(&exp_id).unwrap();
360 backend.start_run(&run_id).unwrap();
361 let status = backend.get_run_status(&run_id).unwrap();
362 assert_eq!(status, RunStatus::Running);
363 }
364
365 #[test]
366 fn test_start_run_nonexistent() {
367 let mut backend = test_backend();
368 let result = backend.start_run("nonexistent-run");
369 assert!(result.is_err());
370 }
371
372 #[test]
373 fn test_start_run_not_pending() {
374 let mut backend = test_backend();
375 let exp_id = backend.create_experiment("test", None).unwrap();
376 let run_id = backend.create_run(&exp_id).unwrap();
377 backend.start_run(&run_id).unwrap();
378 let result = backend.start_run(&run_id);
380 assert!(result.is_err());
381 }
382
383 #[test]
384 fn test_complete_run_success() {
385 let mut backend = test_backend();
386 let exp_id = backend.create_experiment("test", None).unwrap();
387 let run_id = backend.create_run(&exp_id).unwrap();
388 backend.start_run(&run_id).unwrap();
389 backend.complete_run(&run_id, RunStatus::Success).unwrap();
390 let status = backend.get_run_status(&run_id).unwrap();
391 assert_eq!(status, RunStatus::Success);
392 }
393
394 #[test]
395 fn test_complete_run_failed() {
396 let mut backend = test_backend();
397 let exp_id = backend.create_experiment("test", None).unwrap();
398 let run_id = backend.create_run(&exp_id).unwrap();
399 backend.start_run(&run_id).unwrap();
400 backend.complete_run(&run_id, RunStatus::Failed).unwrap();
401 let status = backend.get_run_status(&run_id).unwrap();
402 assert_eq!(status, RunStatus::Failed);
403 }
404
405 #[test]
406 fn test_complete_run_cancelled() {
407 let mut backend = test_backend();
408 let exp_id = backend.create_experiment("test", None).unwrap();
409 let run_id = backend.create_run(&exp_id).unwrap();
410 backend.start_run(&run_id).unwrap();
411 backend.complete_run(&run_id, RunStatus::Cancelled).unwrap();
412 let status = backend.get_run_status(&run_id).unwrap();
413 assert_eq!(status, RunStatus::Cancelled);
414 }
415
416 #[test]
417 fn test_complete_run_not_running() {
418 let mut backend = test_backend();
419 let exp_id = backend.create_experiment("test", None).unwrap();
420 let run_id = backend.create_run(&exp_id).unwrap();
421 let result = backend.complete_run(&run_id, RunStatus::Success);
423 assert!(result.is_err());
424 }
425
426 #[test]
427 fn test_complete_run_nonexistent() {
428 let mut backend = test_backend();
429 let result = backend.complete_run("nonexistent-run", RunStatus::Success);
430 assert!(result.is_err());
431 }
432
433 #[test]
434 fn test_log_metric() {
435 let mut backend = test_backend();
436 let exp_id = backend.create_experiment("test", None).unwrap();
437 let run_id = backend.create_run(&exp_id).unwrap();
438 backend.log_metric(&run_id, "loss", 0, 0.5).unwrap();
439 backend.log_metric(&run_id, "loss", 1, 0.4).unwrap();
440 backend.log_metric(&run_id, "loss", 2, 0.3).unwrap();
441 }
442
443 #[test]
444 fn test_log_metric_nonexistent_run() {
445 let mut backend = test_backend();
446 let result = backend.log_metric("nonexistent-run", "loss", 0, 0.5);
447 assert!(result.is_err());
448 }
449
450 #[test]
451 fn test_get_metrics() {
452 let mut backend = test_backend();
453 let exp_id = backend.create_experiment("test", None).unwrap();
454 let run_id = backend.create_run(&exp_id).unwrap();
455 backend.log_metric(&run_id, "loss", 0, 0.5).unwrap();
456 backend.log_metric(&run_id, "loss", 1, 0.4).unwrap();
457 backend.log_metric(&run_id, "accuracy", 0, 0.8).unwrap();
458
459 let loss_metrics = backend.get_metrics(&run_id, "loss").unwrap();
460 assert_eq!(loss_metrics.len(), 2);
461 assert_eq!(loss_metrics[0].step, 0);
462 assert!((loss_metrics[0].value - 0.5).abs() < f64::EPSILON);
463 assert_eq!(loss_metrics[1].step, 1);
464
465 let acc_metrics = backend.get_metrics(&run_id, "accuracy").unwrap();
466 assert_eq!(acc_metrics.len(), 1);
467 }
468
469 #[test]
470 fn test_get_metrics_nonexistent_run() {
471 let backend = test_backend();
472 let result = backend.get_metrics("nonexistent-run", "loss");
473 assert!(result.is_err());
474 }
475
476 #[test]
477 fn test_get_metrics_empty() {
478 let mut backend = test_backend();
479 let exp_id = backend.create_experiment("test", None).unwrap();
480 let run_id = backend.create_run(&exp_id).unwrap();
481 let metrics = backend.get_metrics(&run_id, "loss").unwrap();
482 assert!(metrics.is_empty());
483 }
484
485 #[test]
486 fn test_get_run_status() {
487 let mut backend = test_backend();
488 let exp_id = backend.create_experiment("test", None).unwrap();
489 let run_id = backend.create_run(&exp_id).unwrap();
490 let status = backend.get_run_status(&run_id).unwrap();
491 assert_eq!(status, RunStatus::Pending);
492 }
493
494 #[test]
495 fn test_get_run_status_nonexistent() {
496 let backend = test_backend();
497 let result = backend.get_run_status("nonexistent-run");
498 assert!(result.is_err());
499 }
500
501 #[test]
502 fn test_log_artifact() {
503 let mut backend = test_backend();
504 let exp_id = backend.create_experiment("test", None).unwrap();
505 let run_id = backend.create_run(&exp_id).unwrap();
506 let sha = backend.log_artifact(&run_id, "model.bin", b"fake model data").unwrap();
507 assert!(!sha.is_empty());
508 assert_eq!(sha.len(), 64);
510 }
511
512 #[test]
513 fn test_log_artifact_nonexistent_run() {
514 let mut backend = test_backend();
515 let result = backend.log_artifact("nonexistent-run", "file.bin", b"data");
516 assert!(result.is_err());
517 }
518
519 #[test]
520 fn test_log_artifact_deterministic_hash() {
521 let mut backend = test_backend();
522 let exp_id = backend.create_experiment("test", None).unwrap();
523 let run_id1 = backend.create_run(&exp_id).unwrap();
524 let run_id2 = backend.create_run(&exp_id).unwrap();
525 let sha1 = backend.log_artifact(&run_id1, "file.bin", b"same data").unwrap();
526 let sha2 = backend.log_artifact(&run_id2, "file.bin", b"same data").unwrap();
527 assert_eq!(sha1, sha2);
529 }
530
531 #[test]
532 fn test_set_and_get_span_id() {
533 let mut backend = test_backend();
534 let exp_id = backend.create_experiment("test", None).unwrap();
535 let run_id = backend.create_run(&exp_id).unwrap();
536
537 let span = backend.get_span_id(&run_id).unwrap();
539 assert!(span.is_none());
540
541 backend.set_span_id(&run_id, "span-12345").unwrap();
543 let span = backend.get_span_id(&run_id).unwrap();
544 assert_eq!(span, Some("span-12345".to_string()));
545 }
546
547 #[test]
548 fn test_set_span_id_nonexistent_run() {
549 let mut backend = test_backend();
550 let result = backend.set_span_id("nonexistent-run", "span-123");
551 assert!(result.is_err());
552 }
553
554 #[test]
555 fn test_get_span_id_nonexistent_run() {
556 let backend = test_backend();
557 let result = backend.get_span_id("nonexistent-run");
558 assert!(result.is_err());
559 }
560
561 #[test]
562 fn test_set_span_id_overwrite() {
563 let mut backend = test_backend();
564 let exp_id = backend.create_experiment("test", None).unwrap();
565 let run_id = backend.create_run(&exp_id).unwrap();
566
567 backend.set_span_id(&run_id, "span-1").unwrap();
568 backend.set_span_id(&run_id, "span-2").unwrap();
569 let span = backend.get_span_id(&run_id).unwrap();
570 assert_eq!(span, Some("span-2".to_string()));
571 }
572
573 #[test]
574 fn test_full_lifecycle() {
575 let mut backend = test_backend();
576 let exp_id = backend.create_experiment("lifecycle-test", None).unwrap();
577 let run_id = backend.create_run(&exp_id).unwrap();
578
579 assert_eq!(backend.get_run_status(&run_id).unwrap(), RunStatus::Pending);
580
581 backend.start_run(&run_id).unwrap();
582 assert_eq!(backend.get_run_status(&run_id).unwrap(), RunStatus::Running);
583
584 backend.log_metric(&run_id, "loss", 0, 1.0).unwrap();
585 backend.log_metric(&run_id, "loss", 1, 0.5).unwrap();
586
587 backend.complete_run(&run_id, RunStatus::Success).unwrap();
588 assert_eq!(backend.get_run_status(&run_id).unwrap(), RunStatus::Success);
589
590 let metrics = backend.get_metrics(&run_id, "loss").unwrap();
591 assert_eq!(metrics.len(), 2);
592 }
593
594 #[test]
595 fn test_metrics_ordered_by_step() {
596 let mut backend = test_backend();
597 let exp_id = backend.create_experiment("test", None).unwrap();
598 let run_id = backend.create_run(&exp_id).unwrap();
599 backend.log_metric(&run_id, "loss", 5, 0.1).unwrap();
601 backend.log_metric(&run_id, "loss", 1, 0.5).unwrap();
602 backend.log_metric(&run_id, "loss", 3, 0.3).unwrap();
603
604 let metrics = backend.get_metrics(&run_id, "loss").unwrap();
605 assert_eq!(metrics.len(), 3);
606 assert_eq!(metrics[0].step, 1);
607 assert_eq!(metrics[1].step, 3);
608 assert_eq!(metrics[2].step, 5);
609 }
610}