Skip to main content

entrenar/storage/sqlite/
metrics.rs

1//! Metrics operations for SQLite Backend.
2//!
3//! Contains metric logging and retrieval methods (ExperimentStorage trait implementation).
4
5use super::backend::SqliteBackend;
6use crate::storage::{ExperimentStorage, MetricPoint, Result, RunStatus, StorageError};
7use chrono::{DateTime, Utc};
8use rusqlite::params;
9use sha2::{Digest, Sha256};
10
11/// Map a RunStatus to its SQLite TEXT representation
12fn status_to_str(status: RunStatus) -> &'static str {
13    match status {
14        RunStatus::Pending => "pending",
15        RunStatus::Running => "running",
16        RunStatus::Success => "completed",
17        RunStatus::Failed => "failed",
18        RunStatus::Cancelled => "cancelled",
19    }
20}
21
22/// Parse a SQLite TEXT status back to RunStatus
23pub(crate) fn str_to_status(s: &str) -> RunStatus {
24    match s {
25        "pending" => RunStatus::Pending,
26        "running" => RunStatus::Running,
27        "completed" => RunStatus::Success,
28        "failed" => RunStatus::Failed,
29        "cancelled" => RunStatus::Cancelled,
30        _ => RunStatus::Failed,
31    }
32}
33
34impl ExperimentStorage for SqliteBackend {
35    fn create_experiment(
36        &mut self,
37        name: &str,
38        config: Option<serde_json::Value>,
39    ) -> Result<String> {
40        let id = Self::generate_id();
41        let config_json = config.map(|c| c.to_string());
42        let now = Utc::now().to_rfc3339();
43
44        let conn = self.lock_conn()?;
45        conn.execute(
46            "INSERT INTO experiments (id, name, config, created_at, updated_at) VALUES (?1, ?2, ?3, ?4, ?5)",
47            params![id, name, config_json, now, now],
48        )
49        .map_err(|e| StorageError::Backend(format!("Failed to create experiment: {e}")))?;
50
51        Ok(id)
52    }
53
54    fn create_run(&mut self, experiment_id: &str) -> Result<String> {
55        let conn = self.lock_conn()?;
56
57        // Verify experiment exists
58        let exists: bool = conn
59            .query_row(
60                "SELECT EXISTS(SELECT 1 FROM experiments WHERE id = ?1)",
61                [experiment_id],
62                |row| row.get(0),
63            )
64            .map_err(|e| StorageError::Backend(format!("Failed to check experiment: {e}")))?;
65
66        if !exists {
67            return Err(StorageError::ExperimentNotFound(experiment_id.to_string()));
68        }
69
70        let id = Self::generate_id();
71        let now = Utc::now().to_rfc3339();
72
73        conn.execute(
74            "INSERT INTO runs (id, experiment_id, status, start_time) VALUES (?1, ?2, 'pending', ?3)",
75            params![id, experiment_id, now],
76        )
77        .map_err(|e| StorageError::Backend(format!("Failed to create run: {e}")))?;
78
79        Ok(id)
80    }
81
82    fn start_run(&mut self, run_id: &str) -> Result<()> {
83        let conn = self.lock_conn()?;
84
85        let current_status: String = conn
86            .query_row("SELECT status FROM runs WHERE id = ?1", [run_id], |row| row.get(0))
87            .map_err(|e| match e {
88                rusqlite::Error::QueryReturnedNoRows => {
89                    StorageError::RunNotFound(run_id.to_string())
90                }
91                _ => StorageError::Backend(format!("Failed to get run status: {e}")),
92            })?;
93
94        if current_status != "pending" {
95            return Err(StorageError::InvalidState(format!(
96                "Cannot start run in {current_status} status"
97            )));
98        }
99
100        let now = Utc::now().to_rfc3339();
101        conn.execute(
102            "UPDATE runs SET status = 'running', start_time = ?1 WHERE id = ?2",
103            params![now, run_id],
104        )
105        .map_err(|e| StorageError::Backend(format!("Failed to start run: {e}")))?;
106
107        Ok(())
108    }
109
110    fn complete_run(&mut self, run_id: &str, status: RunStatus) -> Result<()> {
111        let conn = self.lock_conn()?;
112
113        let current_status: String = conn
114            .query_row("SELECT status FROM runs WHERE id = ?1", [run_id], |row| row.get(0))
115            .map_err(|e| match e {
116                rusqlite::Error::QueryReturnedNoRows => {
117                    StorageError::RunNotFound(run_id.to_string())
118                }
119                _ => StorageError::Backend(format!("Failed to get run status: {e}")),
120            })?;
121
122        if current_status != "running" {
123            return Err(StorageError::InvalidState(format!(
124                "Cannot complete run in {current_status} status"
125            )));
126        }
127
128        let now = Utc::now().to_rfc3339();
129        conn.execute(
130            "UPDATE runs SET status = ?1, end_time = ?2 WHERE id = ?3",
131            params![status_to_str(status), now, run_id],
132        )
133        .map_err(|e| StorageError::Backend(format!("Failed to complete run: {e}")))?;
134
135        Ok(())
136    }
137
138    fn log_metric(&mut self, run_id: &str, key: &str, step: u64, value: f64) -> Result<()> {
139        let conn = self.lock_conn()?;
140
141        // Verify run exists
142        let exists: bool = conn
143            .query_row("SELECT EXISTS(SELECT 1 FROM runs WHERE id = ?1)", [run_id], |row| {
144                row.get(0)
145            })
146            .map_err(|e| StorageError::Backend(format!("Failed to check run: {e}")))?;
147
148        if !exists {
149            return Err(StorageError::RunNotFound(run_id.to_string()));
150        }
151
152        let now = Utc::now().to_rfc3339();
153        conn.execute(
154            "INSERT INTO metrics (run_id, key, step, value, timestamp) VALUES (?1, ?2, ?3, ?4, ?5)",
155            params![run_id, key, step as i64, value, now],
156        )
157        .map_err(|e| StorageError::Backend(format!("Failed to log metric: {e}")))?;
158
159        Ok(())
160    }
161
162    fn log_artifact(&mut self, run_id: &str, key: &str, data: &[u8]) -> Result<String> {
163        let conn = self.lock_conn()?;
164
165        // Verify run exists
166        let exists: bool = conn
167            .query_row("SELECT EXISTS(SELECT 1 FROM runs WHERE id = ?1)", [run_id], |row| {
168                row.get(0)
169            })
170            .map_err(|e| StorageError::Backend(format!("Failed to check run: {e}")))?;
171
172        if !exists {
173            return Err(StorageError::RunNotFound(run_id.to_string()));
174        }
175
176        // Compute SHA-256 for content-addressable storage
177        let mut hasher = Sha256::new();
178        hasher.update(data);
179        let sha256 = format!("{:x}", hasher.finalize());
180
181        let id = Self::generate_id();
182        let size = data.len() as i64;
183
184        conn.execute(
185            "INSERT INTO artifacts (id, run_id, path, size_bytes, sha256, data) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
186            params![id, run_id, key, size, sha256, data],
187        )
188        .map_err(|e| StorageError::Backend(format!("Failed to log artifact: {e}")))?;
189
190        Ok(sha256)
191    }
192
193    fn get_metrics(&self, run_id: &str, key: &str) -> Result<Vec<MetricPoint>> {
194        let conn = self.lock_conn()?;
195
196        // Verify run exists
197        let exists: bool = conn
198            .query_row("SELECT EXISTS(SELECT 1 FROM runs WHERE id = ?1)", [run_id], |row| {
199                row.get(0)
200            })
201            .map_err(|e| StorageError::Backend(format!("Failed to check run: {e}")))?;
202
203        if !exists {
204            return Err(StorageError::RunNotFound(run_id.to_string()));
205        }
206
207        let mut stmt = conn
208            .prepare("SELECT step, value, timestamp FROM metrics WHERE run_id = ?1 AND key = ?2 ORDER BY step")
209            .map_err(|e| StorageError::Backend(format!("Failed to prepare metrics query: {e}")))?;
210
211        let points = stmt
212            .query_map(params![run_id, key], |row| {
213                let step: i64 = row.get(0)?;
214                let value: f64 = row.get(1)?;
215                let ts_str: String = row.get(2)?;
216                let timestamp: DateTime<Utc> = ts_str.parse().unwrap_or_else(|_| Utc::now());
217                Ok(MetricPoint::with_timestamp(step as u64, value, timestamp))
218            })
219            .map_err(|e| StorageError::Backend(format!("Failed to query metrics: {e}")))?
220            .collect::<std::result::Result<Vec<_>, _>>()
221            .map_err(|e| StorageError::Backend(format!("Failed to read metric row: {e}")))?;
222
223        Ok(points)
224    }
225
226    fn get_run_status(&self, run_id: &str) -> Result<RunStatus> {
227        let conn = self.lock_conn()?;
228
229        let status_str: String = conn
230            .query_row("SELECT status FROM runs WHERE id = ?1", [run_id], |row| row.get(0))
231            .map_err(|e| match e {
232                rusqlite::Error::QueryReturnedNoRows => {
233                    StorageError::RunNotFound(run_id.to_string())
234                }
235                _ => StorageError::Backend(format!("Failed to get run status: {e}")),
236            })?;
237
238        Ok(str_to_status(&status_str))
239    }
240
241    fn set_span_id(&mut self, run_id: &str, span_id: &str) -> Result<()> {
242        let conn = self.lock_conn()?;
243
244        // Verify run exists
245        let exists: bool = conn
246            .query_row("SELECT EXISTS(SELECT 1 FROM runs WHERE id = ?1)", [run_id], |row| {
247                row.get(0)
248            })
249            .map_err(|e| StorageError::Backend(format!("Failed to check run: {e}")))?;
250
251        if !exists {
252            return Err(StorageError::RunNotFound(run_id.to_string()));
253        }
254
255        conn.execute(
256            "INSERT OR REPLACE INTO span_ids (run_id, span_id) VALUES (?1, ?2)",
257            params![run_id, span_id],
258        )
259        .map_err(|e| StorageError::Backend(format!("Failed to set span ID: {e}")))?;
260
261        Ok(())
262    }
263
264    fn get_span_id(&self, run_id: &str) -> Result<Option<String>> {
265        let conn = self.lock_conn()?;
266
267        // Verify run exists
268        let exists: bool = conn
269            .query_row("SELECT EXISTS(SELECT 1 FROM runs WHERE id = ?1)", [run_id], |row| {
270                row.get(0)
271            })
272            .map_err(|e| StorageError::Backend(format!("Failed to check run: {e}")))?;
273
274        if !exists {
275            return Err(StorageError::RunNotFound(run_id.to_string()));
276        }
277
278        let result =
279            conn.query_row("SELECT span_id FROM span_ids WHERE run_id = ?1", [run_id], |row| {
280                row.get(0)
281            });
282
283        match result {
284            Ok(span_id) => Ok(Some(span_id)),
285            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
286            Err(e) => Err(StorageError::Backend(format!("Failed to get span ID: {e}"))),
287        }
288    }
289}
290
291#[cfg(test)]
292#[allow(clippy::unwrap_used)]
293mod tests {
294    use super::*;
295    use crate::storage::ExperimentStorage;
296
297    fn test_backend() -> SqliteBackend {
298        SqliteBackend::open_in_memory().expect("in-memory db should succeed")
299    }
300
301    #[test]
302    fn test_status_to_str_all_variants() {
303        assert_eq!(status_to_str(RunStatus::Pending), "pending");
304        assert_eq!(status_to_str(RunStatus::Running), "running");
305        assert_eq!(status_to_str(RunStatus::Success), "completed");
306        assert_eq!(status_to_str(RunStatus::Failed), "failed");
307        assert_eq!(status_to_str(RunStatus::Cancelled), "cancelled");
308    }
309
310    #[test]
311    fn test_str_to_status_all_variants() {
312        assert_eq!(str_to_status("pending"), RunStatus::Pending);
313        assert_eq!(str_to_status("running"), RunStatus::Running);
314        assert_eq!(str_to_status("completed"), RunStatus::Success);
315        assert_eq!(str_to_status("failed"), RunStatus::Failed);
316        assert_eq!(str_to_status("cancelled"), RunStatus::Cancelled);
317    }
318
319    #[test]
320    fn test_str_to_status_unknown_defaults_failed() {
321        assert_eq!(str_to_status("xyz"), RunStatus::Failed);
322        assert_eq!(str_to_status(""), RunStatus::Failed);
323    }
324
325    #[test]
326    fn test_create_experiment() {
327        let mut backend = test_backend();
328        let id = backend.create_experiment("test-exp", None).unwrap();
329        assert!(!id.is_empty());
330    }
331
332    #[test]
333    fn test_create_experiment_with_config() {
334        let mut backend = test_backend();
335        let config = serde_json::json!({"lr": 0.001, "epochs": 10});
336        let id = backend.create_experiment("config-exp", Some(config)).unwrap();
337        assert!(!id.is_empty());
338    }
339
340    #[test]
341    fn test_create_run() {
342        let mut backend = test_backend();
343        let exp_id = backend.create_experiment("test", None).unwrap();
344        let run_id = backend.create_run(&exp_id).unwrap();
345        assert!(!run_id.is_empty());
346    }
347
348    #[test]
349    fn test_create_run_nonexistent_experiment() {
350        let mut backend = test_backend();
351        let result = backend.create_run("nonexistent-exp");
352        assert!(result.is_err());
353    }
354
355    #[test]
356    fn test_start_run() {
357        let mut backend = test_backend();
358        let exp_id = backend.create_experiment("test", None).unwrap();
359        let run_id = backend.create_run(&exp_id).unwrap();
360        backend.start_run(&run_id).unwrap();
361        let status = backend.get_run_status(&run_id).unwrap();
362        assert_eq!(status, RunStatus::Running);
363    }
364
365    #[test]
366    fn test_start_run_nonexistent() {
367        let mut backend = test_backend();
368        let result = backend.start_run("nonexistent-run");
369        assert!(result.is_err());
370    }
371
372    #[test]
373    fn test_start_run_not_pending() {
374        let mut backend = test_backend();
375        let exp_id = backend.create_experiment("test", None).unwrap();
376        let run_id = backend.create_run(&exp_id).unwrap();
377        backend.start_run(&run_id).unwrap();
378        // Starting again should fail (status is "running", not "pending")
379        let result = backend.start_run(&run_id);
380        assert!(result.is_err());
381    }
382
383    #[test]
384    fn test_complete_run_success() {
385        let mut backend = test_backend();
386        let exp_id = backend.create_experiment("test", None).unwrap();
387        let run_id = backend.create_run(&exp_id).unwrap();
388        backend.start_run(&run_id).unwrap();
389        backend.complete_run(&run_id, RunStatus::Success).unwrap();
390        let status = backend.get_run_status(&run_id).unwrap();
391        assert_eq!(status, RunStatus::Success);
392    }
393
394    #[test]
395    fn test_complete_run_failed() {
396        let mut backend = test_backend();
397        let exp_id = backend.create_experiment("test", None).unwrap();
398        let run_id = backend.create_run(&exp_id).unwrap();
399        backend.start_run(&run_id).unwrap();
400        backend.complete_run(&run_id, RunStatus::Failed).unwrap();
401        let status = backend.get_run_status(&run_id).unwrap();
402        assert_eq!(status, RunStatus::Failed);
403    }
404
405    #[test]
406    fn test_complete_run_cancelled() {
407        let mut backend = test_backend();
408        let exp_id = backend.create_experiment("test", None).unwrap();
409        let run_id = backend.create_run(&exp_id).unwrap();
410        backend.start_run(&run_id).unwrap();
411        backend.complete_run(&run_id, RunStatus::Cancelled).unwrap();
412        let status = backend.get_run_status(&run_id).unwrap();
413        assert_eq!(status, RunStatus::Cancelled);
414    }
415
416    #[test]
417    fn test_complete_run_not_running() {
418        let mut backend = test_backend();
419        let exp_id = backend.create_experiment("test", None).unwrap();
420        let run_id = backend.create_run(&exp_id).unwrap();
421        // Completing a "pending" run should fail
422        let result = backend.complete_run(&run_id, RunStatus::Success);
423        assert!(result.is_err());
424    }
425
426    #[test]
427    fn test_complete_run_nonexistent() {
428        let mut backend = test_backend();
429        let result = backend.complete_run("nonexistent-run", RunStatus::Success);
430        assert!(result.is_err());
431    }
432
433    #[test]
434    fn test_log_metric() {
435        let mut backend = test_backend();
436        let exp_id = backend.create_experiment("test", None).unwrap();
437        let run_id = backend.create_run(&exp_id).unwrap();
438        backend.log_metric(&run_id, "loss", 0, 0.5).unwrap();
439        backend.log_metric(&run_id, "loss", 1, 0.4).unwrap();
440        backend.log_metric(&run_id, "loss", 2, 0.3).unwrap();
441    }
442
443    #[test]
444    fn test_log_metric_nonexistent_run() {
445        let mut backend = test_backend();
446        let result = backend.log_metric("nonexistent-run", "loss", 0, 0.5);
447        assert!(result.is_err());
448    }
449
450    #[test]
451    fn test_get_metrics() {
452        let mut backend = test_backend();
453        let exp_id = backend.create_experiment("test", None).unwrap();
454        let run_id = backend.create_run(&exp_id).unwrap();
455        backend.log_metric(&run_id, "loss", 0, 0.5).unwrap();
456        backend.log_metric(&run_id, "loss", 1, 0.4).unwrap();
457        backend.log_metric(&run_id, "accuracy", 0, 0.8).unwrap();
458
459        let loss_metrics = backend.get_metrics(&run_id, "loss").unwrap();
460        assert_eq!(loss_metrics.len(), 2);
461        assert_eq!(loss_metrics[0].step, 0);
462        assert!((loss_metrics[0].value - 0.5).abs() < f64::EPSILON);
463        assert_eq!(loss_metrics[1].step, 1);
464
465        let acc_metrics = backend.get_metrics(&run_id, "accuracy").unwrap();
466        assert_eq!(acc_metrics.len(), 1);
467    }
468
469    #[test]
470    fn test_get_metrics_nonexistent_run() {
471        let backend = test_backend();
472        let result = backend.get_metrics("nonexistent-run", "loss");
473        assert!(result.is_err());
474    }
475
476    #[test]
477    fn test_get_metrics_empty() {
478        let mut backend = test_backend();
479        let exp_id = backend.create_experiment("test", None).unwrap();
480        let run_id = backend.create_run(&exp_id).unwrap();
481        let metrics = backend.get_metrics(&run_id, "loss").unwrap();
482        assert!(metrics.is_empty());
483    }
484
485    #[test]
486    fn test_get_run_status() {
487        let mut backend = test_backend();
488        let exp_id = backend.create_experiment("test", None).unwrap();
489        let run_id = backend.create_run(&exp_id).unwrap();
490        let status = backend.get_run_status(&run_id).unwrap();
491        assert_eq!(status, RunStatus::Pending);
492    }
493
494    #[test]
495    fn test_get_run_status_nonexistent() {
496        let backend = test_backend();
497        let result = backend.get_run_status("nonexistent-run");
498        assert!(result.is_err());
499    }
500
501    #[test]
502    fn test_log_artifact() {
503        let mut backend = test_backend();
504        let exp_id = backend.create_experiment("test", None).unwrap();
505        let run_id = backend.create_run(&exp_id).unwrap();
506        let sha = backend.log_artifact(&run_id, "model.bin", b"fake model data").unwrap();
507        assert!(!sha.is_empty());
508        // SHA-256 should be 64 hex chars
509        assert_eq!(sha.len(), 64);
510    }
511
512    #[test]
513    fn test_log_artifact_nonexistent_run() {
514        let mut backend = test_backend();
515        let result = backend.log_artifact("nonexistent-run", "file.bin", b"data");
516        assert!(result.is_err());
517    }
518
519    #[test]
520    fn test_log_artifact_deterministic_hash() {
521        let mut backend = test_backend();
522        let exp_id = backend.create_experiment("test", None).unwrap();
523        let run_id1 = backend.create_run(&exp_id).unwrap();
524        let run_id2 = backend.create_run(&exp_id).unwrap();
525        let sha1 = backend.log_artifact(&run_id1, "file.bin", b"same data").unwrap();
526        let sha2 = backend.log_artifact(&run_id2, "file.bin", b"same data").unwrap();
527        // Same data -> same SHA-256
528        assert_eq!(sha1, sha2);
529    }
530
531    #[test]
532    fn test_set_and_get_span_id() {
533        let mut backend = test_backend();
534        let exp_id = backend.create_experiment("test", None).unwrap();
535        let run_id = backend.create_run(&exp_id).unwrap();
536
537        // Initially no span ID
538        let span = backend.get_span_id(&run_id).unwrap();
539        assert!(span.is_none());
540
541        // Set span ID
542        backend.set_span_id(&run_id, "span-12345").unwrap();
543        let span = backend.get_span_id(&run_id).unwrap();
544        assert_eq!(span, Some("span-12345".to_string()));
545    }
546
547    #[test]
548    fn test_set_span_id_nonexistent_run() {
549        let mut backend = test_backend();
550        let result = backend.set_span_id("nonexistent-run", "span-123");
551        assert!(result.is_err());
552    }
553
554    #[test]
555    fn test_get_span_id_nonexistent_run() {
556        let backend = test_backend();
557        let result = backend.get_span_id("nonexistent-run");
558        assert!(result.is_err());
559    }
560
561    #[test]
562    fn test_set_span_id_overwrite() {
563        let mut backend = test_backend();
564        let exp_id = backend.create_experiment("test", None).unwrap();
565        let run_id = backend.create_run(&exp_id).unwrap();
566
567        backend.set_span_id(&run_id, "span-1").unwrap();
568        backend.set_span_id(&run_id, "span-2").unwrap();
569        let span = backend.get_span_id(&run_id).unwrap();
570        assert_eq!(span, Some("span-2".to_string()));
571    }
572
573    #[test]
574    fn test_full_lifecycle() {
575        let mut backend = test_backend();
576        let exp_id = backend.create_experiment("lifecycle-test", None).unwrap();
577        let run_id = backend.create_run(&exp_id).unwrap();
578
579        assert_eq!(backend.get_run_status(&run_id).unwrap(), RunStatus::Pending);
580
581        backend.start_run(&run_id).unwrap();
582        assert_eq!(backend.get_run_status(&run_id).unwrap(), RunStatus::Running);
583
584        backend.log_metric(&run_id, "loss", 0, 1.0).unwrap();
585        backend.log_metric(&run_id, "loss", 1, 0.5).unwrap();
586
587        backend.complete_run(&run_id, RunStatus::Success).unwrap();
588        assert_eq!(backend.get_run_status(&run_id).unwrap(), RunStatus::Success);
589
590        let metrics = backend.get_metrics(&run_id, "loss").unwrap();
591        assert_eq!(metrics.len(), 2);
592    }
593
594    #[test]
595    fn test_metrics_ordered_by_step() {
596        let mut backend = test_backend();
597        let exp_id = backend.create_experiment("test", None).unwrap();
598        let run_id = backend.create_run(&exp_id).unwrap();
599        // Insert out of order
600        backend.log_metric(&run_id, "loss", 5, 0.1).unwrap();
601        backend.log_metric(&run_id, "loss", 1, 0.5).unwrap();
602        backend.log_metric(&run_id, "loss", 3, 0.3).unwrap();
603
604        let metrics = backend.get_metrics(&run_id, "loss").unwrap();
605        assert_eq!(metrics.len(), 3);
606        assert_eq!(metrics[0].step, 1);
607        assert_eq!(metrics[1].step, 3);
608        assert_eq!(metrics[2].step, 5);
609    }
610}