1use std::collections::HashMap;
35use std::path::PathBuf;
36use std::sync::{Arc, Mutex};
37use std::time::{SystemTime, UNIX_EPOCH};
38
39use serde::{Deserialize, Serialize};
40
41use crate::storage::{ExperimentStorage, Result, RunStatus, StorageError};
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45pub struct TracingConfig {
46 pub tracing_enabled: bool,
48
49 pub export_otlp: bool,
51
52 pub golden_trace_path: Option<PathBuf>,
54}
55
56impl Default for TracingConfig {
57 fn default() -> Self {
58 Self { tracing_enabled: true, export_otlp: false, golden_trace_path: None }
59 }
60}
61
62impl TracingConfig {
63 pub fn disabled() -> Self {
65 Self { tracing_enabled: false, export_otlp: false, golden_trace_path: None }
66 }
67
68 pub fn with_otlp_export(mut self) -> Self {
70 self.export_otlp = true;
71 self
72 }
73
74 pub fn with_golden_trace_path(mut self, path: impl Into<PathBuf>) -> Self {
76 self.golden_trace_path = Some(path.into());
77 self
78 }
79}
80
81pub struct Run<S: ExperimentStorage> {
86 pub id: String,
88 pub experiment_id: String,
90 pub(crate) storage: Arc<Mutex<S>>,
92 span: Option<String>,
94 config: TracingConfig,
96 pub(crate) step_counters: HashMap<String, u64>,
98 finished: bool,
100}
101
102impl<S: ExperimentStorage> Run<S> {
103 fn lock_storage(storage: &Arc<Mutex<S>>) -> Result<std::sync::MutexGuard<'_, S>> {
105 storage.lock().map_err(|e| StorageError::Backend(format!("mutex poisoned: {e}")))
106 }
107
108 fn lock(&self) -> Result<std::sync::MutexGuard<'_, S>> {
110 Self::lock_storage(&self.storage)
111 }
112
113 pub fn new(experiment_id: &str, storage: Arc<Mutex<S>>, config: TracingConfig) -> Result<Self> {
124 let run_id = {
126 let mut store = Self::lock_storage(&storage)?;
127 let run_id = store.create_run(experiment_id)?;
128 store.start_run(&run_id)?;
129 run_id
130 };
131
132 let span = if config.tracing_enabled {
134 let span_id = Self::create_span(&run_id);
135 Self::lock_storage(&storage)?.set_span_id(&run_id, &span_id)?;
136 Some(span_id)
137 } else {
138 None
139 };
140
141 Ok(Self {
142 id: run_id,
143 experiment_id: experiment_id.to_string(),
144 storage,
145 span,
146 config,
147 step_counters: HashMap::new(),
148 finished: false,
149 })
150 }
151
152 fn create_span(run_id: &str) -> String {
154 let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default();
155
156 format!("span-{}-{}", run_id, now.as_nanos())
157 }
158
159 pub fn log_metric(&mut self, key: &str, value: f64) -> Result<()> {
169 let step = *self.step_counters.get(key).unwrap_or(&0);
170 self.log_metric_at(key, step, value)?;
171 self.step_counters.insert(key.to_string(), step + 1);
172 Ok(())
173 }
174
175 pub fn log_metric_at(&mut self, key: &str, step: u64, value: f64) -> Result<()> {
183 if self.finished {
184 return Err(StorageError::InvalidState("Cannot log to finished run".to_string()));
185 }
186
187 self.lock()?.log_metric(&self.id, key, step, value)?;
188
189 if self.config.tracing_enabled {
191 self.emit_metric_event(key, step, value);
192 }
193
194 Ok(())
195 }
196
197 fn emit_metric_event(&self, key: &str, step: u64, value: f64) {
199 if self.span.is_some() {
201 let _ = (key, step, value);
202 }
203 }
204
205 pub fn finish(mut self, status: RunStatus) -> Result<()> {
214 if self.finished {
215 return Ok(());
216 }
217
218 self.lock()?.complete_run(&self.id, status)?;
219
220 self.finished = true;
221
222 if self.config.tracing_enabled {
224 self.end_span();
225 }
226
227 Ok(())
228 }
229
230 fn end_span(&self) {
232 let _ = self.span.as_ref();
234 }
235
236 pub fn span_id(&self) -> Option<&str> {
238 self.span.as_deref()
239 }
240
241 pub fn run_id(&self) -> &str {
243 &self.id
244 }
245
246 pub fn tracing_config(&self) -> &TracingConfig {
248 &self.config
249 }
250
251 pub fn is_finished(&self) -> bool {
253 self.finished
254 }
255
256 pub fn current_step(&self, key: &str) -> u64 {
258 *self.step_counters.get(key).unwrap_or(&0)
259 }
260}
261
262impl<S: ExperimentStorage> std::fmt::Debug for Run<S> {
263 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264 f.debug_struct("Run")
265 .field("id", &self.id)
266 .field("experiment_id", &self.experiment_id)
267 .field("span", &self.span)
268 .field("finished", &self.finished)
269 .finish_non_exhaustive()
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::storage::InMemoryStorage;
277
278 fn setup_storage() -> (Arc<Mutex<InMemoryStorage>>, String) {
279 let mut storage = InMemoryStorage::new();
280 let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
281 (Arc::new(Mutex::new(storage)), exp_id)
282 }
283
284 #[test]
285 fn test_tracing_config_default() {
286 let config = TracingConfig::default();
287 assert!(config.tracing_enabled);
288 assert!(!config.export_otlp);
289 assert!(config.golden_trace_path.is_none());
290 }
291
292 #[test]
293 fn test_tracing_config_disabled() {
294 let config = TracingConfig::disabled();
295 assert!(!config.tracing_enabled);
296 }
297
298 #[test]
299 fn test_tracing_config_builder() {
300 let config =
301 TracingConfig::default().with_otlp_export().with_golden_trace_path("/tmp/golden");
302
303 assert!(config.tracing_enabled);
304 assert!(config.export_otlp);
305 assert_eq!(config.golden_trace_path, Some(PathBuf::from("/tmp/golden")));
306 }
307
308 #[test]
309 fn test_run_new_creates_span() {
310 let (storage, exp_id) = setup_storage();
311 let config = TracingConfig::default();
312
313 let run = Run::new(&exp_id, storage, config).expect("config should be valid");
314
315 assert!(run.span_id().is_some());
316 assert!(run.span_id().expect("operation should succeed").starts_with("span-"));
317 }
318
319 #[test]
320 fn test_run_new_without_tracing() {
321 let (storage, exp_id) = setup_storage();
322 let config = TracingConfig::disabled();
323
324 let run = Run::new(&exp_id, storage, config).expect("config should be valid");
325
326 assert!(run.span_id().is_none());
327 }
328
329 #[test]
330 fn test_run_log_metric_auto_increment() {
331 let (storage, exp_id) = setup_storage();
332 let config = TracingConfig::disabled();
333
334 let mut run = Run::new(&exp_id, storage.clone(), config).expect("config should be valid");
335
336 run.log_metric("loss", 0.5).expect("operation should succeed");
337 run.log_metric("loss", 0.4).expect("operation should succeed");
338 run.log_metric("loss", 0.3).expect("operation should succeed");
339
340 assert_eq!(run.current_step("loss"), 3);
341
342 let metrics = storage
343 .lock()
344 .expect("lock acquisition should succeed")
345 .get_metrics(&run.id, "loss")
346 .expect("lock acquisition should succeed");
347 assert_eq!(metrics.len(), 3);
348 assert_eq!(metrics[0].step, 0);
349 assert_eq!(metrics[1].step, 1);
350 assert_eq!(metrics[2].step, 2);
351 }
352
353 #[test]
354 fn test_run_log_metric_at_explicit_step() {
355 let (storage, exp_id) = setup_storage();
356 let config = TracingConfig::disabled();
357
358 let mut run = Run::new(&exp_id, storage.clone(), config).expect("config should be valid");
359
360 run.log_metric_at("accuracy", 0, 0.7).expect("operation should succeed");
361 run.log_metric_at("accuracy", 10, 0.8).expect("operation should succeed");
362 run.log_metric_at("accuracy", 20, 0.9).expect("operation should succeed");
363
364 let metrics = storage
365 .lock()
366 .expect("lock acquisition should succeed")
367 .get_metrics(&run.id, "accuracy")
368 .expect("lock acquisition should succeed");
369 assert_eq!(metrics.len(), 3);
370 assert_eq!(metrics[0].step, 0);
371 assert_eq!(metrics[1].step, 10);
372 assert_eq!(metrics[2].step, 20);
373 }
374
375 #[test]
376 fn test_run_multiple_metrics() {
377 let (storage, exp_id) = setup_storage();
378 let config = TracingConfig::disabled();
379
380 let mut run = Run::new(&exp_id, storage.clone(), config).expect("config should be valid");
381
382 run.log_metric("loss", 0.5).expect("operation should succeed");
383 run.log_metric("accuracy", 0.8).expect("operation should succeed");
384 run.log_metric("loss", 0.4).expect("operation should succeed");
385
386 assert_eq!(run.current_step("loss"), 2);
387 assert_eq!(run.current_step("accuracy"), 1);
388 }
389
390 #[test]
391 fn test_run_finish_success() {
392 let (storage, exp_id) = setup_storage();
393 let config = TracingConfig::disabled();
394
395 let run = Run::new(&exp_id, storage.clone(), config).expect("config should be valid");
396 let run_id = run.id.clone();
397
398 run.finish(RunStatus::Success).expect("operation should succeed");
399
400 let status = storage
401 .lock()
402 .unwrap_or_else(std::sync::PoisonError::into_inner)
403 .get_run_status(&run_id)
404 .expect("operation should succeed");
405 assert_eq!(status, RunStatus::Success);
406 }
407
408 #[test]
409 fn test_run_finish_failed() {
410 let (storage, exp_id) = setup_storage();
411 let config = TracingConfig::disabled();
412
413 let run = Run::new(&exp_id, storage.clone(), config).expect("config should be valid");
414 let run_id = run.id.clone();
415
416 run.finish(RunStatus::Failed).expect("operation should succeed");
417
418 let status = storage
419 .lock()
420 .unwrap_or_else(std::sync::PoisonError::into_inner)
421 .get_run_status(&run_id)
422 .expect("operation should succeed");
423 assert_eq!(status, RunStatus::Failed);
424 }
425
426 #[test]
427 fn test_run_stores_span_id() {
428 let (storage, exp_id) = setup_storage();
429 let config = TracingConfig::default();
430
431 let run = Run::new(&exp_id, storage.clone(), config).expect("config should be valid");
432 let span_id = run.span_id().expect("operation should succeed").to_string();
433
434 let stored_span = storage
435 .lock()
436 .expect("lock acquisition should succeed")
437 .get_span_id(&run.id)
438 .expect("lock acquisition should succeed")
439 .expect("lock acquisition should succeed");
440 assert_eq!(stored_span, span_id);
441 }
442
443 #[test]
444 fn test_run_accessors() {
445 let (storage, exp_id) = setup_storage();
446 let config = TracingConfig::default();
447
448 let run = Run::new(&exp_id, storage, config).expect("config should be valid");
449
450 assert!(!run.is_finished());
451 assert!(run.run_id().starts_with("run-"));
452 assert!(run.tracing_config().tracing_enabled);
453 }
454
455 #[test]
456 fn test_run_debug() {
457 let (storage, exp_id) = setup_storage();
458 let config = TracingConfig::disabled();
459
460 let run = Run::new(&exp_id, storage, config).expect("config should be valid");
461 let debug_str = format!("{run:?}");
462
463 assert!(debug_str.contains("Run"));
464 assert!(debug_str.contains(&run.id));
465 }
466}