Skip to main content

entrenar/
run.rs

1//! Run Struct with Renacer Integration (ENT-002)
2//!
3//! Provides the `Run` struct which wraps experiment tracking with
4//! distributed tracing via Renacer spans.
5//!
6//! # Example
7//!
8//! ```
9//! use std::sync::{Arc, Mutex};
10//! use entrenar::storage::{InMemoryStorage, ExperimentStorage};
11//! use entrenar::run::{Run, TracingConfig};
12//!
13//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
14//! let mut storage = InMemoryStorage::new();
15//! let exp_id = storage.create_experiment("my-exp", None)?;
16//! let storage = Arc::new(Mutex::new(storage));
17//!
18//! let config = TracingConfig::default();
19//! let mut run = Run::new(&exp_id, storage.clone(), config)?;
20//!
21//! // Log metrics - auto-increments step
22//! run.log_metric("loss", 0.5)?;
23//! run.log_metric("loss", 0.4)?;
24//!
25//! // Or log with explicit step
26//! run.log_metric_at("accuracy", 0, 0.85)?;
27//!
28//! // Complete the run
29//! run.finish(entrenar::storage::RunStatus::Success)?;
30//! # Ok(())
31//! # }
32//! ```
33
34use std::collections::HashMap;
35use std::path::PathBuf;
36use std::sync::{Arc, Mutex};
37use std::time::{SystemTime, UNIX_EPOCH};
38
39use serde::{Deserialize, Serialize};
40
41use crate::storage::{ExperimentStorage, Result, RunStatus, StorageError};
42
43/// Configuration for distributed tracing
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
45pub struct TracingConfig {
46    /// Whether tracing is enabled (creates Renacer spans)
47    pub tracing_enabled: bool,
48
49    /// Whether to export traces via OTLP
50    pub export_otlp: bool,
51
52    /// Path for golden trace storage
53    pub golden_trace_path: Option<PathBuf>,
54}
55
56impl Default for TracingConfig {
57    fn default() -> Self {
58        Self { tracing_enabled: true, export_otlp: false, golden_trace_path: None }
59    }
60}
61
62impl TracingConfig {
63    /// Create a disabled tracing configuration
64    pub fn disabled() -> Self {
65        Self { tracing_enabled: false, export_otlp: false, golden_trace_path: None }
66    }
67
68    /// Enable OTLP export
69    pub fn with_otlp_export(mut self) -> Self {
70        self.export_otlp = true;
71        self
72    }
73
74    /// Set golden trace path
75    pub fn with_golden_trace_path(mut self, path: impl Into<PathBuf>) -> Self {
76        self.golden_trace_path = Some(path.into());
77        self
78    }
79}
80
81/// A training run with integrated distributed tracing
82///
83/// Generic over the storage backend, allowing different backends
84/// for production (TruenoBackend) and testing (InMemoryStorage).
85pub struct Run<S: ExperimentStorage> {
86    /// Run ID
87    pub id: String,
88    /// Parent experiment ID
89    pub experiment_id: String,
90    /// Storage backend (shared)
91    pub(crate) storage: Arc<Mutex<S>>,
92    /// Renacer span ID (if tracing enabled)
93    span: Option<String>,
94    /// Tracing configuration
95    config: TracingConfig,
96    /// Current step counters per metric key
97    pub(crate) step_counters: HashMap<String, u64>,
98    /// Whether the run has been finished
99    finished: bool,
100}
101
102impl<S: ExperimentStorage> Run<S> {
103    /// Acquire the storage mutex, mapping poison errors to `StorageError::Backend`.
104    fn lock_storage(storage: &Arc<Mutex<S>>) -> Result<std::sync::MutexGuard<'_, S>> {
105        storage.lock().map_err(|e| StorageError::Backend(format!("mutex poisoned: {e}")))
106    }
107
108    /// Acquire the storage mutex on `self`, mapping poison errors to `StorageError::Backend`.
109    fn lock(&self) -> Result<std::sync::MutexGuard<'_, S>> {
110        Self::lock_storage(&self.storage)
111    }
112
113    /// Create a new run with tracing
114    ///
115    /// Creates a run in the storage backend, starts it, and optionally
116    /// creates a Renacer span for distributed tracing.
117    ///
118    /// # Arguments
119    ///
120    /// * `experiment_id` - ID of the parent experiment
121    /// * `storage` - Shared storage backend
122    /// * `config` - Tracing configuration
123    pub fn new(experiment_id: &str, storage: Arc<Mutex<S>>, config: TracingConfig) -> Result<Self> {
124        // Create run in storage
125        let run_id = {
126            let mut store = Self::lock_storage(&storage)?;
127            let run_id = store.create_run(experiment_id)?;
128            store.start_run(&run_id)?;
129            run_id
130        };
131
132        // Create span if tracing enabled
133        let span = if config.tracing_enabled {
134            let span_id = Self::create_span(&run_id);
135            Self::lock_storage(&storage)?.set_span_id(&run_id, &span_id)?;
136            Some(span_id)
137        } else {
138            None
139        };
140
141        Ok(Self {
142            id: run_id,
143            experiment_id: experiment_id.to_string(),
144            storage,
145            span,
146            config,
147            step_counters: HashMap::new(),
148            finished: false,
149        })
150    }
151
152    /// Create a Renacer span for this run
153    fn create_span(run_id: &str) -> String {
154        let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default();
155
156        format!("span-{}-{}", run_id, now.as_nanos())
157    }
158
159    /// Log a metric value, auto-incrementing the step
160    ///
161    /// Each metric key has its own step counter that starts at 0
162    /// and increments with each call.
163    ///
164    /// # Arguments
165    ///
166    /// * `key` - Metric name (e.g., "loss", "accuracy")
167    /// * `value` - Metric value
168    pub fn log_metric(&mut self, key: &str, value: f64) -> Result<()> {
169        let step = *self.step_counters.get(key).unwrap_or(&0);
170        self.log_metric_at(key, step, value)?;
171        self.step_counters.insert(key.to_string(), step + 1);
172        Ok(())
173    }
174
175    /// Log a metric value at a specific step
176    ///
177    /// # Arguments
178    ///
179    /// * `key` - Metric name
180    /// * `step` - Training step
181    /// * `value` - Metric value
182    pub fn log_metric_at(&mut self, key: &str, step: u64, value: f64) -> Result<()> {
183        if self.finished {
184            return Err(StorageError::InvalidState("Cannot log to finished run".to_string()));
185        }
186
187        self.lock()?.log_metric(&self.id, key, step, value)?;
188
189        // Emit span event if tracing enabled
190        if self.config.tracing_enabled {
191            self.emit_metric_event(key, step, value);
192        }
193
194        Ok(())
195    }
196
197    /// Emit a metric event to the Renacer span
198    fn emit_metric_event(&self, key: &str, step: u64, value: f64) {
199        // In a full implementation, this would call renacer::record_event()
200        if self.span.is_some() {
201            let _ = (key, step, value);
202        }
203    }
204
205    /// Finish the run with the given status
206    ///
207    /// Completes the run in storage and ends the Renacer span.
208    /// Consumes the Run to prevent further operations.
209    ///
210    /// # Arguments
211    ///
212    /// * `status` - Final run status
213    pub fn finish(mut self, status: RunStatus) -> Result<()> {
214        if self.finished {
215            return Ok(());
216        }
217
218        self.lock()?.complete_run(&self.id, status)?;
219
220        self.finished = true;
221
222        // End span if tracing enabled
223        if self.config.tracing_enabled {
224            self.end_span();
225        }
226
227        Ok(())
228    }
229
230    /// End the Renacer span
231    fn end_span(&self) {
232        // In a full implementation, this would call span.end()
233        let _ = self.span.as_ref();
234    }
235
236    /// Get the Renacer span ID (if tracing is enabled)
237    pub fn span_id(&self) -> Option<&str> {
238        self.span.as_deref()
239    }
240
241    /// Get the run ID
242    pub fn run_id(&self) -> &str {
243        &self.id
244    }
245
246    /// Get the tracing configuration
247    pub fn tracing_config(&self) -> &TracingConfig {
248        &self.config
249    }
250
251    /// Check if the run has been finished
252    pub fn is_finished(&self) -> bool {
253        self.finished
254    }
255
256    /// Get current step for a metric key
257    pub fn current_step(&self, key: &str) -> u64 {
258        *self.step_counters.get(key).unwrap_or(&0)
259    }
260}
261
262impl<S: ExperimentStorage> std::fmt::Debug for Run<S> {
263    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264        f.debug_struct("Run")
265            .field("id", &self.id)
266            .field("experiment_id", &self.experiment_id)
267            .field("span", &self.span)
268            .field("finished", &self.finished)
269            .finish_non_exhaustive()
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::storage::InMemoryStorage;
277
278    fn setup_storage() -> (Arc<Mutex<InMemoryStorage>>, String) {
279        let mut storage = InMemoryStorage::new();
280        let exp_id = storage.create_experiment("test-exp", None).expect("operation should succeed");
281        (Arc::new(Mutex::new(storage)), exp_id)
282    }
283
284    #[test]
285    fn test_tracing_config_default() {
286        let config = TracingConfig::default();
287        assert!(config.tracing_enabled);
288        assert!(!config.export_otlp);
289        assert!(config.golden_trace_path.is_none());
290    }
291
292    #[test]
293    fn test_tracing_config_disabled() {
294        let config = TracingConfig::disabled();
295        assert!(!config.tracing_enabled);
296    }
297
298    #[test]
299    fn test_tracing_config_builder() {
300        let config =
301            TracingConfig::default().with_otlp_export().with_golden_trace_path("/tmp/golden");
302
303        assert!(config.tracing_enabled);
304        assert!(config.export_otlp);
305        assert_eq!(config.golden_trace_path, Some(PathBuf::from("/tmp/golden")));
306    }
307
308    #[test]
309    fn test_run_new_creates_span() {
310        let (storage, exp_id) = setup_storage();
311        let config = TracingConfig::default();
312
313        let run = Run::new(&exp_id, storage, config).expect("config should be valid");
314
315        assert!(run.span_id().is_some());
316        assert!(run.span_id().expect("operation should succeed").starts_with("span-"));
317    }
318
319    #[test]
320    fn test_run_new_without_tracing() {
321        let (storage, exp_id) = setup_storage();
322        let config = TracingConfig::disabled();
323
324        let run = Run::new(&exp_id, storage, config).expect("config should be valid");
325
326        assert!(run.span_id().is_none());
327    }
328
329    #[test]
330    fn test_run_log_metric_auto_increment() {
331        let (storage, exp_id) = setup_storage();
332        let config = TracingConfig::disabled();
333
334        let mut run = Run::new(&exp_id, storage.clone(), config).expect("config should be valid");
335
336        run.log_metric("loss", 0.5).expect("operation should succeed");
337        run.log_metric("loss", 0.4).expect("operation should succeed");
338        run.log_metric("loss", 0.3).expect("operation should succeed");
339
340        assert_eq!(run.current_step("loss"), 3);
341
342        let metrics = storage
343            .lock()
344            .expect("lock acquisition should succeed")
345            .get_metrics(&run.id, "loss")
346            .expect("lock acquisition should succeed");
347        assert_eq!(metrics.len(), 3);
348        assert_eq!(metrics[0].step, 0);
349        assert_eq!(metrics[1].step, 1);
350        assert_eq!(metrics[2].step, 2);
351    }
352
353    #[test]
354    fn test_run_log_metric_at_explicit_step() {
355        let (storage, exp_id) = setup_storage();
356        let config = TracingConfig::disabled();
357
358        let mut run = Run::new(&exp_id, storage.clone(), config).expect("config should be valid");
359
360        run.log_metric_at("accuracy", 0, 0.7).expect("operation should succeed");
361        run.log_metric_at("accuracy", 10, 0.8).expect("operation should succeed");
362        run.log_metric_at("accuracy", 20, 0.9).expect("operation should succeed");
363
364        let metrics = storage
365            .lock()
366            .expect("lock acquisition should succeed")
367            .get_metrics(&run.id, "accuracy")
368            .expect("lock acquisition should succeed");
369        assert_eq!(metrics.len(), 3);
370        assert_eq!(metrics[0].step, 0);
371        assert_eq!(metrics[1].step, 10);
372        assert_eq!(metrics[2].step, 20);
373    }
374
375    #[test]
376    fn test_run_multiple_metrics() {
377        let (storage, exp_id) = setup_storage();
378        let config = TracingConfig::disabled();
379
380        let mut run = Run::new(&exp_id, storage.clone(), config).expect("config should be valid");
381
382        run.log_metric("loss", 0.5).expect("operation should succeed");
383        run.log_metric("accuracy", 0.8).expect("operation should succeed");
384        run.log_metric("loss", 0.4).expect("operation should succeed");
385
386        assert_eq!(run.current_step("loss"), 2);
387        assert_eq!(run.current_step("accuracy"), 1);
388    }
389
390    #[test]
391    fn test_run_finish_success() {
392        let (storage, exp_id) = setup_storage();
393        let config = TracingConfig::disabled();
394
395        let run = Run::new(&exp_id, storage.clone(), config).expect("config should be valid");
396        let run_id = run.id.clone();
397
398        run.finish(RunStatus::Success).expect("operation should succeed");
399
400        let status = storage
401            .lock()
402            .unwrap_or_else(std::sync::PoisonError::into_inner)
403            .get_run_status(&run_id)
404            .expect("operation should succeed");
405        assert_eq!(status, RunStatus::Success);
406    }
407
408    #[test]
409    fn test_run_finish_failed() {
410        let (storage, exp_id) = setup_storage();
411        let config = TracingConfig::disabled();
412
413        let run = Run::new(&exp_id, storage.clone(), config).expect("config should be valid");
414        let run_id = run.id.clone();
415
416        run.finish(RunStatus::Failed).expect("operation should succeed");
417
418        let status = storage
419            .lock()
420            .unwrap_or_else(std::sync::PoisonError::into_inner)
421            .get_run_status(&run_id)
422            .expect("operation should succeed");
423        assert_eq!(status, RunStatus::Failed);
424    }
425
426    #[test]
427    fn test_run_stores_span_id() {
428        let (storage, exp_id) = setup_storage();
429        let config = TracingConfig::default();
430
431        let run = Run::new(&exp_id, storage.clone(), config).expect("config should be valid");
432        let span_id = run.span_id().expect("operation should succeed").to_string();
433
434        let stored_span = storage
435            .lock()
436            .expect("lock acquisition should succeed")
437            .get_span_id(&run.id)
438            .expect("lock acquisition should succeed")
439            .expect("lock acquisition should succeed");
440        assert_eq!(stored_span, span_id);
441    }
442
443    #[test]
444    fn test_run_accessors() {
445        let (storage, exp_id) = setup_storage();
446        let config = TracingConfig::default();
447
448        let run = Run::new(&exp_id, storage, config).expect("config should be valid");
449
450        assert!(!run.is_finished());
451        assert!(run.run_id().starts_with("run-"));
452        assert!(run.tracing_config().tracing_enabled);
453    }
454
455    #[test]
456    fn test_run_debug() {
457        let (storage, exp_id) = setup_storage();
458        let config = TracingConfig::disabled();
459
460        let run = Run::new(&exp_id, storage, config).expect("config should be valid");
461        let debug_str = format!("{run:?}");
462
463        assert!(debug_str.contains("Run"));
464        assert!(debug_str.contains(&run.id));
465    }
466}