Skip to main content

entrenar/
trace.rs

1//! Training Trace Module (ITP-SPEC-001)
2//!
3//! Provides observability into the training pipeline for empirical analysis.
4//! Used to falsify the "Kernel Launch Overhead" hypothesis.
5
6use std::collections::HashMap;
7use std::fmt;
8use std::sync::{LazyLock, Mutex, PoisonError};
9use std::time::{Duration, Instant};
10
11/// The lifecycle steps of a training operation.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum TraceStep {
14    /// Forward pass through model
15    Forward,
16    /// Backward pass (gradient computation)
17    Backward,
18    /// Matrix multiplication kernel
19    Matmul,
20    /// Attention computation
21    Attention,
22    /// CPU transpose operation
23    Transpose,
24    /// Memory allocation
25    Alloc,
26    /// Data transfer overhead
27    Transfer,
28    /// VRAM ledger reservation (GPU-SHARE-001)
29    LedgerReserve,
30    /// VRAM ledger dead-PID / expired-lease cleanup
31    LedgerCleanup,
32    /// VRAM query (cuMemGetInfo / NVML)
33    VramQuery,
34    /// Wait-for-VRAM poll iteration (GPU-SHARE-003)
35    WaitPoll,
36    /// VRAM ledger release
37    LedgerRelease,
38}
39
40impl fmt::Display for TraceStep {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        write!(f, "{self:?}")
43    }
44}
45
46/// A single timing measurement.
47#[derive(Debug, Clone)]
48pub struct TraceMeasurement {
49    pub step: TraceStep,
50    pub duration: Duration,
51    pub metadata: String,
52}
53
54/// Thread-safe tracer for collecting timing measurements.
55/// ALB-099: Uses aggregated counters instead of unbounded Vec to prevent
56/// memory leak during long training runs (was ~2.8 GB at 28K steps).
57pub struct Tracer {
58    /// Legacy per-measurement storage (kept for backward compat, bounded)
59    measurements: Mutex<Vec<TraceMeasurement>>,
60    /// ALB-099: Aggregated totals — O(1) memory regardless of training length
61    aggregated: Mutex<HashMap<TraceStep, (usize, Duration)>>,
62    active_spans: Mutex<HashMap<TraceStep, Instant>>,
63    enabled: Mutex<bool>,
64}
65
66impl Tracer {
67    /// Create a new tracer.
68    pub fn new() -> Self {
69        Self {
70            measurements: Mutex::new(Vec::new()),
71            aggregated: Mutex::new(HashMap::new()),
72            active_spans: Mutex::new(HashMap::new()),
73            enabled: Mutex::new(false), // Disabled by default for performance
74        }
75    }
76
77    /// Enable tracing.
78    pub fn enable(&self) {
79        *self.enabled.lock().unwrap_or_else(PoisonError::into_inner) = true;
80    }
81
82    /// Disable tracing.
83    pub fn disable(&self) {
84        *self.enabled.lock().unwrap_or_else(PoisonError::into_inner) = false;
85    }
86
87    /// Check if tracing is enabled.
88    pub fn is_enabled(&self) -> bool {
89        *self.enabled.lock().unwrap_or_else(PoisonError::into_inner)
90    }
91
92    /// Start a timing span.
93    pub fn start(&self, step: TraceStep) {
94        if !self.is_enabled() {
95            return;
96        }
97        let mut spans = self.active_spans.lock().unwrap_or_else(PoisonError::into_inner);
98        spans.insert(step, Instant::now());
99    }
100
101    /// End a timing span and record measurement.
102    /// ALB-099: Aggregates into counters (O(1) memory) instead of appending to Vec.
103    pub fn end(&self, step: TraceStep, _metadata: impl Into<String>) {
104        if !self.is_enabled() {
105            return;
106        }
107        let mut spans = self.active_spans.lock().unwrap_or_else(PoisonError::into_inner);
108        if let Some(start) = spans.remove(&step) {
109            let duration = start.elapsed();
110            let mut agg = self.aggregated.lock().unwrap_or_else(PoisonError::into_inner);
111            let entry = agg.entry(step).or_insert((0, Duration::ZERO));
112            entry.0 += 1;
113            entry.1 += duration;
114        }
115    }
116
117    /// Run a closure within a measured span.
118    #[inline]
119    pub fn span<F, R>(&self, step: TraceStep, metadata: impl Into<String>, f: F) -> R
120    where
121        F: FnOnce() -> R,
122    {
123        if !self.is_enabled() {
124            return f();
125        }
126        self.start(step);
127        let result = f();
128        self.end(step, metadata);
129        result
130    }
131
132    /// Clear all measurements.
133    pub fn clear(&self) {
134        self.measurements.lock().unwrap_or_else(PoisonError::into_inner).clear();
135        self.aggregated.lock().unwrap_or_else(PoisonError::into_inner).clear();
136        self.active_spans.lock().unwrap_or_else(PoisonError::into_inner).clear();
137    }
138
139    /// Generate a report with Dr. Popper analysis.
140    /// ALB-099: Reads from aggregated counters (O(1) memory).
141    pub fn report(&self) -> String {
142        let agg = self.aggregated.lock().unwrap_or_else(PoisonError::into_inner);
143        if agg.is_empty() {
144            return "No measurements recorded. Enable tracing with TRACER.enable()".to_string();
145        }
146
147        let mut totals: HashMap<TraceStep, Duration> = HashMap::new();
148        let mut counts: HashMap<TraceStep, usize> = HashMap::new();
149        let mut total_time = Duration::ZERO;
150
151        for (&step, &(count, duration)) in agg.iter() {
152            totals.insert(step, duration);
153            counts.insert(step, count);
154            total_time += duration;
155        }
156
157        let mut output =
158            String::from("\n╔══════════════════════════════════════════════════════════════╗\n");
159        output.push_str("║       ENTRENAR TRACE REPORT (ITP-SPEC-001)                   ║\n");
160        output.push_str("╚══════════════════════════════════════════════════════════════╝\n");
161        output.push_str(&format!("Total Measured Time: {total_time:.2?}\n"));
162        output.push_str("────────────────────────────────────────────────────────────────\n");
163        output.push_str(&format!(
164            "{:<15} | {:<8} | {:<15} | {:<8}\n",
165            "Step", "Count", "Duration", "% Time"
166        ));
167        output.push_str("────────────────────────────────────────────────────────────────\n");
168
169        // Sort by duration descending
170        let mut sorted_steps: Vec<_> = totals.keys().collect();
171        sorted_steps.sort_by(|a, b| totals[b].cmp(&totals[a]));
172
173        for step in sorted_steps {
174            let duration = totals[step];
175            let count = counts[step];
176            let percentage = if total_time.as_nanos() > 0 {
177                (duration.as_secs_f64() / total_time.as_secs_f64()) * 100.0
178            } else {
179                0.0
180            };
181            output.push_str(&format!(
182                "{:<15} | {:<8} | {:<15.2?} | {:>7.2}%\n",
183                step.to_string(),
184                count,
185                duration,
186                percentage
187            ));
188        }
189        output.push_str("────────────────────────────────────────────────────────────────\n");
190
191        // Dr. Popper Analysis
192        let matmul_time = totals.get(&TraceStep::Matmul).copied().unwrap_or_default();
193        let transpose_time = totals.get(&TraceStep::Transpose).copied().unwrap_or_default();
194        let alloc_time = totals.get(&TraceStep::Alloc).copied().unwrap_or_default();
195        let compute_time = matmul_time;
196        let overhead_time = transpose_time + alloc_time;
197
198        if compute_time.as_nanos() > 0 {
199            let overhead_pct = (overhead_time.as_secs_f64()
200                / (compute_time + overhead_time).as_secs_f64())
201                * 100.0;
202
203            output.push_str("\n[Dr. Popper Analysis]\n");
204            output.push_str(&format!("CUDA Compute:   {compute_time:.2?}\n"));
205            output.push_str(&format!("CPU Overhead:   {overhead_time:.2?} ({overhead_pct:.2}%)\n"));
206
207            if overhead_pct > 50.0 {
208                output.push_str("\n🔴 FALSIFICATION: Overhead > 50%. Kernel fusion required.\n");
209            } else {
210                output.push_str("\n🟢 CORROBORATED: Compute dominates. Current approach viable.\n");
211            }
212        }
213
214        output
215    }
216}
217
218impl Default for Tracer {
219    fn default() -> Self {
220        Self::new()
221    }
222}
223
224/// Global tracer instance.
225pub static TRACER: LazyLock<Tracer> = LazyLock::new(Tracer::new);
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_trace_step_display() {
233        assert_eq!(TraceStep::Forward.to_string(), "Forward");
234        assert_eq!(TraceStep::Backward.to_string(), "Backward");
235        assert_eq!(TraceStep::Matmul.to_string(), "Matmul");
236        assert_eq!(TraceStep::Attention.to_string(), "Attention");
237        assert_eq!(TraceStep::Transpose.to_string(), "Transpose");
238        assert_eq!(TraceStep::Alloc.to_string(), "Alloc");
239        assert_eq!(TraceStep::Transfer.to_string(), "Transfer");
240        assert_eq!(TraceStep::LedgerReserve.to_string(), "LedgerReserve");
241        assert_eq!(TraceStep::LedgerCleanup.to_string(), "LedgerCleanup");
242        assert_eq!(TraceStep::VramQuery.to_string(), "VramQuery");
243        assert_eq!(TraceStep::WaitPoll.to_string(), "WaitPoll");
244        assert_eq!(TraceStep::LedgerRelease.to_string(), "LedgerRelease");
245    }
246
247    #[test]
248    fn test_trace_step_clone() {
249        let step = TraceStep::Forward;
250        let cloned = step;
251        assert_eq!(step, cloned);
252    }
253
254    #[test]
255    fn test_trace_step_hash() {
256        use std::collections::HashSet;
257        let mut set = HashSet::new();
258        set.insert(TraceStep::Forward);
259        set.insert(TraceStep::Forward);
260        assert_eq!(set.len(), 1);
261        set.insert(TraceStep::Backward);
262        assert_eq!(set.len(), 2);
263    }
264
265    #[test]
266    fn test_tracer_new() {
267        let tracer = Tracer::new();
268        assert!(!tracer.is_enabled());
269    }
270
271    #[test]
272    fn test_tracer_default() {
273        let tracer = Tracer::default();
274        assert!(!tracer.is_enabled());
275    }
276
277    #[test]
278    fn test_tracer_enable_disable() {
279        let tracer = Tracer::new();
280        assert!(!tracer.is_enabled());
281        tracer.enable();
282        assert!(tracer.is_enabled());
283        tracer.disable();
284        assert!(!tracer.is_enabled());
285    }
286
287    #[test]
288    fn test_tracer_start_end_disabled() {
289        let tracer = Tracer::new();
290        // Should not panic when disabled
291        tracer.start(TraceStep::Forward);
292        tracer.end(TraceStep::Forward, "test");
293    }
294
295    #[test]
296    fn test_tracer_start_end_enabled() {
297        let tracer = Tracer::new();
298        tracer.enable();
299        tracer.start(TraceStep::Matmul);
300        tracer.end(TraceStep::Matmul, "2x2");
301        // Verify measurement was recorded
302        let report = tracer.report();
303        assert!(report.contains("Matmul"));
304    }
305
306    #[test]
307    fn test_tracer_span_disabled() {
308        let tracer = Tracer::new();
309        let result = tracer.span(TraceStep::Forward, "test", || 42);
310        assert_eq!(result, 42);
311    }
312
313    #[test]
314    fn test_tracer_span_enabled() {
315        let tracer = Tracer::new();
316        tracer.enable();
317        let result = tracer.span(TraceStep::Attention, "4 heads", || "done");
318        assert_eq!(result, "done");
319        let report = tracer.report();
320        assert!(report.contains("Attention"));
321    }
322
323    #[test]
324    fn test_tracer_clear() {
325        let tracer = Tracer::new();
326        tracer.enable();
327        tracer.start(TraceStep::Forward);
328        tracer.end(TraceStep::Forward, "test");
329        tracer.clear();
330        let report = tracer.report();
331        assert!(report.contains("No measurements recorded"));
332    }
333
334    #[test]
335    fn test_tracer_report_empty() {
336        let tracer = Tracer::new();
337        let report = tracer.report();
338        assert!(report.contains("No measurements recorded"));
339    }
340
341    #[test]
342    fn test_tracer_report_with_measurements() {
343        let tracer = Tracer::new();
344        tracer.enable();
345
346        tracer.start(TraceStep::Matmul);
347        tracer.end(TraceStep::Matmul, "512x512");
348
349        tracer.start(TraceStep::Transpose);
350        tracer.end(TraceStep::Transpose, "256x256");
351
352        let report = tracer.report();
353        assert!(report.contains("ENTRENAR TRACE REPORT"));
354        assert!(report.contains("Matmul"));
355        assert!(report.contains("Transpose"));
356        assert!(report.contains("% Time"));
357    }
358
359    #[test]
360    fn test_tracer_report_dr_popper_analysis() {
361        let tracer = Tracer::new();
362
363        // Inject into aggregated counters (report reads from aggregated, not measurements)
364        {
365            let mut agg = tracer.aggregated.lock().expect("lock acquisition should succeed");
366            agg.insert(TraceStep::Matmul, (1, Duration::from_millis(50)));
367            agg.insert(TraceStep::Transpose, (1, Duration::from_millis(10)));
368        }
369
370        let report = tracer.report();
371        assert!(report.contains("Dr. Popper Analysis"));
372        assert!(report.contains("CUDA Compute:"));
373        assert!(report.contains("CPU Overhead:"));
374    }
375
376    #[test]
377    fn test_tracer_end_without_start() {
378        let tracer = Tracer::new();
379        tracer.enable();
380        // Should not panic - just ignored
381        tracer.end(TraceStep::Forward, "no start");
382        let report = tracer.report();
383        assert!(report.contains("No measurements recorded"));
384    }
385
386    #[test]
387    fn test_trace_measurement_clone() {
388        let measurement = TraceMeasurement {
389            step: TraceStep::Forward,
390            duration: Duration::from_millis(100),
391            metadata: "test".to_string(),
392        };
393        let cloned = measurement.clone();
394        assert_eq!(measurement.step, cloned.step);
395        assert_eq!(measurement.duration, cloned.duration);
396        assert_eq!(measurement.metadata, cloned.metadata);
397    }
398
399    #[test]
400    fn test_trace_measurement_debug() {
401        let measurement = TraceMeasurement {
402            step: TraceStep::Backward,
403            duration: Duration::from_micros(50),
404            metadata: "grad".to_string(),
405        };
406        let debug_str = format!("{measurement:?}");
407        assert!(debug_str.contains("TraceMeasurement"));
408        assert!(debug_str.contains("Backward"));
409    }
410}