use std::collections::HashMap;
use std::fmt::Write as FmtWrite;
pub type RunId = String;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TrackerRunStatus {
Running,
Finished,
Failed,
Killed,
}
impl TrackerRunStatus {
fn as_str(&self) -> &str {
match self {
Self::Running => "running",
Self::Finished => "finished",
Self::Failed => "failed",
Self::Killed => "killed",
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum TrackingError {
NoActiveRun,
AlreadyStarted,
IoError(String),
InvalidValue(String),
}
impl std::fmt::Display for TrackingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoActiveRun => write!(f, "no active run — call start_run() first"),
Self::AlreadyStarted => write!(f, "a run is already in progress"),
Self::IoError(msg) => write!(f, "tracking I/O error: {msg}"),
Self::InvalidValue(msg) => write!(f, "invalid tracking value: {msg}"),
}
}
}
impl std::error::Error for TrackingError {}
pub trait ExperimentTracker {
fn log_metric(&mut self, key: &str, value: f64, step: u64) -> Result<(), TrackingError>;
fn log_param(&mut self, key: &str, value: &str) -> Result<(), TrackingError>;
fn log_artifact(
&mut self,
local_path: &str,
artifact_path: Option<&str>,
) -> Result<(), TrackingError>;
fn set_tags(&mut self, tags: &[(&str, &str)]) -> Result<(), TrackingError>;
fn start_run(&mut self, run_name: &str) -> Result<RunId, TrackingError>;
fn end_run(&mut self, status: TrackerRunStatus) -> Result<(), TrackingError>;
fn get_run_id(&self) -> Option<&RunId>;
}
#[derive(Debug, Clone)]
pub struct TrackingSummary {
pub run_id: Option<RunId>,
pub total_metrics: usize,
pub total_params: usize,
pub total_artifacts: usize,
pub status: String,
}
pub struct InMemoryTracker {
run_id: Option<RunId>,
metrics: HashMap<String, Vec<(u64, f64)>>,
params: HashMap<String, String>,
tags: HashMap<String, String>,
artifacts: Vec<String>,
status: Option<TrackerRunStatus>,
}
impl InMemoryTracker {
pub fn new() -> Self {
Self {
run_id: None,
metrics: HashMap::new(),
params: HashMap::new(),
tags: HashMap::new(),
artifacts: Vec::new(),
status: None,
}
}
pub fn get_metric_history(&self, key: &str) -> Option<&Vec<(u64, f64)>> {
self.metrics.get(key)
}
pub fn get_param(&self, key: &str) -> Option<&str> {
self.params.get(key).map(|s| s.as_str())
}
pub fn metric_count(&self) -> usize {
self.metrics.len()
}
pub fn export_to_json(&self) -> String {
let run_id_json = self
.run_id
.as_deref()
.map(|id| format!("\"{}\"", escape_json(id)))
.unwrap_or_else(|| "null".to_string());
let status_str = self
.status
.as_ref()
.map(|s| s.as_str())
.unwrap_or("null");
let status_json = if status_str == "null" {
"null".to_string()
} else {
format!("\"{}\"", status_str)
};
let params_json = map_to_json_obj(&self.params);
let tags_json = map_to_json_obj(&self.tags);
let artifacts_json = {
let mut out = String::from('[');
for (i, a) in self.artifacts.iter().enumerate() {
if i > 0 {
out.push(',');
}
let _ = write!(out, "\"{}\"", escape_json(a));
}
out.push(']');
out
};
let metrics_json = {
let mut out = String::from('{');
for (i, (key, series)) in self.metrics.iter().enumerate() {
if i > 0 {
out.push(',');
}
let _ = write!(out, "\"{}\":[", escape_json(key));
for (j, (step, val)) in series.iter().enumerate() {
if j > 0 {
out.push(',');
}
let _ = write!(out, "[{step},{val}]");
}
out.push(']');
}
out.push('}');
out
};
format!(
r#"{{"run_id":{run_id_json},"status":{status_json},"params":{params_json},"tags":{tags_json},"artifacts":{artifacts_json},"metrics":{metrics_json}}}"#
)
}
pub fn to_summary(&self) -> TrackingSummary {
TrackingSummary {
run_id: self.run_id.clone(),
total_metrics: self.metrics.len(),
total_params: self.params.len(),
total_artifacts: self.artifacts.len(),
status: self
.status
.as_ref()
.map(|s| s.as_str().to_string())
.unwrap_or_else(|| "none".to_string()),
}
}
pub fn artifacts(&self) -> &[String] {
&self.artifacts
}
pub fn tags(&self) -> &HashMap<String, String> {
&self.tags
}
}
impl ExperimentTracker for InMemoryTracker {
fn start_run(&mut self, run_name: &str) -> Result<RunId, TrackingError> {
if self.run_id.is_some() {
return Err(TrackingError::AlreadyStarted);
}
if run_name.is_empty() {
return Err(TrackingError::InvalidValue("run_name must not be empty".to_string()));
}
let id = format!("run_{}", run_name.replace(' ', "_"));
self.run_id = Some(id.clone());
self.status = Some(TrackerRunStatus::Running);
Ok(id)
}
fn end_run(&mut self, status: TrackerRunStatus) -> Result<(), TrackingError> {
if self.run_id.is_none() {
return Err(TrackingError::NoActiveRun);
}
self.status = Some(status);
Ok(())
}
fn log_metric(&mut self, key: &str, value: f64, step: u64) -> Result<(), TrackingError> {
if self.run_id.is_none() {
return Err(TrackingError::NoActiveRun);
}
if key.is_empty() {
return Err(TrackingError::InvalidValue("metric key must not be empty".to_string()));
}
if !value.is_finite() {
return Err(TrackingError::InvalidValue(format!(
"metric value for '{key}' is not finite: {value}"
)));
}
self.metrics.entry(key.to_string()).or_default().push((step, value));
Ok(())
}
fn log_param(&mut self, key: &str, value: &str) -> Result<(), TrackingError> {
if self.run_id.is_none() {
return Err(TrackingError::NoActiveRun);
}
if key.is_empty() {
return Err(TrackingError::InvalidValue("param key must not be empty".to_string()));
}
self.params.insert(key.to_string(), value.to_string());
Ok(())
}
fn log_artifact(
&mut self,
local_path: &str,
_artifact_path: Option<&str>,
) -> Result<(), TrackingError> {
if self.run_id.is_none() {
return Err(TrackingError::NoActiveRun);
}
if local_path.is_empty() {
return Err(TrackingError::InvalidValue("artifact path must not be empty".to_string()));
}
self.artifacts.push(local_path.to_string());
Ok(())
}
fn set_tags(&mut self, tags: &[(&str, &str)]) -> Result<(), TrackingError> {
if self.run_id.is_none() {
return Err(TrackingError::NoActiveRun);
}
for (k, v) in tags {
if k.is_empty() {
return Err(TrackingError::InvalidValue("tag key must not be empty".to_string()));
}
self.tags.insert(k.to_string(), v.to_string());
}
Ok(())
}
fn get_run_id(&self) -> Option<&RunId> {
self.run_id.as_ref()
}
}
fn escape_json(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for c in s.chars() {
match c {
'"' => out.push_str("\\\""),
'\\' => out.push_str("\\\\"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
c => out.push(c),
}
}
out
}
fn map_to_json_obj(map: &HashMap<String, String>) -> String {
let mut out = String::from('{');
for (i, (k, v)) in map.iter().enumerate() {
if i > 0 {
out.push(',');
}
let _ = write!(out, "\"{}\":\"{}\"", escape_json(k), escape_json(v));
}
out.push('}');
out
}
#[cfg(test)]
mod tests {
use super::*;
fn started_tracker() -> InMemoryTracker {
let mut t = InMemoryTracker::new();
t.start_run("test_run").unwrap();
t
}
#[test]
fn test_start_run_returns_id() {
let mut t = InMemoryTracker::new();
let id = t.start_run("training").unwrap();
assert!(!id.is_empty());
assert_eq!(t.get_run_id(), Some(&id));
}
#[test]
fn test_start_run_already_started() {
let mut t = started_tracker();
let err = t.start_run("second_run");
assert!(matches!(err, Err(TrackingError::AlreadyStarted)));
}
#[test]
fn test_start_run_empty_name_rejected() {
let mut t = InMemoryTracker::new();
assert!(matches!(t.start_run(""), Err(TrackingError::InvalidValue(_))));
}
#[test]
fn test_end_run_no_active_run() {
let mut t = InMemoryTracker::new();
assert!(matches!(t.end_run(TrackerRunStatus::Finished), Err(TrackingError::NoActiveRun)));
}
#[test]
fn test_end_run_sets_status() {
let mut t = started_tracker();
t.end_run(TrackerRunStatus::Failed).unwrap();
let summary = t.to_summary();
assert_eq!(summary.status, "failed");
}
#[test]
fn test_log_metric_basic() {
let mut t = started_tracker();
t.log_metric("loss", 1.5, 0).unwrap();
t.log_metric("loss", 1.2, 1).unwrap();
let hist = t.get_metric_history("loss").unwrap();
assert_eq!(hist.len(), 2);
assert_eq!(hist[0], (0, 1.5));
assert_eq!(hist[1], (1, 1.2));
}
#[test]
fn test_log_metric_no_run() {
let mut t = InMemoryTracker::new();
assert!(matches!(t.log_metric("x", 1.0, 0), Err(TrackingError::NoActiveRun)));
}
#[test]
fn test_log_metric_nan_rejected() {
let mut t = started_tracker();
assert!(matches!(t.log_metric("x", f64::NAN, 0), Err(TrackingError::InvalidValue(_))));
}
#[test]
fn test_log_metric_inf_rejected() {
let mut t = started_tracker();
assert!(matches!(t.log_metric("x", f64::INFINITY, 0), Err(TrackingError::InvalidValue(_))));
}
#[test]
fn test_log_metric_empty_key_rejected() {
let mut t = started_tracker();
assert!(matches!(t.log_metric("", 1.0, 0), Err(TrackingError::InvalidValue(_))));
}
#[test]
fn test_log_param_and_retrieve() {
let mut t = started_tracker();
t.log_param("lr", "0.001").unwrap();
assert_eq!(t.get_param("lr"), Some("0.001"));
}
#[test]
fn test_log_param_overwrite() {
let mut t = started_tracker();
t.log_param("lr", "0.001").unwrap();
t.log_param("lr", "0.0001").unwrap();
assert_eq!(t.get_param("lr"), Some("0.0001"));
}
#[test]
fn test_log_param_no_run() {
let mut t = InMemoryTracker::new();
assert!(matches!(t.log_param("k", "v"), Err(TrackingError::NoActiveRun)));
}
#[test]
fn test_log_artifact_stored() {
let mut t = started_tracker();
t.log_artifact("/tmp/model.bin", None).unwrap();
assert_eq!(t.artifacts(), &["/tmp/model.bin".to_string()]);
}
#[test]
fn test_log_artifact_no_run() {
let mut t = InMemoryTracker::new();
assert!(matches!(t.log_artifact("/tmp/x", None), Err(TrackingError::NoActiveRun)));
}
#[test]
fn test_set_tags() {
let mut t = started_tracker();
t.set_tags(&[("framework", "trustformers"), ("version", "0.1.0")]).unwrap();
assert_eq!(t.tags().get("framework").map(|s| s.as_str()), Some("trustformers"));
}
#[test]
fn test_set_tags_no_run() {
let mut t = InMemoryTracker::new();
assert!(matches!(t.set_tags(&[("k", "v")]), Err(TrackingError::NoActiveRun)));
}
#[test]
fn test_metric_count() {
let mut t = started_tracker();
t.log_metric("loss", 1.0, 0).unwrap();
t.log_metric("acc", 0.9, 0).unwrap();
assert_eq!(t.metric_count(), 2);
}
#[test]
fn test_summary_fields() {
let mut t = started_tracker();
t.log_metric("loss", 0.5, 0).unwrap();
t.log_param("epochs", "5").unwrap();
t.log_artifact("/tmp/ckpt.bin", None).unwrap();
let s = t.to_summary();
assert!(s.run_id.is_some());
assert_eq!(s.total_metrics, 1);
assert_eq!(s.total_params, 1);
assert_eq!(s.total_artifacts, 1);
assert_eq!(s.status, "running");
}
#[test]
fn test_export_json_structure() {
let mut t = started_tracker();
t.log_metric("loss", 2.0, 0).unwrap();
t.log_param("lr", "0.01").unwrap();
t.set_tags(&[("model", "gpt2")]).unwrap();
t.log_artifact("/tmp/weights.bin", None).unwrap();
let json = t.export_to_json();
assert!(json.contains("\"run_id\""));
assert!(json.contains("\"params\""));
assert!(json.contains("\"metrics\""));
assert!(json.contains("\"tags\""));
assert!(json.contains("\"artifacts\""));
assert!(json.contains("\"loss\""));
assert!(json.contains("\"lr\""));
assert!(json.contains("gpt2"));
}
#[test]
fn test_export_json_no_run() {
let t = InMemoryTracker::new();
let json = t.export_to_json();
assert!(json.contains("\"run_id\":null"));
assert!(json.contains("\"status\":null"));
}
#[test]
fn test_tracking_error_display() {
assert!(TrackingError::NoActiveRun.to_string().contains("no active run"));
assert!(TrackingError::AlreadyStarted.to_string().contains("already in progress"));
assert!(TrackingError::IoError("disk full".to_string()).to_string().contains("disk full"));
assert!(TrackingError::InvalidValue("nan".to_string()).to_string().contains("nan"));
}
}