Skip to main content

entrenar/train/transformer_trainer/
step_profiler.rs

1#![allow(dead_code)]
2//! Per-step wall-clock profiler for CUDA training (KAIZEN-047).
3//!
4//! Collects `Instant`-based timings for each phase of `train_step_single()`.
5//! Reports per-step breakdown and running statistics after N steps.
6//!
7//! # Contract (C-STEPPROF-001)
8//!
9//! - Zero-overhead when disabled: all methods are no-ops
10//! - No GPU synchronization added (relies on existing sync points)
11//! - Timings include CPU→GPU async dispatch latency (not pure kernel time)
12//! - Report interval configurable at construction
13//!
14//! # Phases measured
15//!
16//! 1. `embed`    — CPU embedding lookup
17//! 2. `h2d`      — Hidden state upload (H2D transfer + padding)
18//! 3. `forward`  — Block forward loop (includes D2D layer_input saves)
19//! 4. `norm_lm`  — Final RMSNorm + LM head GEMM + logits D2H
20//! 5. `loss`     — CPU softmax + cross-entropy + gradient
21//! 6. `grad_h2d` — Grad logits upload (H2D transfer)
22//! 7. `lm_bwd`   — LM head backward (GEMM_A + GEMM_B + clip)
23//! 8. `norm_bwd` — Final RMSNorm backward + clip
24//! 9. `blk_bwd`  — Block backward loop (includes recompute + optimizer)
25//! 10. `embed_bwd` — Embedding backward (D2H + clip + scatter-add)
26//! 11. `opt`      — CPU optimizer step (embedding + bookkeeping)
27
28use std::time::{Duration, Instant};
29
30/// Phase indices (must match `PHASE_NAMES`).
31const EMBED: usize = 0;
32const H2D: usize = 1;
33const FORWARD: usize = 2;
34const NORM_LM: usize = 3;
35const LOSS: usize = 4;
36const GRAD_H2D: usize = 5;
37const LM_BWD: usize = 6;
38const NORM_BWD: usize = 7;
39const BLK_BWD: usize = 8;
40const EMBED_BWD: usize = 9;
41const OPT: usize = 10;
42const NUM_PHASES: usize = 11;
43
44const PHASE_NAMES: [&str; NUM_PHASES] = [
45    "embed",
46    "h2d",
47    "forward",
48    "norm_lm",
49    "loss",
50    "grad_h2d",
51    "lm_bwd",
52    "norm_bwd",
53    "blk_bwd",
54    "embed_bwd",
55    "opt",
56];
57
58/// Maximum number of transformer layers to profile.
59const MAX_LAYERS: usize = 64;
60
61/// Per-operation indices for within-layer profiling (PMAT-483/entrenar#328).
62/// These track where time goes INSIDE each layer's forward and backward.
63const OP_RMSNORM_ATTN: usize = 0;
64const OP_QKV_GEMM: usize = 1;
65const OP_ATTENTION: usize = 2;
66const OP_O_PROJ: usize = 3;
67const OP_RMSNORM_FFN: usize = 4;
68const OP_GATE_UP_GEMM: usize = 5;
69const OP_SILU: usize = 6;
70const OP_DOWN_GEMM: usize = 7;
71const OP_LORA: usize = 8;
72const OP_DOWN_BWD: usize = 9;
73const OP_SWIGLU_BWD: usize = 10;
74const OP_GATE_UP_BWD: usize = 11;
75const OP_ATTN_BWD: usize = 12;
76const OP_QKV_BWD: usize = 13;
77const OP_NORM_BWD: usize = 14;
78const OP_LORA_BWD: usize = 15;
79const NUM_OPS: usize = 16;
80
81const OP_NAMES: [&str; NUM_OPS] = [
82    "rmsnorm_attn",
83    "qkv_gemm",
84    "attention",
85    "o_proj",
86    "rmsnorm_ffn",
87    "gate_up_gemm",
88    "silu",
89    "down_gemm",
90    "lora",
91    "down_bwd",
92    "swiglu_bwd",
93    "gate_up_bwd",
94    "attn_bwd",
95    "qkv_bwd",
96    "norm_bwd",
97    "lora_bwd",
98];
99
100/// Per-step timing accumulator.
101///
102/// Usage: call `begin(phase)` before each section, `end(phase)` after.
103/// Call `finish_step()` to record the step and optionally print a report.
104///
105/// Per-layer profiling (PMAT-480): call `begin_layer(layer)` / `end_layer_fwd(layer)`
106/// and `end_layer_bwd(layer)` inside the forward/backward loops to capture
107/// per-layer timing. Reports per-layer breakdown when enabled.
108pub struct StepProfiler {
109    enabled: bool,
110    /// Report every N steps (0 = never auto-report)
111    report_interval: usize,
112    /// Current step's phase timings
113    current: [Duration; NUM_PHASES],
114    /// Phase start timestamp (set by `begin`)
115    phase_start: Option<Instant>,
116    /// Step-level start timestamp
117    step_start: Option<Instant>,
118    /// Accumulated totals across all steps
119    totals: [Duration; NUM_PHASES],
120    /// Total wall-clock across all steps
121    total_wall: Duration,
122    /// Number of completed steps
123    step_count: usize,
124    /// Per-step wall-clock durations (for percentile analysis)
125    step_durations: Vec<Duration>,
126    /// Per-layer forward timing (accumulated across steps, PMAT-480)
127    layer_fwd_totals: Vec<Duration>,
128    /// Per-layer backward timing (accumulated across steps, PMAT-480)
129    layer_bwd_totals: Vec<Duration>,
130    /// Layer-level start timestamp (set by `begin_layer`)
131    layer_start: Option<Instant>,
132    /// Number of layers detected (set on first step)
133    num_layers: usize,
134    /// PMAT-483: Per-operation timing (accumulated across all layers and steps)
135    op_totals: [Duration; NUM_OPS],
136    /// Per-operation start timestamp
137    op_start: Option<Instant>,
138}
139
140impl StepProfiler {
141    /// Create a new profiler.
142    ///
143    /// - `enabled`: if false, all methods are no-ops
144    /// - `report_interval`: print summary every N steps (0 = manual only)
145    pub fn new(enabled: bool, report_interval: usize) -> Self {
146        Self {
147            enabled,
148            report_interval,
149            current: [Duration::ZERO; NUM_PHASES],
150            phase_start: None,
151            step_start: None,
152            totals: [Duration::ZERO; NUM_PHASES],
153            total_wall: Duration::ZERO,
154            step_count: 0,
155            step_durations: Vec::new(),
156            layer_fwd_totals: vec![Duration::ZERO; MAX_LAYERS],
157            layer_bwd_totals: vec![Duration::ZERO; MAX_LAYERS],
158            layer_start: None,
159            num_layers: 0,
160            op_totals: [Duration::ZERO; NUM_OPS],
161            op_start: None,
162        }
163    }
164
165    /// Disabled (no-op) profiler.
166    pub fn disabled() -> Self {
167        Self::new(false, 0)
168    }
169
170    /// Mark the start of a training step.
171    #[inline]
172    pub fn begin_step(&mut self) {
173        if !self.enabled {
174            return;
175        }
176        self.current = [Duration::ZERO; NUM_PHASES];
177        self.step_start = Some(Instant::now());
178    }
179
180    /// Mark the start of a phase. Must call `end(phase)` to record.
181    #[inline]
182    pub fn begin(&mut self, _phase: usize) {
183        if !self.enabled {
184            return;
185        }
186        self.phase_start = Some(Instant::now());
187    }
188
189    /// Record elapsed time for a phase (since last `begin`).
190    #[inline]
191    pub fn end(&mut self, phase: usize) {
192        if !self.enabled {
193            return;
194        }
195        if let Some(start) = self.phase_start.take() {
196            self.current[phase] += start.elapsed();
197        }
198    }
199
200    /// Start per-layer timing (PMAT-480). Call before each layer's forward or backward.
201    #[inline]
202    pub fn begin_layer(&mut self) {
203        if !self.enabled {
204            return;
205        }
206        self.layer_start = Some(Instant::now());
207    }
208
209    /// Record per-layer forward time (PMAT-480). Call after layer forward completes.
210    #[inline]
211    pub fn end_layer_fwd(&mut self, layer: usize) {
212        if !self.enabled {
213            return;
214        }
215        if let Some(start) = self.layer_start.take() {
216            if layer < MAX_LAYERS {
217                self.layer_fwd_totals[layer] += start.elapsed();
218                if layer >= self.num_layers {
219                    self.num_layers = layer + 1;
220                }
221            }
222        }
223    }
224
225    /// Record per-layer backward time (PMAT-480). Call after layer backward completes.
226    #[inline]
227    pub fn end_layer_bwd(&mut self, layer: usize) {
228        if !self.enabled {
229            return;
230        }
231        if let Some(start) = self.layer_start.take() {
232            if layer < MAX_LAYERS {
233                self.layer_bwd_totals[layer] += start.elapsed();
234                if layer >= self.num_layers {
235                    self.num_layers = layer + 1;
236                }
237            }
238        }
239    }
240
241    /// Finish the current step. Records totals and optionally prints report.
242    pub fn finish_step(&mut self) {
243        if !self.enabled {
244            return;
245        }
246        let step_wall = self.step_start.take().map_or(Duration::ZERO, |s| s.elapsed());
247
248        for i in 0..NUM_PHASES {
249            self.totals[i] += self.current[i];
250        }
251        self.total_wall += step_wall;
252        self.step_count += 1;
253        self.step_durations.push(step_wall);
254
255        if self.report_interval > 0 && self.step_count.is_multiple_of(self.report_interval) {
256            self.print_report();
257        }
258    }
259
260    /// Print cumulative profiling report to stdout.
261    pub fn print_report(&self) {
262        if self.step_count == 0 {
263            return;
264        }
265
266        let total_us = self.total_wall.as_micros() as f64;
267        let avg_step_us = total_us / self.step_count as f64;
268
269        println!(
270            "\n┌─ Step Profiler ({} steps, avg {:.1} ms/step) ─┐",
271            self.step_count,
272            avg_step_us / 1000.0
273        );
274        println!("│ {:>10} │ {:>8} │ {:>6} │ {:>8} │", "phase", "total_ms", "pct", "avg_ms");
275        println!("│ {:-<10} │ {:-<8} │ {:-<6} │ {:-<8} │", "", "", "", "");
276
277        let mut accounted = Duration::ZERO;
278        for i in 0..NUM_PHASES {
279            let t = self.totals[i];
280            accounted += t;
281            let ms = t.as_micros() as f64 / 1000.0;
282            let pct = if total_us > 0.0 { t.as_micros() as f64 / total_us * 100.0 } else { 0.0 };
283            let avg = ms / self.step_count as f64;
284            println!("│ {:>10} │ {:>8.1} │ {:>5.1}% │ {:>8.2} │", PHASE_NAMES[i], ms, pct, avg);
285        }
286
287        let unaccounted = self.total_wall.saturating_sub(accounted);
288        let unaccounted_pct =
289            if total_us > 0.0 { unaccounted.as_micros() as f64 / total_us * 100.0 } else { 0.0 };
290        println!(
291            "│ {:>10} │ {:>8.1} │ {:>5.1}% │ {:>8.2} │",
292            "other",
293            unaccounted.as_micros() as f64 / 1000.0,
294            unaccounted_pct,
295            unaccounted.as_micros() as f64 / 1000.0 / self.step_count as f64
296        );
297
298        println!(
299            "│ {:>10} │ {:>8.1} │ {:>5}  │ {:>8.2} │",
300            "TOTAL",
301            total_us / 1000.0,
302            "100%",
303            avg_step_us / 1000.0
304        );
305        println!("└────────────┴──────────┴────────┴──────────┘");
306
307        // Per-layer breakdown (PMAT-480)
308        if self.num_layers > 0 && self.step_count > 0 {
309            println!(
310                "\n┌─ Per-Layer Profile ({} layers, {} steps) ─┐",
311                self.num_layers, self.step_count
312            );
313            println!(
314                "│ {:>5} │ {:>8} │ {:>8} │ {:>8} │ {:>8} │",
315                "layer", "fwd_ms", "bwd_ms", "fwd_avg", "bwd_avg"
316            );
317            println!("│ {:->5} │ {:->8} │ {:->8} │ {:->8} │ {:->8} │", "", "", "", "", "");
318            let mut fwd_total = Duration::ZERO;
319            let mut bwd_total = Duration::ZERO;
320            for i in 0..self.num_layers {
321                let fwd = self.layer_fwd_totals[i];
322                let bwd = self.layer_bwd_totals[i];
323                fwd_total += fwd;
324                bwd_total += bwd;
325                let fwd_ms = fwd.as_micros() as f64 / 1000.0;
326                let bwd_ms = bwd.as_micros() as f64 / 1000.0;
327                let fwd_avg = fwd_ms / self.step_count as f64;
328                let bwd_avg = bwd_ms / self.step_count as f64;
329                println!(
330                    "│ {i:>5} │ {fwd_ms:>8.1} │ {bwd_ms:>8.1} │ {fwd_avg:>8.2} │ {bwd_avg:>8.2} │"
331                );
332            }
333            let fwd_total_ms = fwd_total.as_micros() as f64 / 1000.0;
334            let bwd_total_ms = bwd_total.as_micros() as f64 / 1000.0;
335            println!(
336                "│ {:>5} │ {:>8.1} │ {:>8.1} │ {:>8.2} │ {:>8.2} │",
337                "TOTAL",
338                fwd_total_ms,
339                bwd_total_ms,
340                fwd_total_ms / self.step_count as f64,
341                bwd_total_ms / self.step_count as f64
342            );
343            println!("└───────┴──────────┴──────────┴──────────┴──────────┘");
344
345            // Identify hotspot layers (>1.5x average)
346            let avg_fwd = fwd_total / self.num_layers as u32;
347            let avg_bwd = bwd_total / self.num_layers as u32;
348            let mut hotspots = Vec::new();
349            for i in 0..self.num_layers {
350                let fwd_ratio = if avg_fwd.as_nanos() > 0 {
351                    self.layer_fwd_totals[i].as_nanos() as f64 / avg_fwd.as_nanos() as f64
352                } else {
353                    0.0
354                };
355                let bwd_ratio = if avg_bwd.as_nanos() > 0 {
356                    self.layer_bwd_totals[i].as_nanos() as f64 / avg_bwd.as_nanos() as f64
357                } else {
358                    0.0
359                };
360                if fwd_ratio > 1.5 || bwd_ratio > 1.5 {
361                    hotspots.push((i, fwd_ratio, bwd_ratio));
362                }
363            }
364            if !hotspots.is_empty() {
365                println!("  Hotspot layers (>1.5x average):");
366                for (layer, fwd_r, bwd_r) in &hotspots {
367                    println!("    L{layer}: fwd {fwd_r:.1}x, bwd {bwd_r:.1}x");
368                }
369            }
370        }
371
372        // Percentiles for step wall-clock
373        if self.step_durations.len() >= 10 {
374            let mut sorted: Vec<u128> =
375                self.step_durations.iter().map(std::time::Duration::as_micros).collect();
376            sorted.sort_unstable();
377            let p50 = sorted[sorted.len() / 2];
378            let p95 = sorted[sorted.len() * 95 / 100];
379            let p99 = sorted[sorted.len() * 99 / 100];
380            println!(
381                "  Step latency: p50={:.1}ms p95={:.1}ms p99={:.1}ms",
382                p50 as f64 / 1000.0,
383                p95 as f64 / 1000.0,
384                p99 as f64 / 1000.0
385            );
386        }
387    }
388
389    /// PMAT-483/entrenar#328: Start timing a per-operation phase within a layer.
390    #[inline]
391    pub fn begin_op(&mut self) {
392        if !self.enabled {
393            return;
394        }
395        self.op_start = Some(Instant::now());
396    }
397
398    /// PMAT-483/entrenar#328: Record elapsed time for a per-operation phase.
399    #[inline]
400    pub fn end_op(&mut self, op: usize) {
401        if !self.enabled {
402            return;
403        }
404        if let Some(start) = self.op_start.take() {
405            if op < NUM_OPS {
406                self.op_totals[op] += start.elapsed();
407            }
408        }
409    }
410
411    /// PMAT-483: Feed pre-accumulated microseconds for an operation (from CudaBlockScratch).
412    #[inline]
413    pub fn end_op_raw(&mut self, op: usize, us: u64) {
414        if !self.enabled || op >= NUM_OPS {
415            return;
416        }
417        self.op_totals[op] += Duration::from_micros(us);
418    }
419
420    /// Whether the profiler is active.
421    pub fn is_enabled(&self) -> bool {
422        self.enabled
423    }
424
425    /// Get step count.
426    pub fn step_count(&self) -> usize {
427        self.step_count
428    }
429
430    /// PMAT-483: Emit structured JSON profiling report to stderr.
431    /// Parseable by canary scripts for scientific analysis.
432    pub fn print_json_report(&self) {
433        if self.step_count == 0 {
434            return;
435        }
436
437        let total_us = self.total_wall.as_micros() as f64;
438        let avg_step_ms = total_us / self.step_count as f64 / 1000.0;
439
440        let mut accounted_us = 0u128;
441        let mut phases = Vec::new();
442        for i in 0..NUM_PHASES {
443            let t = self.totals[i];
444            accounted_us += t.as_micros();
445            let ms = t.as_micros() as f64 / 1000.0;
446            let pct = if total_us > 0.0 { t.as_micros() as f64 / total_us * 100.0 } else { 0.0 };
447            let avg = ms / self.step_count as f64;
448            phases.push(format!(
449                "\"{}\":{{\"total_ms\":{:.1},\"pct\":{:.1},\"avg_ms\":{:.2}}}",
450                PHASE_NAMES[i], ms, pct, avg
451            ));
452        }
453
454        let wall_coverage = if total_us > 0.0 { accounted_us as f64 / total_us } else { 0.0 };
455
456        let mut layers_json = Vec::new();
457        for i in 0..self.num_layers {
458            let fwd_ms = self.layer_fwd_totals[i].as_micros() as f64 / 1000.0;
459            let bwd_ms = self.layer_bwd_totals[i].as_micros() as f64 / 1000.0;
460            layers_json
461                .push(format!("{{\"layer\":{i},\"fwd_ms\":{fwd_ms:.1},\"bwd_ms\":{bwd_ms:.1}}}"));
462        }
463
464        // Classify bottleneck based on phase distribution
465        let forward_pct = if total_us > 0.0 {
466            self.totals[FORWARD].as_micros() as f64 / total_us * 100.0
467        } else {
468            0.0
469        };
470        let transfer_pct = if total_us > 0.0 {
471            (self.totals[H2D].as_micros() + self.totals[GRAD_H2D].as_micros()) as f64 / total_us
472                * 100.0
473        } else {
474            0.0
475        };
476        let bottleneck = if transfer_pct > 30.0 {
477            "transfer"
478        } else if forward_pct < 20.0 {
479            "launch"
480        } else {
481            "memory_bw"
482        };
483
484        // PMAT-483: Per-operation breakdown
485        let mut ops_json = Vec::new();
486        let mut total_op_us = 0u128;
487        for i in 0..NUM_OPS {
488            let t = self.op_totals[i];
489            let us = t.as_micros();
490            total_op_us += us;
491            if us > 0 {
492                let ms = us as f64 / 1000.0;
493                ops_json.push(format!("\"{}\":{:.1}", OP_NAMES[i], ms));
494            }
495        }
496
497        // Classify which operation type dominates
498        let gemm_us = self.op_totals[OP_QKV_GEMM].as_micros()
499            + self.op_totals[OP_O_PROJ].as_micros()
500            + self.op_totals[OP_GATE_UP_GEMM].as_micros()
501            + self.op_totals[OP_DOWN_GEMM].as_micros();
502        let gemm_bwd_us = self.op_totals[OP_QKV_BWD].as_micros()
503            + self.op_totals[OP_GATE_UP_BWD].as_micros()
504            + self.op_totals[OP_DOWN_BWD].as_micros();
505        let total_gemm_us = gemm_us + gemm_bwd_us;
506        let gemm_pct =
507            if total_op_us > 0 { total_gemm_us as f64 / total_op_us as f64 * 100.0 } else { 0.0 };
508
509        eprintln!(
510            "{{\"_profiler\":\"step_profiler_v2\",\"steps\":{},\"avg_step_ms\":{:.2},\"wall_coverage\":{:.3},\"bottleneck\":\"{}\",\"gemm_pct\":{:.1},\"phases\":{{{}}},\"per_layer\":[{}],\"ops\":{{{}}}}}",
511            self.step_count,
512            avg_step_ms,
513            wall_coverage,
514            bottleneck,
515            gemm_pct,
516            phases.join(","),
517            layers_json.join(","),
518            ops_json.join(","),
519        );
520    }
521
522    /// PMAT-483: Feed per-layer timing data from InstructGpuTrainingState.
523    /// Call once per step after forward+backward complete.
524    pub fn record_layer_times(&mut self, fwd_us: &[u64], bwd_us: &[u64]) {
525        if !self.enabled {
526            return;
527        }
528        for (i, &us) in fwd_us.iter().enumerate() {
529            if i < MAX_LAYERS && us > 0 {
530                self.layer_fwd_totals[i] += std::time::Duration::from_micros(us);
531                if i >= self.num_layers {
532                    self.num_layers = i + 1;
533                }
534            }
535        }
536        for (i, &us) in bwd_us.iter().enumerate() {
537            if i < MAX_LAYERS && us > 0 {
538                self.layer_bwd_totals[i] += std::time::Duration::from_micros(us);
539                if i >= self.num_layers {
540                    self.num_layers = i + 1;
541                }
542            }
543        }
544    }
545
546    // Phase constants for external callers
547    pub const EMBED: usize = EMBED;
548
549    // Per-operation constants (PMAT-483/entrenar#328)
550    pub const OP_RMSNORM_ATTN: usize = OP_RMSNORM_ATTN;
551    pub const OP_QKV_GEMM: usize = OP_QKV_GEMM;
552    pub const OP_ATTENTION: usize = OP_ATTENTION;
553    pub const OP_O_PROJ: usize = OP_O_PROJ;
554    pub const OP_RMSNORM_FFN: usize = OP_RMSNORM_FFN;
555    pub const OP_GATE_UP_GEMM: usize = OP_GATE_UP_GEMM;
556    pub const OP_SILU: usize = OP_SILU;
557    pub const OP_DOWN_GEMM: usize = OP_DOWN_GEMM;
558    pub const OP_LORA: usize = OP_LORA;
559    pub const OP_DOWN_BWD: usize = OP_DOWN_BWD;
560    pub const OP_SWIGLU_BWD: usize = OP_SWIGLU_BWD;
561    pub const OP_GATE_UP_BWD: usize = OP_GATE_UP_BWD;
562    pub const OP_ATTN_BWD: usize = OP_ATTN_BWD;
563    pub const OP_QKV_BWD: usize = OP_QKV_BWD;
564    pub const OP_NORM_BWD: usize = OP_NORM_BWD;
565    pub const OP_LORA_BWD: usize = OP_LORA_BWD;
566    pub const H2D: usize = H2D;
567    pub const FORWARD: usize = FORWARD;
568    pub const NORM_LM: usize = NORM_LM;
569    pub const LOSS: usize = LOSS;
570    pub const GRAD_H2D: usize = GRAD_H2D;
571    pub const LM_BWD: usize = LM_BWD;
572    pub const NORM_BWD: usize = NORM_BWD;
573    pub const BLK_BWD: usize = BLK_BWD;
574    pub const EMBED_BWD: usize = EMBED_BWD;
575    pub const OPT: usize = OPT;
576}
577
578#[cfg(test)]
579#[allow(clippy::unwrap_used)]
580mod tests {
581    use super::*;
582
583    #[test]
584    fn test_disabled_profiler_is_noop() {
585        let mut p = StepProfiler::disabled();
586        p.begin_step();
587        p.begin(StepProfiler::EMBED);
588        p.end(StepProfiler::EMBED);
589        p.finish_step();
590        assert_eq!(p.step_count(), 0);
591    }
592
593    #[test]
594    fn test_enabled_profiler_counts_steps() {
595        let mut p = StepProfiler::new(true, 0);
596        p.begin_step();
597        p.begin(StepProfiler::EMBED);
598        std::thread::sleep(Duration::from_millis(1));
599        p.end(StepProfiler::EMBED);
600        p.finish_step();
601        assert_eq!(p.step_count(), 1);
602        assert!(p.totals[EMBED] >= Duration::from_micros(500));
603    }
604
605    #[test]
606    fn test_multiple_steps_accumulate() {
607        let mut p = StepProfiler::new(true, 0);
608        for _ in 0..3 {
609            p.begin_step();
610            p.begin(StepProfiler::LOSS);
611            std::thread::sleep(Duration::from_millis(1));
612            p.end(StepProfiler::LOSS);
613            p.finish_step();
614        }
615        assert_eq!(p.step_count(), 3);
616        assert!(p.totals[LOSS] >= Duration::from_millis(3));
617    }
618
619    #[test]
620    fn test_unaccounted_time_captured() {
621        let mut p = StepProfiler::new(true, 0);
622        p.begin_step();
623        // Sleep without recording any phase → should appear as "other"
624        std::thread::sleep(Duration::from_millis(2));
625        p.finish_step();
626        assert!(p.total_wall >= Duration::from_millis(2));
627        // All phase totals should be zero
628        for i in 0..NUM_PHASES {
629            assert_eq!(p.totals[i], Duration::ZERO);
630        }
631    }
632
633    #[test]
634    fn test_is_enabled() {
635        let p = StepProfiler::new(true, 0);
636        assert!(p.is_enabled());
637        let p = StepProfiler::disabled();
638        assert!(!p.is_enabled());
639    }
640
641    #[test]
642    fn test_step_count_starts_at_zero() {
643        let p = StepProfiler::new(true, 0);
644        assert_eq!(p.step_count(), 0);
645    }
646
647    #[test]
648    fn test_disabled_profiler_step_count_stays_zero() {
649        let mut p = StepProfiler::disabled();
650        for _ in 0..5 {
651            p.begin_step();
652            p.begin(StepProfiler::EMBED);
653            p.end(StepProfiler::EMBED);
654            p.finish_step();
655        }
656        assert_eq!(p.step_count(), 0);
657    }
658
659    #[test]
660    fn test_all_phase_constants() {
661        // Verify phase constants match expected values
662        assert_eq!(StepProfiler::EMBED, 0);
663        assert_eq!(StepProfiler::H2D, 1);
664        assert_eq!(StepProfiler::FORWARD, 2);
665        assert_eq!(StepProfiler::NORM_LM, 3);
666        assert_eq!(StepProfiler::LOSS, 4);
667        assert_eq!(StepProfiler::GRAD_H2D, 5);
668        assert_eq!(StepProfiler::LM_BWD, 6);
669        assert_eq!(StepProfiler::NORM_BWD, 7);
670        assert_eq!(StepProfiler::BLK_BWD, 8);
671        assert_eq!(StepProfiler::EMBED_BWD, 9);
672        assert_eq!(StepProfiler::OPT, 10);
673    }
674
675    #[test]
676    fn test_phase_names_count() {
677        assert_eq!(PHASE_NAMES.len(), NUM_PHASES);
678        assert_eq!(NUM_PHASES, 11);
679    }
680
681    #[test]
682    fn test_multiple_phases_in_one_step() {
683        let mut p = StepProfiler::new(true, 0);
684        p.begin_step();
685
686        p.begin(StepProfiler::EMBED);
687        std::thread::sleep(Duration::from_millis(1));
688        p.end(StepProfiler::EMBED);
689
690        p.begin(StepProfiler::FORWARD);
691        std::thread::sleep(Duration::from_millis(1));
692        p.end(StepProfiler::FORWARD);
693
694        p.begin(StepProfiler::LOSS);
695        std::thread::sleep(Duration::from_millis(1));
696        p.end(StepProfiler::LOSS);
697
698        p.finish_step();
699
700        assert_eq!(p.step_count(), 1);
701        assert!(p.totals[EMBED] > Duration::ZERO);
702        assert!(p.totals[FORWARD] > Duration::ZERO);
703        assert!(p.totals[LOSS] > Duration::ZERO);
704        // Unrecorded phases should be zero
705        assert_eq!(p.totals[H2D], Duration::ZERO);
706        assert_eq!(p.totals[NORM_LM], Duration::ZERO);
707    }
708
709    #[test]
710    fn test_end_without_begin_is_noop() {
711        let mut p = StepProfiler::new(true, 0);
712        p.begin_step();
713        // end without begin should be a no-op (phase_start is None)
714        p.end(StepProfiler::EMBED);
715        p.finish_step();
716        assert_eq!(p.totals[EMBED], Duration::ZERO);
717    }
718
719    #[test]
720    fn test_print_report_empty_is_noop() {
721        let p = StepProfiler::new(true, 0);
722        // No steps recorded — should not panic
723        p.print_report();
724        assert_eq!(p.step_count(), 0);
725    }
726
727    #[test]
728    fn test_print_report_with_data() {
729        let mut p = StepProfiler::new(true, 0);
730        for _ in 0..3 {
731            p.begin_step();
732            p.begin(StepProfiler::EMBED);
733            std::thread::sleep(Duration::from_millis(1));
734            p.end(StepProfiler::EMBED);
735            p.finish_step();
736        }
737        // Should print without panic
738        p.print_report();
739        assert_eq!(p.step_count(), 3);
740    }
741
742    #[test]
743    fn test_report_interval_auto_print() {
744        let mut p = StepProfiler::new(true, 2); // report every 2 steps
745        for _ in 0..4 {
746            p.begin_step();
747            p.begin(StepProfiler::LOSS);
748            p.end(StepProfiler::LOSS);
749            p.finish_step();
750        }
751        assert_eq!(p.step_count(), 4);
752        // Report should have been triggered at steps 2 and 4
753    }
754
755    #[test]
756    fn test_step_durations_tracked() {
757        let mut p = StepProfiler::new(true, 0);
758        for _ in 0..5 {
759            p.begin_step();
760            std::thread::sleep(Duration::from_millis(1));
761            p.finish_step();
762        }
763        assert_eq!(p.step_durations.len(), 5);
764        for d in &p.step_durations {
765            assert!(*d >= Duration::from_micros(500));
766        }
767    }
768
769    #[test]
770    fn test_total_wall_accumulates() {
771        let mut p = StepProfiler::new(true, 0);
772        p.begin_step();
773        std::thread::sleep(Duration::from_millis(2));
774        p.finish_step();
775
776        p.begin_step();
777        std::thread::sleep(Duration::from_millis(2));
778        p.finish_step();
779
780        assert!(p.total_wall >= Duration::from_millis(4));
781    }
782
783    #[test]
784    fn test_percentiles_with_enough_steps() {
785        let mut p = StepProfiler::new(true, 0);
786        for _ in 0..20 {
787            p.begin_step();
788            std::thread::sleep(Duration::from_millis(1));
789            p.finish_step();
790        }
791        // print_report will show percentiles (>= 10 steps)
792        p.print_report();
793        assert_eq!(p.step_durations.len(), 20);
794    }
795
796    #[test]
797    fn test_finish_step_without_begin_step() {
798        let mut p = StepProfiler::new(true, 0);
799        // finish_step without begin_step — step_start is None
800        p.finish_step();
801        // Should record Duration::ZERO for wall time
802        assert_eq!(p.step_count(), 1);
803        assert_eq!(p.total_wall, Duration::ZERO);
804    }
805
806    // --- Per-layer profiling tests (PMAT-480) ---
807
808    #[test]
809    fn test_per_layer_fwd_timing() {
810        let mut p = StepProfiler::new(true, 0);
811        p.begin_step();
812        for layer in 0..3 {
813            p.begin_layer();
814            std::thread::sleep(Duration::from_millis(1));
815            p.end_layer_fwd(layer);
816        }
817        p.finish_step();
818        assert_eq!(p.num_layers, 3);
819        for layer in 0..3 {
820            assert!(p.layer_fwd_totals[layer] >= Duration::from_micros(500));
821        }
822    }
823
824    #[test]
825    fn test_per_layer_bwd_timing() {
826        let mut p = StepProfiler::new(true, 0);
827        p.begin_step();
828        for layer in (0..4).rev() {
829            p.begin_layer();
830            std::thread::sleep(Duration::from_millis(1));
831            p.end_layer_bwd(layer);
832        }
833        p.finish_step();
834        assert_eq!(p.num_layers, 4);
835        for layer in 0..4 {
836            assert!(p.layer_bwd_totals[layer] >= Duration::from_micros(500));
837        }
838    }
839
840    #[test]
841    fn test_per_layer_accumulates_across_steps() {
842        let mut p = StepProfiler::new(true, 0);
843        for _ in 0..3 {
844            p.begin_step();
845            p.begin_layer();
846            std::thread::sleep(Duration::from_millis(1));
847            p.end_layer_fwd(0);
848            p.finish_step();
849        }
850        assert!(p.layer_fwd_totals[0] >= Duration::from_millis(3));
851    }
852
853    #[test]
854    fn test_per_layer_disabled_is_noop() {
855        let mut p = StepProfiler::disabled();
856        p.begin_layer();
857        p.end_layer_fwd(0);
858        p.begin_layer();
859        p.end_layer_bwd(0);
860        assert_eq!(p.num_layers, 0);
861    }
862
863    #[test]
864    fn test_per_layer_out_of_bounds_ignored() {
865        let mut p = StepProfiler::new(true, 0);
866        p.begin_step();
867        p.begin_layer();
868        // Layer index >= MAX_LAYERS should be silently ignored
869        p.end_layer_fwd(MAX_LAYERS + 1);
870        p.finish_step();
871        assert_eq!(p.num_layers, 0);
872    }
873
874    #[test]
875    fn test_per_layer_report_prints() {
876        let mut p = StepProfiler::new(true, 0);
877        for _ in 0..2 {
878            p.begin_step();
879            for layer in 0..3 {
880                p.begin_layer();
881                std::thread::sleep(Duration::from_millis(1));
882                p.end_layer_fwd(layer);
883            }
884            for layer in (0..3).rev() {
885                p.begin_layer();
886                std::thread::sleep(Duration::from_millis(1));
887                p.end_layer_bwd(layer);
888            }
889            p.finish_step();
890        }
891        // Should print per-layer breakdown without panic
892        p.print_report();
893        assert_eq!(p.num_layers, 3);
894    }
895}