1use std::time::{Duration, Instant};
20
21pub const LOCK_ACQ: usize = 0;
23pub const LEDGER_RD: usize = 1;
24pub const VRAM_QRY: usize = 2;
25pub const LEDGER_WR: usize = 3;
26pub const LOCK_REL: usize = 4;
27pub const WAIT_POLL: usize = 5;
28const NUM_GPU_PHASES: usize = 6;
29
30const GPU_PHASE_NAMES: [&str; NUM_GPU_PHASES] =
31 ["lock_acq", "ledger_rd", "vram_qry", "ledger_wr", "lock_rel", "wait_poll"];
32
33pub struct GpuProfiler {
35 enabled: bool,
36 phase_start: Option<Instant>,
37 totals: [Duration; NUM_GPU_PHASES],
38 counts: [usize; NUM_GPU_PHASES],
39 op_count: usize,
40}
41
42impl GpuProfiler {
43 pub fn new(enabled: bool) -> Self {
45 Self {
46 enabled,
47 phase_start: None,
48 totals: [Duration::ZERO; NUM_GPU_PHASES],
49 counts: [0; NUM_GPU_PHASES],
50 op_count: 0,
51 }
52 }
53
54 pub fn disabled() -> Self {
56 Self::new(false)
57 }
58
59 #[inline]
61 pub fn begin(&mut self, _phase: usize) {
62 if !self.enabled {
63 return;
64 }
65 self.phase_start = Some(Instant::now());
66 }
67
68 #[inline]
70 pub fn end(&mut self, phase: usize) {
71 if !self.enabled {
72 return;
73 }
74 if let Some(start) = self.phase_start.take() {
75 self.totals[phase] += start.elapsed();
76 self.counts[phase] += 1;
77 }
78 }
79
80 #[inline]
82 pub fn span<F, R>(&mut self, phase: usize, f: F) -> R
83 where
84 F: FnOnce() -> R,
85 {
86 if !self.enabled {
87 return f();
88 }
89 self.begin(phase);
90 let result = f();
91 self.end(phase);
92 result
93 }
94
95 pub fn finish_op(&mut self) {
97 if self.enabled {
98 self.op_count += 1;
99 }
100 }
101
102 pub fn is_enabled(&self) -> bool {
104 self.enabled
105 }
106
107 pub fn total_time(&self) -> Duration {
109 self.totals.iter().copied().sum()
110 }
111
112 pub fn report(&self) -> String {
114 if self.op_count == 0 {
115 return String::from("[GpuProfiler] No operations recorded.");
116 }
117
118 let total_us = self.total_time().as_micros() as f64;
119 let mut out = format!(
120 "\n┌─ GPU Sharing Profiler ({} ops, total {:.1} ms) ─┐\n",
121 self.op_count,
122 total_us / 1000.0
123 );
124 out.push_str(&format!(
125 "│ {:>10} │ {:>6} │ {:>8} │ {:>6} │\n",
126 "phase", "count", "total_ms", "pct"
127 ));
128 out.push_str(&format!("│ {:-<10} │ {:-<6} │ {:-<8} │ {:-<6} │\n", "", "", "", ""));
129
130 for i in 0..NUM_GPU_PHASES {
131 let ms = self.totals[i].as_micros() as f64 / 1000.0;
132 let pct = if total_us > 0.0 {
133 self.totals[i].as_micros() as f64 / total_us * 100.0
134 } else {
135 0.0
136 };
137 out.push_str(&format!(
138 "│ {:>10} │ {:>6} │ {:>8.2} │ {:>5.1}% │\n",
139 GPU_PHASE_NAMES[i], self.counts[i], ms, pct
140 ));
141 }
142 out.push_str("└────────────┴────────┴──────────┴────────┘\n");
143 out
144 }
145
146 pub const LOCK_ACQ: usize = LOCK_ACQ;
148 pub const LEDGER_RD: usize = LEDGER_RD;
149 pub const VRAM_QRY: usize = VRAM_QRY;
150 pub const LEDGER_WR: usize = LEDGER_WR;
151 pub const LOCK_REL: usize = LOCK_REL;
152 pub const WAIT_POLL: usize = WAIT_POLL;
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn disabled_profiler_is_noop() {
161 let mut p = GpuProfiler::disabled();
162 p.begin(LOCK_ACQ);
163 p.end(LOCK_ACQ);
164 p.finish_op();
165 assert_eq!(p.op_count, 0);
167 assert!(!p.is_enabled());
168 }
169
170 #[test]
171 fn enabled_profiler_records() {
172 let mut p = GpuProfiler::new(true);
173 p.begin(LOCK_ACQ);
174 std::thread::sleep(Duration::from_millis(1));
175 p.end(LOCK_ACQ);
176 p.finish_op();
177
178 assert!(p.is_enabled());
179 assert_eq!(p.op_count, 1);
180 assert_eq!(p.counts[LOCK_ACQ], 1);
181 assert!(p.totals[LOCK_ACQ] >= Duration::from_micros(500));
182 }
183
184 #[test]
185 fn span_returns_value() {
186 let mut p = GpuProfiler::new(true);
187 let result = p.span(VRAM_QRY, || 42);
188 assert_eq!(result, 42);
189 assert_eq!(p.counts[VRAM_QRY], 1);
190 }
191
192 #[test]
193 fn report_empty_when_no_ops() {
194 let p = GpuProfiler::new(true);
195 let report = p.report();
196 assert!(report.contains("No operations recorded"));
197 }
198
199 #[test]
200 fn report_contains_phase_names() {
201 let mut p = GpuProfiler::new(true);
202 p.begin(LEDGER_RD);
203 p.end(LEDGER_RD);
204 p.finish_op();
205
206 let report = p.report();
207 assert!(report.contains("ledger_rd"));
208 assert!(report.contains("1 ops"));
209 }
210}