pub trait ExperimentTracker: Send + Sync {
fn log_metric(&self, key: &str, value: f64, step: usize);
fn log_text(&self, key: &str, value: &str) {
let _ = (key, value);
}
fn finish(&self) {}
}
pub struct NoopTracker;
impl ExperimentTracker for NoopTracker {
#[inline]
fn log_metric(&self, _key: &str, _value: f64, _step: usize) {}
#[inline]
fn log_text(&self, _key: &str, _value: &str) {}
#[inline]
fn finish(&self) {}
}
#[cfg(feature = "wandb")]
pub struct WandbTracker {
pub project: String,
pub run_name: Option<String>,
}
#[cfg(feature = "wandb")]
impl ExperimentTracker for WandbTracker {
fn log_metric(&self, key: &str, value: f64, step: usize) {
tracing::info!(
tracker = "wandb",
project = %self.project,
key,
value,
step,
"log_metric"
);
}
fn log_text(&self, key: &str, value: &str) {
tracing::info!(
tracker = "wandb",
project = %self.project,
key,
value,
"log_text"
);
}
fn finish(&self) {
tracing::info!(
tracker = "wandb",
project = %self.project,
"finish"
);
}
}
#[cfg(feature = "mlflow")]
pub struct MlflowTracker {
pub tracking_uri: String,
pub experiment_name: String,
}
#[cfg(feature = "mlflow")]
impl ExperimentTracker for MlflowTracker {
fn log_metric(&self, key: &str, value: f64, step: usize) {
tracing::info!(
tracker = "mlflow",
uri = %self.tracking_uri,
experiment = %self.experiment_name,
key,
value,
step,
"log_metric"
);
}
fn log_text(&self, key: &str, value: &str) {
tracing::info!(
tracker = "mlflow",
uri = %self.tracking_uri,
experiment = %self.experiment_name,
key,
value,
"log_text"
);
}
fn finish(&self) {
tracing::info!(
tracker = "mlflow",
uri = %self.tracking_uri,
experiment = %self.experiment_name,
"finish"
);
}
}
#[cfg(test)]
mod tests {
use super::*;
struct RecordingTracker {
calls: std::sync::Mutex<Vec<(String, f64, usize)>>,
}
impl RecordingTracker {
fn new() -> Self {
Self {
calls: std::sync::Mutex::new(Vec::new()),
}
}
fn recorded(&self) -> Vec<(String, f64, usize)> {
self.calls.lock().unwrap().clone()
}
}
impl ExperimentTracker for RecordingTracker {
fn log_metric(&self, key: &str, value: f64, step: usize) {
self.calls
.lock()
.unwrap()
.push((key.to_owned(), value, step));
}
}
#[test]
fn noop_tracker_compiles_and_is_silent() {
let t = NoopTracker;
t.log_metric("loss", 0.5, 0);
t.log_text("prompt", "hello");
t.finish();
}
#[test]
fn recording_tracker_captures_metrics() {
let t = RecordingTracker::new();
t.log_metric("accuracy", 0.8, 1);
t.log_metric("loss", 0.2, 2);
let calls = t.recorded();
assert_eq!(calls.len(), 2);
assert_eq!(calls[0], ("accuracy".to_owned(), 0.8, 1));
assert_eq!(calls[1], ("loss".to_owned(), 0.2, 2));
}
#[test]
fn experiment_tracker_is_object_safe() {
let t: Box<dyn ExperimentTracker> = Box::new(NoopTracker);
t.log_metric("x", 1.0, 0);
t.finish();
}
#[test]
fn default_log_text_and_finish_are_no_ops() {
let t = RecordingTracker::new();
t.log_text("key", "val");
t.finish();
assert!(t.recorded().is_empty());
}
}