batuta/recipes/
experiment_tracking.rs1use crate::experiment::{
4 EnergyMetrics, ExperimentError, ExperimentRun, ExperimentStorage, ModelParadigm,
5};
6use crate::recipes::{ExperimentTrackingConfig, RecipeResult};
7
8#[derive(Debug)]
10pub struct ExperimentTrackingRecipe {
11 config: ExperimentTrackingConfig,
12 current_run: Option<ExperimentRun>,
13 start_time: Option<std::time::Instant>,
14}
15
16impl ExperimentTrackingRecipe {
17 pub fn new(config: ExperimentTrackingConfig) -> Self {
19 Self { config, current_run: None, start_time: None }
20 }
21
22 pub fn start_run(&mut self, run_id: impl Into<String>) -> &mut ExperimentRun {
24 let mut run = ExperimentRun::new(
25 run_id,
26 &self.config.experiment_name,
27 self.config.paradigm,
28 self.config.device.clone(),
29 );
30 run.platform = self.config.platform;
31 run.tags = self.config.tags.clone();
32 self.current_run = Some(run);
33 self.start_time = Some(std::time::Instant::now());
34 self.current_run.as_mut().expect("current_run was just set to Some")
35 }
36
37 pub fn log_metric(
39 &mut self,
40 name: impl Into<String>,
41 value: f64,
42 ) -> Result<(), ExperimentError> {
43 self.current_run
44 .as_mut()
45 .ok_or_else(|| ExperimentError::StorageError("No active run".to_string()))?
46 .log_metric(name, value);
47 Ok(())
48 }
49
50 pub fn log_param(
52 &mut self,
53 name: impl Into<String>,
54 value: serde_json::Value,
55 ) -> Result<(), ExperimentError> {
56 self.current_run
57 .as_mut()
58 .ok_or_else(|| ExperimentError::StorageError("No active run".to_string()))?
59 .log_param(name, value);
60 Ok(())
61 }
62
63 pub fn end_run(&mut self, success: bool) -> Result<RecipeResult, ExperimentError> {
65 let run = self
66 .current_run
67 .as_mut()
68 .ok_or_else(|| ExperimentError::StorageError("No active run".to_string()))?;
69
70 if success {
71 run.complete();
72 } else {
73 run.fail();
74 }
75
76 let duration = self.start_time.take().map(|t| t.elapsed().as_secs_f64()).unwrap_or(0.0);
77
78 if self.config.track_energy {
80 let power = self.config.device.estimated_power_watts() as f64;
81 let energy_joules = power * duration;
82 let mut energy = EnergyMetrics::new(energy_joules, power, power * 1.2, duration);
83
84 if let Some(carbon_intensity) = self.config.carbon_intensity {
85 energy = energy.with_carbon_intensity(carbon_intensity);
86 }
87
88 run.energy = Some(energy);
89 }
90
91 let mut result = RecipeResult::success("experiment-tracking");
92 result = result.with_metric("duration_seconds", duration);
93
94 if let Some(ref energy) = run.energy {
95 result = result.with_metric("energy_joules", energy.total_joules);
96 if let Some(co2) = energy.co2_grams {
97 result = result.with_metric("co2_grams", co2);
98 }
99 }
100
101 for (name, value) in &run.metrics {
102 result = result.with_metric(format!("run_{}", name), *value);
103 }
104
105 Ok(result)
106 }
107
108 pub fn current_run(&self) -> Option<&ExperimentRun> {
110 self.current_run.as_ref()
111 }
112
113 pub fn store_run<S: ExperimentStorage>(&self, storage: &S) -> Result<(), ExperimentError> {
115 if let Some(ref run) = self.current_run {
116 storage.store_run(run)?;
117 }
118 Ok(())
119 }
120}