1#![allow(dead_code)]
2use std::time::{Duration, Instant};
29
30const EMBED: usize = 0;
32const H2D: usize = 1;
33const FORWARD: usize = 2;
34const NORM_LM: usize = 3;
35const LOSS: usize = 4;
36const GRAD_H2D: usize = 5;
37const LM_BWD: usize = 6;
38const NORM_BWD: usize = 7;
39const BLK_BWD: usize = 8;
40const EMBED_BWD: usize = 9;
41const OPT: usize = 10;
42const NUM_PHASES: usize = 11;
43
44const PHASE_NAMES: [&str; NUM_PHASES] = [
45 "embed",
46 "h2d",
47 "forward",
48 "norm_lm",
49 "loss",
50 "grad_h2d",
51 "lm_bwd",
52 "norm_bwd",
53 "blk_bwd",
54 "embed_bwd",
55 "opt",
56];
57
58const MAX_LAYERS: usize = 64;
60
61const OP_RMSNORM_ATTN: usize = 0;
64const OP_QKV_GEMM: usize = 1;
65const OP_ATTENTION: usize = 2;
66const OP_O_PROJ: usize = 3;
67const OP_RMSNORM_FFN: usize = 4;
68const OP_GATE_UP_GEMM: usize = 5;
69const OP_SILU: usize = 6;
70const OP_DOWN_GEMM: usize = 7;
71const OP_LORA: usize = 8;
72const OP_DOWN_BWD: usize = 9;
73const OP_SWIGLU_BWD: usize = 10;
74const OP_GATE_UP_BWD: usize = 11;
75const OP_ATTN_BWD: usize = 12;
76const OP_QKV_BWD: usize = 13;
77const OP_NORM_BWD: usize = 14;
78const OP_LORA_BWD: usize = 15;
79const NUM_OPS: usize = 16;
80
81const OP_NAMES: [&str; NUM_OPS] = [
82 "rmsnorm_attn",
83 "qkv_gemm",
84 "attention",
85 "o_proj",
86 "rmsnorm_ffn",
87 "gate_up_gemm",
88 "silu",
89 "down_gemm",
90 "lora",
91 "down_bwd",
92 "swiglu_bwd",
93 "gate_up_bwd",
94 "attn_bwd",
95 "qkv_bwd",
96 "norm_bwd",
97 "lora_bwd",
98];
99
100pub struct StepProfiler {
109 enabled: bool,
110 report_interval: usize,
112 current: [Duration; NUM_PHASES],
114 phase_start: Option<Instant>,
116 step_start: Option<Instant>,
118 totals: [Duration; NUM_PHASES],
120 total_wall: Duration,
122 step_count: usize,
124 step_durations: Vec<Duration>,
126 layer_fwd_totals: Vec<Duration>,
128 layer_bwd_totals: Vec<Duration>,
130 layer_start: Option<Instant>,
132 num_layers: usize,
134 op_totals: [Duration; NUM_OPS],
136 op_start: Option<Instant>,
138}
139
140impl StepProfiler {
141 pub fn new(enabled: bool, report_interval: usize) -> Self {
146 Self {
147 enabled,
148 report_interval,
149 current: [Duration::ZERO; NUM_PHASES],
150 phase_start: None,
151 step_start: None,
152 totals: [Duration::ZERO; NUM_PHASES],
153 total_wall: Duration::ZERO,
154 step_count: 0,
155 step_durations: Vec::new(),
156 layer_fwd_totals: vec![Duration::ZERO; MAX_LAYERS],
157 layer_bwd_totals: vec![Duration::ZERO; MAX_LAYERS],
158 layer_start: None,
159 num_layers: 0,
160 op_totals: [Duration::ZERO; NUM_OPS],
161 op_start: None,
162 }
163 }
164
165 pub fn disabled() -> Self {
167 Self::new(false, 0)
168 }
169
170 #[inline]
172 pub fn begin_step(&mut self) {
173 if !self.enabled {
174 return;
175 }
176 self.current = [Duration::ZERO; NUM_PHASES];
177 self.step_start = Some(Instant::now());
178 }
179
180 #[inline]
182 pub fn begin(&mut self, _phase: usize) {
183 if !self.enabled {
184 return;
185 }
186 self.phase_start = Some(Instant::now());
187 }
188
189 #[inline]
191 pub fn end(&mut self, phase: usize) {
192 if !self.enabled {
193 return;
194 }
195 if let Some(start) = self.phase_start.take() {
196 self.current[phase] += start.elapsed();
197 }
198 }
199
200 #[inline]
202 pub fn begin_layer(&mut self) {
203 if !self.enabled {
204 return;
205 }
206 self.layer_start = Some(Instant::now());
207 }
208
209 #[inline]
211 pub fn end_layer_fwd(&mut self, layer: usize) {
212 if !self.enabled {
213 return;
214 }
215 if let Some(start) = self.layer_start.take() {
216 if layer < MAX_LAYERS {
217 self.layer_fwd_totals[layer] += start.elapsed();
218 if layer >= self.num_layers {
219 self.num_layers = layer + 1;
220 }
221 }
222 }
223 }
224
225 #[inline]
227 pub fn end_layer_bwd(&mut self, layer: usize) {
228 if !self.enabled {
229 return;
230 }
231 if let Some(start) = self.layer_start.take() {
232 if layer < MAX_LAYERS {
233 self.layer_bwd_totals[layer] += start.elapsed();
234 if layer >= self.num_layers {
235 self.num_layers = layer + 1;
236 }
237 }
238 }
239 }
240
241 pub fn finish_step(&mut self) {
243 if !self.enabled {
244 return;
245 }
246 let step_wall = self.step_start.take().map_or(Duration::ZERO, |s| s.elapsed());
247
248 for i in 0..NUM_PHASES {
249 self.totals[i] += self.current[i];
250 }
251 self.total_wall += step_wall;
252 self.step_count += 1;
253 self.step_durations.push(step_wall);
254
255 if self.report_interval > 0 && self.step_count.is_multiple_of(self.report_interval) {
256 self.print_report();
257 }
258 }
259
260 pub fn print_report(&self) {
262 if self.step_count == 0 {
263 return;
264 }
265
266 let total_us = self.total_wall.as_micros() as f64;
267 let avg_step_us = total_us / self.step_count as f64;
268
269 println!(
270 "\n┌─ Step Profiler ({} steps, avg {:.1} ms/step) ─┐",
271 self.step_count,
272 avg_step_us / 1000.0
273 );
274 println!("│ {:>10} │ {:>8} │ {:>6} │ {:>8} │", "phase", "total_ms", "pct", "avg_ms");
275 println!("│ {:-<10} │ {:-<8} │ {:-<6} │ {:-<8} │", "", "", "", "");
276
277 let mut accounted = Duration::ZERO;
278 for i in 0..NUM_PHASES {
279 let t = self.totals[i];
280 accounted += t;
281 let ms = t.as_micros() as f64 / 1000.0;
282 let pct = if total_us > 0.0 { t.as_micros() as f64 / total_us * 100.0 } else { 0.0 };
283 let avg = ms / self.step_count as f64;
284 println!("│ {:>10} │ {:>8.1} │ {:>5.1}% │ {:>8.2} │", PHASE_NAMES[i], ms, pct, avg);
285 }
286
287 let unaccounted = self.total_wall.saturating_sub(accounted);
288 let unaccounted_pct =
289 if total_us > 0.0 { unaccounted.as_micros() as f64 / total_us * 100.0 } else { 0.0 };
290 println!(
291 "│ {:>10} │ {:>8.1} │ {:>5.1}% │ {:>8.2} │",
292 "other",
293 unaccounted.as_micros() as f64 / 1000.0,
294 unaccounted_pct,
295 unaccounted.as_micros() as f64 / 1000.0 / self.step_count as f64
296 );
297
298 println!(
299 "│ {:>10} │ {:>8.1} │ {:>5} │ {:>8.2} │",
300 "TOTAL",
301 total_us / 1000.0,
302 "100%",
303 avg_step_us / 1000.0
304 );
305 println!("└────────────┴──────────┴────────┴──────────┘");
306
307 if self.num_layers > 0 && self.step_count > 0 {
309 println!(
310 "\n┌─ Per-Layer Profile ({} layers, {} steps) ─┐",
311 self.num_layers, self.step_count
312 );
313 println!(
314 "│ {:>5} │ {:>8} │ {:>8} │ {:>8} │ {:>8} │",
315 "layer", "fwd_ms", "bwd_ms", "fwd_avg", "bwd_avg"
316 );
317 println!("│ {:->5} │ {:->8} │ {:->8} │ {:->8} │ {:->8} │", "", "", "", "", "");
318 let mut fwd_total = Duration::ZERO;
319 let mut bwd_total = Duration::ZERO;
320 for i in 0..self.num_layers {
321 let fwd = self.layer_fwd_totals[i];
322 let bwd = self.layer_bwd_totals[i];
323 fwd_total += fwd;
324 bwd_total += bwd;
325 let fwd_ms = fwd.as_micros() as f64 / 1000.0;
326 let bwd_ms = bwd.as_micros() as f64 / 1000.0;
327 let fwd_avg = fwd_ms / self.step_count as f64;
328 let bwd_avg = bwd_ms / self.step_count as f64;
329 println!(
330 "│ {i:>5} │ {fwd_ms:>8.1} │ {bwd_ms:>8.1} │ {fwd_avg:>8.2} │ {bwd_avg:>8.2} │"
331 );
332 }
333 let fwd_total_ms = fwd_total.as_micros() as f64 / 1000.0;
334 let bwd_total_ms = bwd_total.as_micros() as f64 / 1000.0;
335 println!(
336 "│ {:>5} │ {:>8.1} │ {:>8.1} │ {:>8.2} │ {:>8.2} │",
337 "TOTAL",
338 fwd_total_ms,
339 bwd_total_ms,
340 fwd_total_ms / self.step_count as f64,
341 bwd_total_ms / self.step_count as f64
342 );
343 println!("└───────┴──────────┴──────────┴──────────┴──────────┘");
344
345 let avg_fwd = fwd_total / self.num_layers as u32;
347 let avg_bwd = bwd_total / self.num_layers as u32;
348 let mut hotspots = Vec::new();
349 for i in 0..self.num_layers {
350 let fwd_ratio = if avg_fwd.as_nanos() > 0 {
351 self.layer_fwd_totals[i].as_nanos() as f64 / avg_fwd.as_nanos() as f64
352 } else {
353 0.0
354 };
355 let bwd_ratio = if avg_bwd.as_nanos() > 0 {
356 self.layer_bwd_totals[i].as_nanos() as f64 / avg_bwd.as_nanos() as f64
357 } else {
358 0.0
359 };
360 if fwd_ratio > 1.5 || bwd_ratio > 1.5 {
361 hotspots.push((i, fwd_ratio, bwd_ratio));
362 }
363 }
364 if !hotspots.is_empty() {
365 println!(" Hotspot layers (>1.5x average):");
366 for (layer, fwd_r, bwd_r) in &hotspots {
367 println!(" L{layer}: fwd {fwd_r:.1}x, bwd {bwd_r:.1}x");
368 }
369 }
370 }
371
372 if self.step_durations.len() >= 10 {
374 let mut sorted: Vec<u128> =
375 self.step_durations.iter().map(std::time::Duration::as_micros).collect();
376 sorted.sort_unstable();
377 let p50 = sorted[sorted.len() / 2];
378 let p95 = sorted[sorted.len() * 95 / 100];
379 let p99 = sorted[sorted.len() * 99 / 100];
380 println!(
381 " Step latency: p50={:.1}ms p95={:.1}ms p99={:.1}ms",
382 p50 as f64 / 1000.0,
383 p95 as f64 / 1000.0,
384 p99 as f64 / 1000.0
385 );
386 }
387 }
388
389 #[inline]
391 pub fn begin_op(&mut self) {
392 if !self.enabled {
393 return;
394 }
395 self.op_start = Some(Instant::now());
396 }
397
398 #[inline]
400 pub fn end_op(&mut self, op: usize) {
401 if !self.enabled {
402 return;
403 }
404 if let Some(start) = self.op_start.take() {
405 if op < NUM_OPS {
406 self.op_totals[op] += start.elapsed();
407 }
408 }
409 }
410
411 #[inline]
413 pub fn end_op_raw(&mut self, op: usize, us: u64) {
414 if !self.enabled || op >= NUM_OPS {
415 return;
416 }
417 self.op_totals[op] += Duration::from_micros(us);
418 }
419
420 pub fn is_enabled(&self) -> bool {
422 self.enabled
423 }
424
425 pub fn step_count(&self) -> usize {
427 self.step_count
428 }
429
430 pub fn print_json_report(&self) {
433 if self.step_count == 0 {
434 return;
435 }
436
437 let total_us = self.total_wall.as_micros() as f64;
438 let avg_step_ms = total_us / self.step_count as f64 / 1000.0;
439
440 let mut accounted_us = 0u128;
441 let mut phases = Vec::new();
442 for i in 0..NUM_PHASES {
443 let t = self.totals[i];
444 accounted_us += t.as_micros();
445 let ms = t.as_micros() as f64 / 1000.0;
446 let pct = if total_us > 0.0 { t.as_micros() as f64 / total_us * 100.0 } else { 0.0 };
447 let avg = ms / self.step_count as f64;
448 phases.push(format!(
449 "\"{}\":{{\"total_ms\":{:.1},\"pct\":{:.1},\"avg_ms\":{:.2}}}",
450 PHASE_NAMES[i], ms, pct, avg
451 ));
452 }
453
454 let wall_coverage = if total_us > 0.0 { accounted_us as f64 / total_us } else { 0.0 };
455
456 let mut layers_json = Vec::new();
457 for i in 0..self.num_layers {
458 let fwd_ms = self.layer_fwd_totals[i].as_micros() as f64 / 1000.0;
459 let bwd_ms = self.layer_bwd_totals[i].as_micros() as f64 / 1000.0;
460 layers_json
461 .push(format!("{{\"layer\":{i},\"fwd_ms\":{fwd_ms:.1},\"bwd_ms\":{bwd_ms:.1}}}"));
462 }
463
464 let forward_pct = if total_us > 0.0 {
466 self.totals[FORWARD].as_micros() as f64 / total_us * 100.0
467 } else {
468 0.0
469 };
470 let transfer_pct = if total_us > 0.0 {
471 (self.totals[H2D].as_micros() + self.totals[GRAD_H2D].as_micros()) as f64 / total_us
472 * 100.0
473 } else {
474 0.0
475 };
476 let bottleneck = if transfer_pct > 30.0 {
477 "transfer"
478 } else if forward_pct < 20.0 {
479 "launch"
480 } else {
481 "memory_bw"
482 };
483
484 let mut ops_json = Vec::new();
486 let mut total_op_us = 0u128;
487 for i in 0..NUM_OPS {
488 let t = self.op_totals[i];
489 let us = t.as_micros();
490 total_op_us += us;
491 if us > 0 {
492 let ms = us as f64 / 1000.0;
493 ops_json.push(format!("\"{}\":{:.1}", OP_NAMES[i], ms));
494 }
495 }
496
497 let gemm_us = self.op_totals[OP_QKV_GEMM].as_micros()
499 + self.op_totals[OP_O_PROJ].as_micros()
500 + self.op_totals[OP_GATE_UP_GEMM].as_micros()
501 + self.op_totals[OP_DOWN_GEMM].as_micros();
502 let gemm_bwd_us = self.op_totals[OP_QKV_BWD].as_micros()
503 + self.op_totals[OP_GATE_UP_BWD].as_micros()
504 + self.op_totals[OP_DOWN_BWD].as_micros();
505 let total_gemm_us = gemm_us + gemm_bwd_us;
506 let gemm_pct =
507 if total_op_us > 0 { total_gemm_us as f64 / total_op_us as f64 * 100.0 } else { 0.0 };
508
509 eprintln!(
510 "{{\"_profiler\":\"step_profiler_v2\",\"steps\":{},\"avg_step_ms\":{:.2},\"wall_coverage\":{:.3},\"bottleneck\":\"{}\",\"gemm_pct\":{:.1},\"phases\":{{{}}},\"per_layer\":[{}],\"ops\":{{{}}}}}",
511 self.step_count,
512 avg_step_ms,
513 wall_coverage,
514 bottleneck,
515 gemm_pct,
516 phases.join(","),
517 layers_json.join(","),
518 ops_json.join(","),
519 );
520 }
521
522 pub fn record_layer_times(&mut self, fwd_us: &[u64], bwd_us: &[u64]) {
525 if !self.enabled {
526 return;
527 }
528 for (i, &us) in fwd_us.iter().enumerate() {
529 if i < MAX_LAYERS && us > 0 {
530 self.layer_fwd_totals[i] += std::time::Duration::from_micros(us);
531 if i >= self.num_layers {
532 self.num_layers = i + 1;
533 }
534 }
535 }
536 for (i, &us) in bwd_us.iter().enumerate() {
537 if i < MAX_LAYERS && us > 0 {
538 self.layer_bwd_totals[i] += std::time::Duration::from_micros(us);
539 if i >= self.num_layers {
540 self.num_layers = i + 1;
541 }
542 }
543 }
544 }
545
546 pub const EMBED: usize = EMBED;
548
549 pub const OP_RMSNORM_ATTN: usize = OP_RMSNORM_ATTN;
551 pub const OP_QKV_GEMM: usize = OP_QKV_GEMM;
552 pub const OP_ATTENTION: usize = OP_ATTENTION;
553 pub const OP_O_PROJ: usize = OP_O_PROJ;
554 pub const OP_RMSNORM_FFN: usize = OP_RMSNORM_FFN;
555 pub const OP_GATE_UP_GEMM: usize = OP_GATE_UP_GEMM;
556 pub const OP_SILU: usize = OP_SILU;
557 pub const OP_DOWN_GEMM: usize = OP_DOWN_GEMM;
558 pub const OP_LORA: usize = OP_LORA;
559 pub const OP_DOWN_BWD: usize = OP_DOWN_BWD;
560 pub const OP_SWIGLU_BWD: usize = OP_SWIGLU_BWD;
561 pub const OP_GATE_UP_BWD: usize = OP_GATE_UP_BWD;
562 pub const OP_ATTN_BWD: usize = OP_ATTN_BWD;
563 pub const OP_QKV_BWD: usize = OP_QKV_BWD;
564 pub const OP_NORM_BWD: usize = OP_NORM_BWD;
565 pub const OP_LORA_BWD: usize = OP_LORA_BWD;
566 pub const H2D: usize = H2D;
567 pub const FORWARD: usize = FORWARD;
568 pub const NORM_LM: usize = NORM_LM;
569 pub const LOSS: usize = LOSS;
570 pub const GRAD_H2D: usize = GRAD_H2D;
571 pub const LM_BWD: usize = LM_BWD;
572 pub const NORM_BWD: usize = NORM_BWD;
573 pub const BLK_BWD: usize = BLK_BWD;
574 pub const EMBED_BWD: usize = EMBED_BWD;
575 pub const OPT: usize = OPT;
576}
577
578#[cfg(test)]
579#[allow(clippy::unwrap_used)]
580mod tests {
581 use super::*;
582
583 #[test]
584 fn test_disabled_profiler_is_noop() {
585 let mut p = StepProfiler::disabled();
586 p.begin_step();
587 p.begin(StepProfiler::EMBED);
588 p.end(StepProfiler::EMBED);
589 p.finish_step();
590 assert_eq!(p.step_count(), 0);
591 }
592
593 #[test]
594 fn test_enabled_profiler_counts_steps() {
595 let mut p = StepProfiler::new(true, 0);
596 p.begin_step();
597 p.begin(StepProfiler::EMBED);
598 std::thread::sleep(Duration::from_millis(1));
599 p.end(StepProfiler::EMBED);
600 p.finish_step();
601 assert_eq!(p.step_count(), 1);
602 assert!(p.totals[EMBED] >= Duration::from_micros(500));
603 }
604
605 #[test]
606 fn test_multiple_steps_accumulate() {
607 let mut p = StepProfiler::new(true, 0);
608 for _ in 0..3 {
609 p.begin_step();
610 p.begin(StepProfiler::LOSS);
611 std::thread::sleep(Duration::from_millis(1));
612 p.end(StepProfiler::LOSS);
613 p.finish_step();
614 }
615 assert_eq!(p.step_count(), 3);
616 assert!(p.totals[LOSS] >= Duration::from_millis(3));
617 }
618
619 #[test]
620 fn test_unaccounted_time_captured() {
621 let mut p = StepProfiler::new(true, 0);
622 p.begin_step();
623 std::thread::sleep(Duration::from_millis(2));
625 p.finish_step();
626 assert!(p.total_wall >= Duration::from_millis(2));
627 for i in 0..NUM_PHASES {
629 assert_eq!(p.totals[i], Duration::ZERO);
630 }
631 }
632
633 #[test]
634 fn test_is_enabled() {
635 let p = StepProfiler::new(true, 0);
636 assert!(p.is_enabled());
637 let p = StepProfiler::disabled();
638 assert!(!p.is_enabled());
639 }
640
641 #[test]
642 fn test_step_count_starts_at_zero() {
643 let p = StepProfiler::new(true, 0);
644 assert_eq!(p.step_count(), 0);
645 }
646
647 #[test]
648 fn test_disabled_profiler_step_count_stays_zero() {
649 let mut p = StepProfiler::disabled();
650 for _ in 0..5 {
651 p.begin_step();
652 p.begin(StepProfiler::EMBED);
653 p.end(StepProfiler::EMBED);
654 p.finish_step();
655 }
656 assert_eq!(p.step_count(), 0);
657 }
658
659 #[test]
660 fn test_all_phase_constants() {
661 assert_eq!(StepProfiler::EMBED, 0);
663 assert_eq!(StepProfiler::H2D, 1);
664 assert_eq!(StepProfiler::FORWARD, 2);
665 assert_eq!(StepProfiler::NORM_LM, 3);
666 assert_eq!(StepProfiler::LOSS, 4);
667 assert_eq!(StepProfiler::GRAD_H2D, 5);
668 assert_eq!(StepProfiler::LM_BWD, 6);
669 assert_eq!(StepProfiler::NORM_BWD, 7);
670 assert_eq!(StepProfiler::BLK_BWD, 8);
671 assert_eq!(StepProfiler::EMBED_BWD, 9);
672 assert_eq!(StepProfiler::OPT, 10);
673 }
674
675 #[test]
676 fn test_phase_names_count() {
677 assert_eq!(PHASE_NAMES.len(), NUM_PHASES);
678 assert_eq!(NUM_PHASES, 11);
679 }
680
681 #[test]
682 fn test_multiple_phases_in_one_step() {
683 let mut p = StepProfiler::new(true, 0);
684 p.begin_step();
685
686 p.begin(StepProfiler::EMBED);
687 std::thread::sleep(Duration::from_millis(1));
688 p.end(StepProfiler::EMBED);
689
690 p.begin(StepProfiler::FORWARD);
691 std::thread::sleep(Duration::from_millis(1));
692 p.end(StepProfiler::FORWARD);
693
694 p.begin(StepProfiler::LOSS);
695 std::thread::sleep(Duration::from_millis(1));
696 p.end(StepProfiler::LOSS);
697
698 p.finish_step();
699
700 assert_eq!(p.step_count(), 1);
701 assert!(p.totals[EMBED] > Duration::ZERO);
702 assert!(p.totals[FORWARD] > Duration::ZERO);
703 assert!(p.totals[LOSS] > Duration::ZERO);
704 assert_eq!(p.totals[H2D], Duration::ZERO);
706 assert_eq!(p.totals[NORM_LM], Duration::ZERO);
707 }
708
709 #[test]
710 fn test_end_without_begin_is_noop() {
711 let mut p = StepProfiler::new(true, 0);
712 p.begin_step();
713 p.end(StepProfiler::EMBED);
715 p.finish_step();
716 assert_eq!(p.totals[EMBED], Duration::ZERO);
717 }
718
719 #[test]
720 fn test_print_report_empty_is_noop() {
721 let p = StepProfiler::new(true, 0);
722 p.print_report();
724 assert_eq!(p.step_count(), 0);
725 }
726
727 #[test]
728 fn test_print_report_with_data() {
729 let mut p = StepProfiler::new(true, 0);
730 for _ in 0..3 {
731 p.begin_step();
732 p.begin(StepProfiler::EMBED);
733 std::thread::sleep(Duration::from_millis(1));
734 p.end(StepProfiler::EMBED);
735 p.finish_step();
736 }
737 p.print_report();
739 assert_eq!(p.step_count(), 3);
740 }
741
742 #[test]
743 fn test_report_interval_auto_print() {
744 let mut p = StepProfiler::new(true, 2); for _ in 0..4 {
746 p.begin_step();
747 p.begin(StepProfiler::LOSS);
748 p.end(StepProfiler::LOSS);
749 p.finish_step();
750 }
751 assert_eq!(p.step_count(), 4);
752 }
754
755 #[test]
756 fn test_step_durations_tracked() {
757 let mut p = StepProfiler::new(true, 0);
758 for _ in 0..5 {
759 p.begin_step();
760 std::thread::sleep(Duration::from_millis(1));
761 p.finish_step();
762 }
763 assert_eq!(p.step_durations.len(), 5);
764 for d in &p.step_durations {
765 assert!(*d >= Duration::from_micros(500));
766 }
767 }
768
769 #[test]
770 fn test_total_wall_accumulates() {
771 let mut p = StepProfiler::new(true, 0);
772 p.begin_step();
773 std::thread::sleep(Duration::from_millis(2));
774 p.finish_step();
775
776 p.begin_step();
777 std::thread::sleep(Duration::from_millis(2));
778 p.finish_step();
779
780 assert!(p.total_wall >= Duration::from_millis(4));
781 }
782
783 #[test]
784 fn test_percentiles_with_enough_steps() {
785 let mut p = StepProfiler::new(true, 0);
786 for _ in 0..20 {
787 p.begin_step();
788 std::thread::sleep(Duration::from_millis(1));
789 p.finish_step();
790 }
791 p.print_report();
793 assert_eq!(p.step_durations.len(), 20);
794 }
795
796 #[test]
797 fn test_finish_step_without_begin_step() {
798 let mut p = StepProfiler::new(true, 0);
799 p.finish_step();
801 assert_eq!(p.step_count(), 1);
803 assert_eq!(p.total_wall, Duration::ZERO);
804 }
805
806 #[test]
809 fn test_per_layer_fwd_timing() {
810 let mut p = StepProfiler::new(true, 0);
811 p.begin_step();
812 for layer in 0..3 {
813 p.begin_layer();
814 std::thread::sleep(Duration::from_millis(1));
815 p.end_layer_fwd(layer);
816 }
817 p.finish_step();
818 assert_eq!(p.num_layers, 3);
819 for layer in 0..3 {
820 assert!(p.layer_fwd_totals[layer] >= Duration::from_micros(500));
821 }
822 }
823
824 #[test]
825 fn test_per_layer_bwd_timing() {
826 let mut p = StepProfiler::new(true, 0);
827 p.begin_step();
828 for layer in (0..4).rev() {
829 p.begin_layer();
830 std::thread::sleep(Duration::from_millis(1));
831 p.end_layer_bwd(layer);
832 }
833 p.finish_step();
834 assert_eq!(p.num_layers, 4);
835 for layer in 0..4 {
836 assert!(p.layer_bwd_totals[layer] >= Duration::from_micros(500));
837 }
838 }
839
840 #[test]
841 fn test_per_layer_accumulates_across_steps() {
842 let mut p = StepProfiler::new(true, 0);
843 for _ in 0..3 {
844 p.begin_step();
845 p.begin_layer();
846 std::thread::sleep(Duration::from_millis(1));
847 p.end_layer_fwd(0);
848 p.finish_step();
849 }
850 assert!(p.layer_fwd_totals[0] >= Duration::from_millis(3));
851 }
852
853 #[test]
854 fn test_per_layer_disabled_is_noop() {
855 let mut p = StepProfiler::disabled();
856 p.begin_layer();
857 p.end_layer_fwd(0);
858 p.begin_layer();
859 p.end_layer_bwd(0);
860 assert_eq!(p.num_layers, 0);
861 }
862
863 #[test]
864 fn test_per_layer_out_of_bounds_ignored() {
865 let mut p = StepProfiler::new(true, 0);
866 p.begin_step();
867 p.begin_layer();
868 p.end_layer_fwd(MAX_LAYERS + 1);
870 p.finish_step();
871 assert_eq!(p.num_layers, 0);
872 }
873
874 #[test]
875 fn test_per_layer_report_prints() {
876 let mut p = StepProfiler::new(true, 0);
877 for _ in 0..2 {
878 p.begin_step();
879 for layer in 0..3 {
880 p.begin_layer();
881 std::thread::sleep(Duration::from_millis(1));
882 p.end_layer_fwd(layer);
883 }
884 for layer in (0..3).rev() {
885 p.begin_layer();
886 std::thread::sleep(Duration::from_millis(1));
887 p.end_layer_bwd(layer);
888 }
889 p.finish_step();
890 }
891 p.print_report();
893 assert_eq!(p.num_layers, 3);
894 }
895}