1use std::collections::HashMap;
7use std::fmt;
8use std::sync::{LazyLock, Mutex, PoisonError};
9use std::time::{Duration, Instant};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum TraceStep {
14 Forward,
16 Backward,
18 Matmul,
20 Attention,
22 Transpose,
24 Alloc,
26 Transfer,
28 LedgerReserve,
30 LedgerCleanup,
32 VramQuery,
34 WaitPoll,
36 LedgerRelease,
38}
39
40impl fmt::Display for TraceStep {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 write!(f, "{self:?}")
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct TraceMeasurement {
49 pub step: TraceStep,
50 pub duration: Duration,
51 pub metadata: String,
52}
53
54pub struct Tracer {
58 measurements: Mutex<Vec<TraceMeasurement>>,
60 aggregated: Mutex<HashMap<TraceStep, (usize, Duration)>>,
62 active_spans: Mutex<HashMap<TraceStep, Instant>>,
63 enabled: Mutex<bool>,
64}
65
66impl Tracer {
67 pub fn new() -> Self {
69 Self {
70 measurements: Mutex::new(Vec::new()),
71 aggregated: Mutex::new(HashMap::new()),
72 active_spans: Mutex::new(HashMap::new()),
73 enabled: Mutex::new(false), }
75 }
76
77 pub fn enable(&self) {
79 *self.enabled.lock().unwrap_or_else(PoisonError::into_inner) = true;
80 }
81
82 pub fn disable(&self) {
84 *self.enabled.lock().unwrap_or_else(PoisonError::into_inner) = false;
85 }
86
87 pub fn is_enabled(&self) -> bool {
89 *self.enabled.lock().unwrap_or_else(PoisonError::into_inner)
90 }
91
92 pub fn start(&self, step: TraceStep) {
94 if !self.is_enabled() {
95 return;
96 }
97 let mut spans = self.active_spans.lock().unwrap_or_else(PoisonError::into_inner);
98 spans.insert(step, Instant::now());
99 }
100
101 pub fn end(&self, step: TraceStep, _metadata: impl Into<String>) {
104 if !self.is_enabled() {
105 return;
106 }
107 let mut spans = self.active_spans.lock().unwrap_or_else(PoisonError::into_inner);
108 if let Some(start) = spans.remove(&step) {
109 let duration = start.elapsed();
110 let mut agg = self.aggregated.lock().unwrap_or_else(PoisonError::into_inner);
111 let entry = agg.entry(step).or_insert((0, Duration::ZERO));
112 entry.0 += 1;
113 entry.1 += duration;
114 }
115 }
116
117 #[inline]
119 pub fn span<F, R>(&self, step: TraceStep, metadata: impl Into<String>, f: F) -> R
120 where
121 F: FnOnce() -> R,
122 {
123 if !self.is_enabled() {
124 return f();
125 }
126 self.start(step);
127 let result = f();
128 self.end(step, metadata);
129 result
130 }
131
132 pub fn clear(&self) {
134 self.measurements.lock().unwrap_or_else(PoisonError::into_inner).clear();
135 self.aggregated.lock().unwrap_or_else(PoisonError::into_inner).clear();
136 self.active_spans.lock().unwrap_or_else(PoisonError::into_inner).clear();
137 }
138
139 pub fn report(&self) -> String {
142 let agg = self.aggregated.lock().unwrap_or_else(PoisonError::into_inner);
143 if agg.is_empty() {
144 return "No measurements recorded. Enable tracing with TRACER.enable()".to_string();
145 }
146
147 let mut totals: HashMap<TraceStep, Duration> = HashMap::new();
148 let mut counts: HashMap<TraceStep, usize> = HashMap::new();
149 let mut total_time = Duration::ZERO;
150
151 for (&step, &(count, duration)) in agg.iter() {
152 totals.insert(step, duration);
153 counts.insert(step, count);
154 total_time += duration;
155 }
156
157 let mut output =
158 String::from("\n╔══════════════════════════════════════════════════════════════╗\n");
159 output.push_str("║ ENTRENAR TRACE REPORT (ITP-SPEC-001) ║\n");
160 output.push_str("╚══════════════════════════════════════════════════════════════╝\n");
161 output.push_str(&format!("Total Measured Time: {total_time:.2?}\n"));
162 output.push_str("────────────────────────────────────────────────────────────────\n");
163 output.push_str(&format!(
164 "{:<15} | {:<8} | {:<15} | {:<8}\n",
165 "Step", "Count", "Duration", "% Time"
166 ));
167 output.push_str("────────────────────────────────────────────────────────────────\n");
168
169 let mut sorted_steps: Vec<_> = totals.keys().collect();
171 sorted_steps.sort_by(|a, b| totals[b].cmp(&totals[a]));
172
173 for step in sorted_steps {
174 let duration = totals[step];
175 let count = counts[step];
176 let percentage = if total_time.as_nanos() > 0 {
177 (duration.as_secs_f64() / total_time.as_secs_f64()) * 100.0
178 } else {
179 0.0
180 };
181 output.push_str(&format!(
182 "{:<15} | {:<8} | {:<15.2?} | {:>7.2}%\n",
183 step.to_string(),
184 count,
185 duration,
186 percentage
187 ));
188 }
189 output.push_str("────────────────────────────────────────────────────────────────\n");
190
191 let matmul_time = totals.get(&TraceStep::Matmul).copied().unwrap_or_default();
193 let transpose_time = totals.get(&TraceStep::Transpose).copied().unwrap_or_default();
194 let alloc_time = totals.get(&TraceStep::Alloc).copied().unwrap_or_default();
195 let compute_time = matmul_time;
196 let overhead_time = transpose_time + alloc_time;
197
198 if compute_time.as_nanos() > 0 {
199 let overhead_pct = (overhead_time.as_secs_f64()
200 / (compute_time + overhead_time).as_secs_f64())
201 * 100.0;
202
203 output.push_str("\n[Dr. Popper Analysis]\n");
204 output.push_str(&format!("CUDA Compute: {compute_time:.2?}\n"));
205 output.push_str(&format!("CPU Overhead: {overhead_time:.2?} ({overhead_pct:.2}%)\n"));
206
207 if overhead_pct > 50.0 {
208 output.push_str("\n🔴 FALSIFICATION: Overhead > 50%. Kernel fusion required.\n");
209 } else {
210 output.push_str("\n🟢 CORROBORATED: Compute dominates. Current approach viable.\n");
211 }
212 }
213
214 output
215 }
216}
217
218impl Default for Tracer {
219 fn default() -> Self {
220 Self::new()
221 }
222}
223
224pub static TRACER: LazyLock<Tracer> = LazyLock::new(Tracer::new);
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_trace_step_display() {
233 assert_eq!(TraceStep::Forward.to_string(), "Forward");
234 assert_eq!(TraceStep::Backward.to_string(), "Backward");
235 assert_eq!(TraceStep::Matmul.to_string(), "Matmul");
236 assert_eq!(TraceStep::Attention.to_string(), "Attention");
237 assert_eq!(TraceStep::Transpose.to_string(), "Transpose");
238 assert_eq!(TraceStep::Alloc.to_string(), "Alloc");
239 assert_eq!(TraceStep::Transfer.to_string(), "Transfer");
240 assert_eq!(TraceStep::LedgerReserve.to_string(), "LedgerReserve");
241 assert_eq!(TraceStep::LedgerCleanup.to_string(), "LedgerCleanup");
242 assert_eq!(TraceStep::VramQuery.to_string(), "VramQuery");
243 assert_eq!(TraceStep::WaitPoll.to_string(), "WaitPoll");
244 assert_eq!(TraceStep::LedgerRelease.to_string(), "LedgerRelease");
245 }
246
247 #[test]
248 fn test_trace_step_clone() {
249 let step = TraceStep::Forward;
250 let cloned = step;
251 assert_eq!(step, cloned);
252 }
253
254 #[test]
255 fn test_trace_step_hash() {
256 use std::collections::HashSet;
257 let mut set = HashSet::new();
258 set.insert(TraceStep::Forward);
259 set.insert(TraceStep::Forward);
260 assert_eq!(set.len(), 1);
261 set.insert(TraceStep::Backward);
262 assert_eq!(set.len(), 2);
263 }
264
265 #[test]
266 fn test_tracer_new() {
267 let tracer = Tracer::new();
268 assert!(!tracer.is_enabled());
269 }
270
271 #[test]
272 fn test_tracer_default() {
273 let tracer = Tracer::default();
274 assert!(!tracer.is_enabled());
275 }
276
277 #[test]
278 fn test_tracer_enable_disable() {
279 let tracer = Tracer::new();
280 assert!(!tracer.is_enabled());
281 tracer.enable();
282 assert!(tracer.is_enabled());
283 tracer.disable();
284 assert!(!tracer.is_enabled());
285 }
286
287 #[test]
288 fn test_tracer_start_end_disabled() {
289 let tracer = Tracer::new();
290 tracer.start(TraceStep::Forward);
292 tracer.end(TraceStep::Forward, "test");
293 }
294
295 #[test]
296 fn test_tracer_start_end_enabled() {
297 let tracer = Tracer::new();
298 tracer.enable();
299 tracer.start(TraceStep::Matmul);
300 tracer.end(TraceStep::Matmul, "2x2");
301 let report = tracer.report();
303 assert!(report.contains("Matmul"));
304 }
305
306 #[test]
307 fn test_tracer_span_disabled() {
308 let tracer = Tracer::new();
309 let result = tracer.span(TraceStep::Forward, "test", || 42);
310 assert_eq!(result, 42);
311 }
312
313 #[test]
314 fn test_tracer_span_enabled() {
315 let tracer = Tracer::new();
316 tracer.enable();
317 let result = tracer.span(TraceStep::Attention, "4 heads", || "done");
318 assert_eq!(result, "done");
319 let report = tracer.report();
320 assert!(report.contains("Attention"));
321 }
322
323 #[test]
324 fn test_tracer_clear() {
325 let tracer = Tracer::new();
326 tracer.enable();
327 tracer.start(TraceStep::Forward);
328 tracer.end(TraceStep::Forward, "test");
329 tracer.clear();
330 let report = tracer.report();
331 assert!(report.contains("No measurements recorded"));
332 }
333
334 #[test]
335 fn test_tracer_report_empty() {
336 let tracer = Tracer::new();
337 let report = tracer.report();
338 assert!(report.contains("No measurements recorded"));
339 }
340
341 #[test]
342 fn test_tracer_report_with_measurements() {
343 let tracer = Tracer::new();
344 tracer.enable();
345
346 tracer.start(TraceStep::Matmul);
347 tracer.end(TraceStep::Matmul, "512x512");
348
349 tracer.start(TraceStep::Transpose);
350 tracer.end(TraceStep::Transpose, "256x256");
351
352 let report = tracer.report();
353 assert!(report.contains("ENTRENAR TRACE REPORT"));
354 assert!(report.contains("Matmul"));
355 assert!(report.contains("Transpose"));
356 assert!(report.contains("% Time"));
357 }
358
359 #[test]
360 fn test_tracer_report_dr_popper_analysis() {
361 let tracer = Tracer::new();
362
363 {
365 let mut agg = tracer.aggregated.lock().expect("lock acquisition should succeed");
366 agg.insert(TraceStep::Matmul, (1, Duration::from_millis(50)));
367 agg.insert(TraceStep::Transpose, (1, Duration::from_millis(10)));
368 }
369
370 let report = tracer.report();
371 assert!(report.contains("Dr. Popper Analysis"));
372 assert!(report.contains("CUDA Compute:"));
373 assert!(report.contains("CPU Overhead:"));
374 }
375
376 #[test]
377 fn test_tracer_end_without_start() {
378 let tracer = Tracer::new();
379 tracer.enable();
380 tracer.end(TraceStep::Forward, "no start");
382 let report = tracer.report();
383 assert!(report.contains("No measurements recorded"));
384 }
385
386 #[test]
387 fn test_trace_measurement_clone() {
388 let measurement = TraceMeasurement {
389 step: TraceStep::Forward,
390 duration: Duration::from_millis(100),
391 metadata: "test".to_string(),
392 };
393 let cloned = measurement.clone();
394 assert_eq!(measurement.step, cloned.step);
395 assert_eq!(measurement.duration, cloned.duration);
396 assert_eq!(measurement.metadata, cloned.metadata);
397 }
398
399 #[test]
400 fn test_trace_measurement_debug() {
401 let measurement = TraceMeasurement {
402 step: TraceStep::Backward,
403 duration: Duration::from_micros(50),
404 metadata: "grad".to_string(),
405 };
406 let debug_str = format!("{measurement:?}");
407 assert!(debug_str.contains("TraceMeasurement"));
408 assert!(debug_str.contains("Backward"));
409 }
410}