Skip to main content

batuta/recipes/
experiment_tracking.rs

1//! Experiment tracking recipe implementation.
2
3use crate::experiment::{
4    EnergyMetrics, ExperimentError, ExperimentRun, ExperimentStorage, ModelParadigm,
5};
6use crate::recipes::{ExperimentTrackingConfig, RecipeResult};
7
8/// Experiment tracking recipe
9#[derive(Debug)]
10pub struct ExperimentTrackingRecipe {
11    config: ExperimentTrackingConfig,
12    current_run: Option<ExperimentRun>,
13    start_time: Option<std::time::Instant>,
14}
15
16impl ExperimentTrackingRecipe {
17    /// Create a new experiment tracking recipe
18    pub fn new(config: ExperimentTrackingConfig) -> Self {
19        Self { config, current_run: None, start_time: None }
20    }
21
22    /// Start a new experiment run
23    pub fn start_run(&mut self, run_id: impl Into<String>) -> &mut ExperimentRun {
24        let mut run = ExperimentRun::new(
25            run_id,
26            &self.config.experiment_name,
27            self.config.paradigm,
28            self.config.device.clone(),
29        );
30        run.platform = self.config.platform;
31        run.tags = self.config.tags.clone();
32        self.current_run = Some(run);
33        self.start_time = Some(std::time::Instant::now());
34        self.current_run.as_mut().expect("current_run was just set to Some")
35    }
36
37    /// Log a metric to the current run
38    pub fn log_metric(
39        &mut self,
40        name: impl Into<String>,
41        value: f64,
42    ) -> Result<(), ExperimentError> {
43        self.current_run
44            .as_mut()
45            .ok_or_else(|| ExperimentError::StorageError("No active run".to_string()))?
46            .log_metric(name, value);
47        Ok(())
48    }
49
50    /// Log a hyperparameter
51    pub fn log_param(
52        &mut self,
53        name: impl Into<String>,
54        value: serde_json::Value,
55    ) -> Result<(), ExperimentError> {
56        self.current_run
57            .as_mut()
58            .ok_or_else(|| ExperimentError::StorageError("No active run".to_string()))?
59            .log_param(name, value);
60        Ok(())
61    }
62
63    /// End the current run and calculate metrics
64    pub fn end_run(&mut self, success: bool) -> Result<RecipeResult, ExperimentError> {
65        let run = self
66            .current_run
67            .as_mut()
68            .ok_or_else(|| ExperimentError::StorageError("No active run".to_string()))?;
69
70        if success {
71            run.complete();
72        } else {
73            run.fail();
74        }
75
76        let duration = self.start_time.take().map(|t| t.elapsed().as_secs_f64()).unwrap_or(0.0);
77
78        // Calculate energy metrics if enabled
79        if self.config.track_energy {
80            let power = self.config.device.estimated_power_watts() as f64;
81            let energy_joules = power * duration;
82            let mut energy = EnergyMetrics::new(energy_joules, power, power * 1.2, duration);
83
84            if let Some(carbon_intensity) = self.config.carbon_intensity {
85                energy = energy.with_carbon_intensity(carbon_intensity);
86            }
87
88            run.energy = Some(energy);
89        }
90
91        let mut result = RecipeResult::success("experiment-tracking");
92        result = result.with_metric("duration_seconds", duration);
93
94        if let Some(ref energy) = run.energy {
95            result = result.with_metric("energy_joules", energy.total_joules);
96            if let Some(co2) = energy.co2_grams {
97                result = result.with_metric("co2_grams", co2);
98            }
99        }
100
101        for (name, value) in &run.metrics {
102            result = result.with_metric(format!("run_{}", name), *value);
103        }
104
105        Ok(result)
106    }
107
108    /// Get the current run
109    pub fn current_run(&self) -> Option<&ExperimentRun> {
110        self.current_run.as_ref()
111    }
112
113    /// Store the run to a backend
114    pub fn store_run<S: ExperimentStorage>(&self, storage: &S) -> Result<(), ExperimentError> {
115        if let Some(ref run) = self.current_run {
116            storage.store_run(run)?;
117        }
118        Ok(())
119    }
120}