use crate::engine::SynaDB;
use crate::error::{Result, SynaError};
use crate::types::Atom;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum RunStatus {
#[default]
Running,
Completed,
Failed,
Killed,
}
impl std::fmt::Display for RunStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RunStatus::Running => write!(f, "Running"),
RunStatus::Completed => write!(f, "Completed"),
RunStatus::Failed => write!(f, "Failed"),
RunStatus::Killed => write!(f, "Killed"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Run {
pub id: String,
pub experiment_name: String,
pub started_at: u64,
pub ended_at: Option<u64>,
pub status: RunStatus,
pub params: HashMap<String, String>,
pub tags: Vec<String>,
}
impl Run {
pub fn new(id: String, experiment_name: String, started_at: u64, tags: Vec<String>) -> Self {
Self {
id,
experiment_name,
started_at,
ended_at: None,
status: RunStatus::Running,
params: HashMap::new(),
tags,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct RunFilter {
pub experiment: Option<String>,
pub status: Option<RunStatus>,
pub tags: Option<Vec<String>>,
pub param_filter: Option<(String, String)>,
}
pub struct ExperimentTracker {
db: SynaDB,
}
impl ExperimentTracker {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
let db = SynaDB::new(path)?;
Ok(Self { db })
}
pub fn db(&self) -> &SynaDB {
&self.db
}
pub fn db_mut(&mut self) -> &mut SynaDB {
&mut self.db
}
pub fn start_run(&mut self, experiment: &str, tags: Vec<String>) -> Result<String> {
let run_id = Uuid::new_v4().to_string();
let started_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let run = Run::new(run_id.clone(), experiment.to_string(), started_at, tags);
let meta_key = format!("exp/{}/run/{}/meta", experiment, run_id);
let meta_json =
serde_json::to_string(&run).map_err(|e| SynaError::InvalidPath(e.to_string()))?;
self.db.append(&meta_key, Atom::Text(meta_json))?;
Ok(run_id)
}
pub fn log_param(&mut self, run_id: &str, key: &str, value: &str) -> Result<()> {
let run = self.get_run_internal(run_id)?;
if run.status != RunStatus::Running {
return Err(SynaError::RunAlreadyEnded(run_id.to_string()));
}
let param_key = format!("exp/{}/run/{}/param/{}", run.experiment_name, run_id, key);
self.db.append(¶m_key, Atom::Text(value.to_string()))?;
let mut updated_run = run;
updated_run
.params
.insert(key.to_string(), value.to_string());
let meta_key = format!("exp/{}/run/{}/meta", updated_run.experiment_name, run_id);
let meta_json = serde_json::to_string(&updated_run)
.map_err(|e| SynaError::InvalidPath(e.to_string()))?;
self.db.append(&meta_key, Atom::Text(meta_json))?;
Ok(())
}
pub fn log_metric(
&mut self,
run_id: &str,
key: &str,
value: f64,
step: Option<u64>,
) -> Result<()> {
let run = self.get_run_internal(run_id)?;
if run.status != RunStatus::Running {
return Err(SynaError::RunAlreadyEnded(run_id.to_string()));
}
let step_num = step.unwrap_or_else(|| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_micros() as u64
});
let metric_key = format!(
"exp/{}/run/{}/metric/{}/{}",
run.experiment_name, run_id, key, step_num
);
self.db.append(&metric_key, Atom::Float(value))?;
Ok(())
}
pub fn log_artifact(&mut self, run_id: &str, name: &str, data: &[u8]) -> Result<()> {
let run = self.get_run_internal(run_id)?;
if run.status != RunStatus::Running {
return Err(SynaError::RunAlreadyEnded(run_id.to_string()));
}
let artifact_key = format!(
"exp/{}/run/{}/artifact/{}",
run.experiment_name, run_id, name
);
self.db.append(&artifact_key, Atom::Bytes(data.to_vec()))?;
Ok(())
}
pub fn get_artifact(&mut self, run_id: &str, name: &str) -> Result<Option<Vec<u8>>> {
let run = self.get_run_internal(run_id)?;
let artifact_key = format!(
"exp/{}/run/{}/artifact/{}",
run.experiment_name, run_id, name
);
match self.db.get(&artifact_key)? {
Some(Atom::Bytes(data)) => Ok(Some(data)),
_ => Ok(None),
}
}
pub fn list_artifacts(&mut self, run_id: &str) -> Result<Vec<String>> {
let run = self.get_run_internal(run_id)?;
let prefix = format!("exp/{}/run/{}/artifact/", run.experiment_name, run_id);
Ok(self
.db
.keys()
.into_iter()
.filter(|k| k.starts_with(&prefix))
.filter_map(|k| k.strip_prefix(&prefix).map(|s| s.to_string()))
.collect())
}
pub fn end_run(&mut self, run_id: &str, status: RunStatus) -> Result<()> {
let mut run = self.get_run_internal(run_id)?;
if run.status != RunStatus::Running {
return Err(SynaError::RunAlreadyEnded(run_id.to_string()));
}
run.status = status;
run.ended_at = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
);
let meta_key = format!("exp/{}/run/{}/meta", run.experiment_name, run_id);
let meta_json =
serde_json::to_string(&run).map_err(|e| SynaError::InvalidPath(e.to_string()))?;
self.db.append(&meta_key, Atom::Text(meta_json))?;
Ok(())
}
pub fn get_run(&mut self, run_id: &str) -> Result<Run> {
self.get_run_internal(run_id)
}
fn get_run_internal(&mut self, run_id: &str) -> Result<Run> {
let prefix = format!("/run/{}/meta", run_id);
for key in self.db.keys() {
if key.contains(&prefix) {
if let Some(Atom::Text(json)) = self.db.get(&key)? {
let run: Run = serde_json::from_str(&json)
.map_err(|e| SynaError::InvalidPath(e.to_string()))?;
return Ok(run);
}
}
}
Err(SynaError::RunNotFound(run_id.to_string()))
}
pub fn list_runs(&mut self, experiment: &str) -> Result<Vec<Run>> {
let prefix = format!("exp/{}/run/", experiment);
let suffix = "/meta";
let mut runs = Vec::new();
for key in self.db.keys() {
if key.starts_with(&prefix) && key.ends_with(suffix) {
if let Some(Atom::Text(json)) = self.db.get(&key)? {
let run: Run = serde_json::from_str(&json)
.map_err(|e| SynaError::InvalidPath(e.to_string()))?;
runs.push(run);
}
}
}
runs.sort_by_key(|r| r.started_at);
Ok(runs)
}
pub fn get_params(&mut self, run_id: &str) -> Result<HashMap<String, String>> {
let run = self.get_run_internal(run_id)?;
let prefix = format!("exp/{}/run/{}/param/", run.experiment_name, run_id);
let mut params = HashMap::new();
for key in self.db.keys() {
if key.starts_with(&prefix) {
if let Some(param_name) = key.strip_prefix(&prefix) {
if let Some(Atom::Text(value)) = self.db.get(&key)? {
params.insert(param_name.to_string(), value);
}
}
}
}
Ok(params)
}
pub fn get_param(&mut self, run_id: &str, param_name: &str) -> Result<Option<String>> {
let run = self.get_run_internal(run_id)?;
let param_key = format!(
"exp/{}/run/{}/param/{}",
run.experiment_name, run_id, param_name
);
if let Some(Atom::Text(value)) = self.db.get(¶m_key)? {
Ok(Some(value))
} else {
Ok(None)
}
}
pub fn get_metric(&mut self, run_id: &str, metric_name: &str) -> Result<Vec<(u64, f64)>> {
let run = self.get_run_internal(run_id)?;
let prefix = format!(
"exp/{}/run/{}/metric/{}/",
run.experiment_name, run_id, metric_name
);
let mut metrics = Vec::new();
for key in self.db.keys() {
if key.starts_with(&prefix) {
if let Some(step_str) = key.strip_prefix(&prefix) {
if let Ok(step) = step_str.parse::<u64>() {
if let Some(Atom::Float(v)) = self.db.get(&key)? {
metrics.push((step, v));
}
}
}
}
}
metrics.sort_by_key(|(step, _)| *step);
Ok(metrics)
}
pub fn get_all_metrics(&mut self, run_id: &str) -> Result<HashMap<String, Vec<(u64, f64)>>> {
let run = self.get_run_internal(run_id)?;
let prefix = format!("exp/{}/run/{}/metric/", run.experiment_name, run_id);
let mut metrics: HashMap<String, Vec<(u64, f64)>> = HashMap::new();
for key in self.db.keys() {
if key.starts_with(&prefix) {
let suffix = key.strip_prefix(&prefix).unwrap_or("");
let parts: Vec<&str> = suffix.split('/').collect();
if parts.len() == 2 {
let metric_name = parts[0];
if let Ok(step) = parts[1].parse::<u64>() {
if let Some(Atom::Float(value)) = self.db.get(&key)? {
metrics
.entry(metric_name.to_string())
.or_default()
.push((step, value));
}
}
}
}
}
for values in metrics.values_mut() {
values.sort_by_key(|(step, _)| *step);
}
Ok(metrics)
}
pub fn query_runs(&mut self, filter: RunFilter) -> Result<Vec<Run>> {
let mut runs = Vec::new();
for key in self.db.keys() {
if key.contains("/run/") && key.ends_with("/meta") {
if let Some(Atom::Text(json)) = self.db.get(&key)? {
if let Ok(run) = serde_json::from_str::<Run>(&json) {
if let Some(ref exp) = filter.experiment {
if &run.experiment_name != exp {
continue;
}
}
if let Some(status) = filter.status {
if run.status != status {
continue;
}
}
if let Some(ref tags) = filter.tags {
if !tags.iter().all(|t| run.tags.contains(t)) {
continue;
}
}
if let Some((ref param_key, ref param_value)) = filter.param_filter {
match run.params.get(param_key) {
Some(value) if value == param_value => {}
_ => continue,
}
}
runs.push(run);
}
}
}
}
runs.sort_by_key(|r| std::cmp::Reverse(r.started_at));
Ok(runs)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_run_status_default() {
let status = RunStatus::default();
assert_eq!(status, RunStatus::Running);
}
#[test]
fn test_run_status_display() {
assert_eq!(format!("{}", RunStatus::Running), "Running");
assert_eq!(format!("{}", RunStatus::Completed), "Completed");
assert_eq!(format!("{}", RunStatus::Failed), "Failed");
assert_eq!(format!("{}", RunStatus::Killed), "Killed");
}
#[test]
fn test_run_new() {
let run = Run::new(
"run_123".to_string(),
"mnist".to_string(),
1234567890,
vec!["baseline".to_string()],
);
assert_eq!(run.id, "run_123");
assert_eq!(run.experiment_name, "mnist");
assert_eq!(run.started_at, 1234567890);
assert_eq!(run.ended_at, None);
assert_eq!(run.status, RunStatus::Running);
assert!(run.params.is_empty());
assert_eq!(run.tags, vec!["baseline".to_string()]);
}
#[test]
fn test_experiment_tracker_new() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_tracker.db");
let tracker = ExperimentTracker::new(&db_path);
assert!(tracker.is_ok());
}
#[test]
fn test_start_run() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_start_run.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run_id = tracker
.start_run("mnist", vec!["baseline".to_string()])
.unwrap();
assert!(uuid::Uuid::parse_str(&run_id).is_ok());
let run = tracker.get_run(&run_id).unwrap();
assert_eq!(run.id, run_id);
assert_eq!(run.experiment_name, "mnist");
assert_eq!(run.status, RunStatus::Running);
assert_eq!(run.tags, vec!["baseline".to_string()]);
}
#[test]
fn test_log_param() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_log_param.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run_id = tracker.start_run("mnist", vec![]).unwrap();
tracker
.log_param(&run_id, "learning_rate", "0.001")
.unwrap();
tracker.log_param(&run_id, "batch_size", "32").unwrap();
let run = tracker.get_run(&run_id).unwrap();
assert_eq!(run.params.get("learning_rate"), Some(&"0.001".to_string()));
assert_eq!(run.params.get("batch_size"), Some(&"32".to_string()));
}
#[test]
fn test_log_metric() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_log_metric.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run_id = tracker.start_run("mnist", vec![]).unwrap();
tracker.log_metric(&run_id, "loss", 0.5, Some(1)).unwrap();
tracker.log_metric(&run_id, "loss", 0.3, Some(2)).unwrap();
tracker
.log_metric(&run_id, "accuracy", 0.85, Some(1))
.unwrap();
let metrics = tracker.get_all_metrics(&run_id).unwrap();
let loss_values = metrics.get("loss").unwrap();
assert_eq!(loss_values.len(), 2);
assert_eq!(loss_values[0], (1, 0.5));
assert_eq!(loss_values[1], (2, 0.3));
let accuracy_values = metrics.get("accuracy").unwrap();
assert_eq!(accuracy_values.len(), 1);
assert_eq!(accuracy_values[0], (1, 0.85));
}
#[test]
fn test_get_metric() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_get_metric.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run_id = tracker.start_run("mnist", vec![]).unwrap();
tracker.log_metric(&run_id, "loss", 0.5, Some(1)).unwrap();
tracker.log_metric(&run_id, "loss", 0.3, Some(2)).unwrap();
tracker.log_metric(&run_id, "loss", 0.1, Some(3)).unwrap();
tracker
.log_metric(&run_id, "accuracy", 0.85, Some(1))
.unwrap();
let loss_values = tracker.get_metric(&run_id, "loss").unwrap();
assert_eq!(loss_values.len(), 3);
assert_eq!(loss_values[0], (1, 0.5));
assert_eq!(loss_values[1], (2, 0.3));
assert_eq!(loss_values[2], (3, 0.1));
let accuracy_values = tracker.get_metric(&run_id, "accuracy").unwrap();
assert_eq!(accuracy_values.len(), 1);
assert_eq!(accuracy_values[0], (1, 0.85));
let empty_values = tracker.get_metric(&run_id, "nonexistent").unwrap();
assert!(empty_values.is_empty());
}
#[test]
fn test_get_params() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_get_params.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run_id = tracker.start_run("mnist", vec![]).unwrap();
tracker
.log_param(&run_id, "learning_rate", "0.001")
.unwrap();
tracker.log_param(&run_id, "batch_size", "32").unwrap();
tracker.log_param(&run_id, "optimizer", "adam").unwrap();
let params = tracker.get_params(&run_id).unwrap();
assert_eq!(params.len(), 3);
assert_eq!(params.get("learning_rate"), Some(&"0.001".to_string()));
assert_eq!(params.get("batch_size"), Some(&"32".to_string()));
assert_eq!(params.get("optimizer"), Some(&"adam".to_string()));
}
#[test]
fn test_get_param() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_get_param.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run_id = tracker.start_run("mnist", vec![]).unwrap();
tracker
.log_param(&run_id, "learning_rate", "0.001")
.unwrap();
tracker.log_param(&run_id, "batch_size", "32").unwrap();
let lr = tracker.get_param(&run_id, "learning_rate").unwrap();
assert_eq!(lr, Some("0.001".to_string()));
let batch_size = tracker.get_param(&run_id, "batch_size").unwrap();
assert_eq!(batch_size, Some("32".to_string()));
let nonexistent = tracker.get_param(&run_id, "nonexistent").unwrap();
assert_eq!(nonexistent, None);
}
#[test]
fn test_log_artifact() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_log_artifact.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run_id = tracker.start_run("mnist", vec![]).unwrap();
let artifact_data = vec![0u8, 1, 2, 3, 4, 5];
tracker
.log_artifact(&run_id, "model.pt", &artifact_data)
.unwrap();
let run = tracker.get_run(&run_id).unwrap();
let artifact_key = format!(
"exp/{}/run/{}/artifact/model.pt",
run.experiment_name, run_id
);
let stored = tracker.db_mut().get(&artifact_key).unwrap();
assert!(matches!(stored, Some(Atom::Bytes(data)) if data == artifact_data));
}
#[test]
fn test_get_artifact() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_get_artifact.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run_id = tracker.start_run("mnist", vec![]).unwrap();
let artifact_data = vec![0u8, 1, 2, 3, 4, 5];
tracker
.log_artifact(&run_id, "model.pt", &artifact_data)
.unwrap();
let retrieved = tracker.get_artifact(&run_id, "model.pt").unwrap();
assert_eq!(retrieved, Some(artifact_data));
let nonexistent = tracker.get_artifact(&run_id, "nonexistent.pt").unwrap();
assert_eq!(nonexistent, None);
}
#[test]
fn test_list_artifacts() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_list_artifacts.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run_id = tracker.start_run("mnist", vec![]).unwrap();
tracker
.log_artifact(&run_id, "model.pt", &[1, 2, 3])
.unwrap();
tracker
.log_artifact(&run_id, "config.json", &[4, 5, 6])
.unwrap();
tracker
.log_artifact(&run_id, "weights.bin", &[7, 8, 9])
.unwrap();
let artifacts = tracker.list_artifacts(&run_id).unwrap();
assert_eq!(artifacts.len(), 3);
assert!(artifacts.contains(&"model.pt".to_string()));
assert!(artifacts.contains(&"config.json".to_string()));
assert!(artifacts.contains(&"weights.bin".to_string()));
}
#[test]
fn test_list_artifacts_empty() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_list_artifacts_empty.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run_id = tracker.start_run("mnist", vec![]).unwrap();
let artifacts = tracker.list_artifacts(&run_id).unwrap();
assert!(artifacts.is_empty());
}
#[test]
fn test_end_run() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_end_run.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run_id = tracker.start_run("mnist", vec![]).unwrap();
tracker.end_run(&run_id, RunStatus::Completed).unwrap();
let run = tracker.get_run(&run_id).unwrap();
assert_eq!(run.status, RunStatus::Completed);
assert!(run.ended_at.is_some());
}
#[test]
fn test_cannot_log_to_ended_run() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_ended_run.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run_id = tracker.start_run("mnist", vec![]).unwrap();
tracker.end_run(&run_id, RunStatus::Completed).unwrap();
let result = tracker.log_param(&run_id, "key", "value");
assert!(matches!(result, Err(SynaError::RunAlreadyEnded(_))));
let result = tracker.log_metric(&run_id, "loss", 0.5, Some(1));
assert!(matches!(result, Err(SynaError::RunAlreadyEnded(_))));
let result = tracker.log_artifact(&run_id, "file", &[1, 2, 3]);
assert!(matches!(result, Err(SynaError::RunAlreadyEnded(_))));
let result = tracker.end_run(&run_id, RunStatus::Failed);
assert!(matches!(result, Err(SynaError::RunAlreadyEnded(_))));
}
#[test]
fn test_list_runs() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_list_runs.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run1 = tracker.start_run("mnist", vec!["v1".to_string()]).unwrap();
let run2 = tracker.start_run("mnist", vec!["v2".to_string()]).unwrap();
let run3 = tracker.start_run("cifar", vec![]).unwrap();
let mnist_runs = tracker.list_runs("mnist").unwrap();
assert_eq!(mnist_runs.len(), 2);
assert!(mnist_runs.iter().any(|r| r.id == run1));
assert!(mnist_runs.iter().any(|r| r.id == run2));
let cifar_runs = tracker.list_runs("cifar").unwrap();
assert_eq!(cifar_runs.len(), 1);
assert_eq!(cifar_runs[0].id, run3);
}
#[test]
fn test_run_not_found() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_not_found.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let result = tracker.get_run("nonexistent");
assert!(matches!(result, Err(SynaError::RunNotFound(_))));
}
#[test]
fn test_run_status_serialization() {
let status = RunStatus::Completed;
let serialized = serde_json::to_string(&status).unwrap();
let deserialized: RunStatus = serde_json::from_str(&serialized).unwrap();
assert_eq!(status, deserialized);
}
#[test]
fn test_run_serialization() {
let run = Run::new(
"run_123".to_string(),
"mnist".to_string(),
1234567890,
vec!["baseline".to_string()],
);
let serialized = serde_json::to_string(&run).unwrap();
let deserialized: Run = serde_json::from_str(&serialized).unwrap();
assert_eq!(run.id, deserialized.id);
assert_eq!(run.experiment_name, deserialized.experiment_name);
assert_eq!(run.status, deserialized.status);
}
#[test]
fn test_query_runs_no_filter() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_query_no_filter.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let _run1 = tracker.start_run("mnist", vec![]).unwrap();
let _run2 = tracker.start_run("mnist", vec![]).unwrap();
let _run3 = tracker.start_run("cifar", vec![]).unwrap();
let filter = RunFilter::default();
let runs = tracker.query_runs(filter).unwrap();
assert_eq!(runs.len(), 3);
}
#[test]
fn test_query_runs_by_experiment() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_query_by_exp.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run1 = tracker.start_run("mnist", vec![]).unwrap();
let run2 = tracker.start_run("mnist", vec![]).unwrap();
let _run3 = tracker.start_run("cifar", vec![]).unwrap();
let filter = RunFilter {
experiment: Some("mnist".to_string()),
..Default::default()
};
let runs = tracker.query_runs(filter).unwrap();
assert_eq!(runs.len(), 2);
assert!(runs.iter().any(|r| r.id == run1));
assert!(runs.iter().any(|r| r.id == run2));
}
#[test]
fn test_query_runs_by_status() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_query_by_status.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run1 = tracker.start_run("mnist", vec![]).unwrap();
let run2 = tracker.start_run("mnist", vec![]).unwrap();
let run3 = tracker.start_run("mnist", vec![]).unwrap();
tracker.end_run(&run1, RunStatus::Completed).unwrap();
tracker.end_run(&run2, RunStatus::Failed).unwrap();
let filter = RunFilter {
status: Some(RunStatus::Completed),
..Default::default()
};
let runs = tracker.query_runs(filter).unwrap();
assert_eq!(runs.len(), 1);
assert_eq!(runs[0].id, run1);
let filter = RunFilter {
status: Some(RunStatus::Running),
..Default::default()
};
let runs = tracker.query_runs(filter).unwrap();
assert_eq!(runs.len(), 1);
assert_eq!(runs[0].id, run3);
}
#[test]
fn test_query_runs_by_tags() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_query_by_tags.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run1 = tracker
.start_run("mnist", vec!["baseline".to_string()])
.unwrap();
let run2 = tracker
.start_run("mnist", vec!["baseline".to_string(), "v2".to_string()])
.unwrap();
let _run3 = tracker
.start_run("mnist", vec!["experimental".to_string()])
.unwrap();
let filter = RunFilter {
tags: Some(vec!["baseline".to_string()]),
..Default::default()
};
let runs = tracker.query_runs(filter).unwrap();
assert_eq!(runs.len(), 2);
assert!(runs.iter().any(|r| r.id == run1));
assert!(runs.iter().any(|r| r.id == run2));
let filter = RunFilter {
tags: Some(vec!["baseline".to_string(), "v2".to_string()]),
..Default::default()
};
let runs = tracker.query_runs(filter).unwrap();
assert_eq!(runs.len(), 1);
assert_eq!(runs[0].id, run2);
}
#[test]
fn test_query_runs_by_param() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_query_by_param.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run1 = tracker.start_run("mnist", vec![]).unwrap();
tracker.log_param(&run1, "learning_rate", "0.001").unwrap();
let run2 = tracker.start_run("mnist", vec![]).unwrap();
tracker.log_param(&run2, "learning_rate", "0.01").unwrap();
let run3 = tracker.start_run("mnist", vec![]).unwrap();
tracker.log_param(&run3, "learning_rate", "0.001").unwrap();
let filter = RunFilter {
param_filter: Some(("learning_rate".to_string(), "0.001".to_string())),
..Default::default()
};
let runs = tracker.query_runs(filter).unwrap();
assert_eq!(runs.len(), 2);
assert!(runs.iter().any(|r| r.id == run1));
assert!(runs.iter().any(|r| r.id == run3));
}
#[test]
fn test_query_runs_combined_filters() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_query_combined.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run1 = tracker
.start_run("mnist", vec!["baseline".to_string()])
.unwrap();
tracker.log_param(&run1, "lr", "0.001").unwrap();
tracker.end_run(&run1, RunStatus::Completed).unwrap();
let run2 = tracker
.start_run("mnist", vec!["baseline".to_string()])
.unwrap();
tracker.log_param(&run2, "lr", "0.001").unwrap();
let run3 = tracker
.start_run("cifar", vec!["baseline".to_string()])
.unwrap();
tracker.log_param(&run3, "lr", "0.001").unwrap();
tracker.end_run(&run3, RunStatus::Completed).unwrap();
let filter = RunFilter {
experiment: Some("mnist".to_string()),
status: Some(RunStatus::Completed),
tags: Some(vec!["baseline".to_string()]),
param_filter: Some(("lr".to_string(), "0.001".to_string())),
};
let runs = tracker.query_runs(filter).unwrap();
assert_eq!(runs.len(), 1);
assert_eq!(runs[0].id, run1);
}
#[test]
fn test_query_runs_sorted_by_start_time() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_query_sorted.db");
let mut tracker = ExperimentTracker::new(&db_path).unwrap();
let run1 = tracker.start_run("mnist", vec![]).unwrap();
std::thread::sleep(std::time::Duration::from_secs(1));
let run2 = tracker.start_run("mnist", vec![]).unwrap();
std::thread::sleep(std::time::Duration::from_secs(1));
let run3 = tracker.start_run("mnist", vec![]).unwrap();
let filter = RunFilter::default();
let runs = tracker.query_runs(filter).unwrap();
assert_eq!(runs.len(), 3);
assert_eq!(runs[0].id, run3); assert_eq!(runs[1].id, run2);
assert_eq!(runs[2].id, run1); }
#[test]
fn test_run_filter_default() {
let filter = RunFilter::default();
assert!(filter.experiment.is_none());
assert!(filter.status.is_none());
assert!(filter.tags.is_none());
assert!(filter.param_filter.is_none());
}
}