1use std::collections::VecDeque;
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone)]
14pub struct InferenceMetrics {
15 pub total_steps: u64,
17 pub total_time_us: u64,
19 recent_latencies: VecDeque<u64>,
21 window_size: usize,
23 pub peak_latency_us: u64,
25 pub min_latency_us: u64,
27 start_time: Instant,
29}
30
31impl InferenceMetrics {
32 pub fn new() -> Self {
34 Self::with_window_size(100)
35 }
36
37 pub fn with_window_size(window_size: usize) -> Self {
39 Self {
40 total_steps: 0,
41 total_time_us: 0,
42 recent_latencies: VecDeque::with_capacity(window_size),
43 window_size,
44 peak_latency_us: 0,
45 min_latency_us: u64::MAX,
46 start_time: Instant::now(),
47 }
48 }
49
50 pub fn record_step(&mut self, latency_us: u64) {
52 self.total_steps += 1;
53 self.total_time_us += latency_us;
54
55 self.peak_latency_us = self.peak_latency_us.max(latency_us);
57 if latency_us > 0 {
58 self.min_latency_us = self.min_latency_us.min(latency_us);
59 }
60
61 if self.recent_latencies.len() >= self.window_size {
63 self.recent_latencies.pop_front();
64 }
65 self.recent_latencies.push_back(latency_us);
66 }
67
68 pub fn avg_latency_us(&self) -> f64 {
70 if self.total_steps == 0 {
71 0.0
72 } else {
73 self.total_time_us as f64 / self.total_steps as f64
74 }
75 }
76
77 pub fn recent_avg_latency_us(&self) -> f64 {
79 if self.recent_latencies.is_empty() {
80 0.0
81 } else {
82 let sum: u64 = self.recent_latencies.iter().sum();
83 sum as f64 / self.recent_latencies.len() as f64
84 }
85 }
86
87 pub fn throughput(&self) -> f64 {
89 let elapsed_secs = self.start_time.elapsed().as_secs_f64();
90 if elapsed_secs == 0.0 {
91 0.0
92 } else {
93 self.total_steps as f64 / elapsed_secs
94 }
95 }
96
97 pub fn percentiles(&self) -> (u64, u64, u64) {
99 if self.recent_latencies.is_empty() {
100 return (0, 0, 0);
101 }
102
103 let mut sorted: Vec<u64> = self.recent_latencies.iter().copied().collect();
104 sorted.sort_unstable();
105
106 let p50_idx = (sorted.len() as f64 * 0.50) as usize;
107 let p95_idx = (sorted.len() as f64 * 0.95) as usize;
108 let p99_idx = (sorted.len() as f64 * 0.99) as usize;
109
110 (
111 sorted.get(p50_idx).copied().unwrap_or(0),
112 sorted.get(p95_idx).copied().unwrap_or(0),
113 sorted.get(p99_idx).copied().unwrap_or(0),
114 )
115 }
116
117 pub fn reset(&mut self) {
119 self.total_steps = 0;
120 self.total_time_us = 0;
121 self.recent_latencies.clear();
122 self.peak_latency_us = 0;
123 self.min_latency_us = u64::MAX;
124 self.start_time = Instant::now();
125 }
126
127 pub fn summary(&self) -> MetricsSummary {
129 let (p50, p95, p99) = self.percentiles();
130
131 MetricsSummary {
132 total_steps: self.total_steps,
133 avg_latency_us: self.avg_latency_us(),
134 recent_avg_latency_us: self.recent_avg_latency_us(),
135 peak_latency_us: self.peak_latency_us,
136 min_latency_us: if self.min_latency_us == u64::MAX {
137 0
138 } else {
139 self.min_latency_us
140 },
141 throughput_per_sec: self.throughput(),
142 p50_latency_us: p50,
143 p95_latency_us: p95,
144 p99_latency_us: p99,
145 uptime_secs: self.start_time.elapsed().as_secs_f64(),
146 }
147 }
148}
149
150impl Default for InferenceMetrics {
151 fn default() -> Self {
152 Self::new()
153 }
154}
155
156#[derive(Debug, Clone)]
158pub struct MetricsSummary {
159 pub total_steps: u64,
160 pub avg_latency_us: f64,
161 pub recent_avg_latency_us: f64,
162 pub peak_latency_us: u64,
163 pub min_latency_us: u64,
164 pub throughput_per_sec: f64,
165 pub p50_latency_us: u64,
166 pub p95_latency_us: u64,
167 pub p99_latency_us: u64,
168 pub uptime_secs: f64,
169}
170
171impl MetricsSummary {
172 pub fn print_report(&self) {
174 println!("=== Inference Performance Report ===");
175 println!("Total steps: {}", self.total_steps);
176 println!("Uptime: {:.2}s", self.uptime_secs);
177 println!("Throughput: {:.2} steps/sec", self.throughput_per_sec);
178 println!("\nLatency (microseconds):");
179 println!(" Average: {:.2} µs", self.avg_latency_us);
180 println!(" Recent avg: {:.2} µs", self.recent_avg_latency_us);
181 println!(" Min: {} µs", self.min_latency_us);
182 println!(" Max: {} µs", self.peak_latency_us);
183 println!("\nPercentiles:");
184 println!(" P50: {} µs", self.p50_latency_us);
185 println!(" P95: {} µs", self.p95_latency_us);
186 println!(" P99: {} µs", self.p99_latency_us);
187 println!("=====================================");
188 }
189}
190
191pub struct Timer {
193 start: Instant,
194}
195
196impl Timer {
197 pub fn start() -> Self {
199 Self {
200 start: Instant::now(),
201 }
202 }
203
204 pub fn elapsed_us(&self) -> u64 {
206 self.start.elapsed().as_micros() as u64
207 }
208
209 pub fn elapsed_ms(&self) -> u64 {
211 self.start.elapsed().as_millis() as u64
212 }
213
214 pub fn elapsed(&self) -> Duration {
216 self.start.elapsed()
217 }
218}
219
220#[derive(Debug, Clone)]
222pub struct InferenceProfiler {
223 pub tokenization_us: u64,
225 pub forward_pass_us: u64,
227 pub sampling_us: u64,
229 pub constraints_us: u64,
231 pub detokenization_us: u64,
233 pub step_count: u64,
235}
236
237impl InferenceProfiler {
238 pub fn new() -> Self {
240 Self {
241 tokenization_us: 0,
242 forward_pass_us: 0,
243 sampling_us: 0,
244 constraints_us: 0,
245 detokenization_us: 0,
246 step_count: 0,
247 }
248 }
249
250 pub fn record_tokenization(&mut self, duration_us: u64) {
252 self.tokenization_us += duration_us;
253 }
254
255 pub fn record_forward_pass(&mut self, duration_us: u64) {
257 self.forward_pass_us += duration_us;
258 }
259
260 pub fn record_sampling(&mut self, duration_us: u64) {
262 self.sampling_us += duration_us;
263 }
264
265 pub fn record_constraints(&mut self, duration_us: u64) {
267 self.constraints_us += duration_us;
268 }
269
270 pub fn record_detokenization(&mut self, duration_us: u64) {
272 self.detokenization_us += duration_us;
273 }
274
275 pub fn increment_step(&mut self) {
277 self.step_count += 1;
278 }
279
280 pub fn total_time_us(&self) -> u64 {
282 self.tokenization_us
283 + self.forward_pass_us
284 + self.sampling_us
285 + self.constraints_us
286 + self.detokenization_us
287 }
288
289 pub fn breakdown(&self) -> ProfileBreakdown {
291 let total = self.total_time_us() as f64;
292 if total == 0.0 {
293 return ProfileBreakdown::default();
294 }
295
296 ProfileBreakdown {
297 tokenization_pct: (self.tokenization_us as f64 / total) * 100.0,
298 forward_pass_pct: (self.forward_pass_us as f64 / total) * 100.0,
299 sampling_pct: (self.sampling_us as f64 / total) * 100.0,
300 constraints_pct: (self.constraints_us as f64 / total) * 100.0,
301 detokenization_pct: (self.detokenization_us as f64 / total) * 100.0,
302 }
303 }
304
305 pub fn print_report(&self) {
307 let breakdown = self.breakdown();
308 let total = self.total_time_us();
309
310 println!("=== Inference Profiling Report ===");
311 println!("Total steps: {}", self.step_count);
312 println!("Total time: {} µs ({:.2} ms)", total, total as f64 / 1000.0);
313 println!("\nTime breakdown:");
314 println!(
315 " Tokenization: {:>8} µs ({:>5.2}%)",
316 self.tokenization_us, breakdown.tokenization_pct
317 );
318 println!(
319 " Forward pass: {:>8} µs ({:>5.2}%)",
320 self.forward_pass_us, breakdown.forward_pass_pct
321 );
322 println!(
323 " Sampling: {:>8} µs ({:>5.2}%)",
324 self.sampling_us, breakdown.sampling_pct
325 );
326 println!(
327 " Constraints: {:>8} µs ({:>5.2}%)",
328 self.constraints_us, breakdown.constraints_pct
329 );
330 println!(
331 " Detokenization: {:>8} µs ({:>5.2}%)",
332 self.detokenization_us, breakdown.detokenization_pct
333 );
334 println!("===================================");
335 }
336
337 pub fn reset(&mut self) {
339 *self = Self::new();
340 }
341}
342
343impl Default for InferenceProfiler {
344 fn default() -> Self {
345 Self::new()
346 }
347}
348
349#[derive(Debug, Clone, Default)]
351pub struct ProfileBreakdown {
352 pub tokenization_pct: f64,
353 pub forward_pass_pct: f64,
354 pub sampling_pct: f64,
355 pub constraints_pct: f64,
356 pub detokenization_pct: f64,
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use std::thread;
363
364 #[test]
365 fn test_metrics_creation() {
366 let metrics = InferenceMetrics::new();
367 assert_eq!(metrics.total_steps, 0);
368 assert_eq!(metrics.total_time_us, 0);
369 }
370
371 #[test]
372 fn test_record_step() {
373 let mut metrics = InferenceMetrics::new();
374 metrics.record_step(1000);
375 metrics.record_step(2000);
376
377 assert_eq!(metrics.total_steps, 2);
378 assert_eq!(metrics.total_time_us, 3000);
379 assert_eq!(metrics.avg_latency_us(), 1500.0);
380 }
381
382 #[test]
383 fn test_percentiles() {
384 let mut metrics = InferenceMetrics::new();
385 for i in 1..=100 {
386 metrics.record_step(i * 100);
387 }
388
389 let (p50, p95, p99) = metrics.percentiles();
390 assert!(p50 > 4000 && p50 < 6000);
391 assert!(p95 > 9000);
392 assert!(p99 > 9800);
393 }
394
395 #[test]
396 fn test_timer() {
397 let timer = Timer::start();
398 thread::sleep(Duration::from_micros(100));
399 let elapsed = timer.elapsed_us();
400
401 assert!(elapsed >= 100);
402 }
403
404 #[test]
405 fn test_profiler() {
406 let mut profiler = InferenceProfiler::new();
407
408 profiler.record_tokenization(100);
409 profiler.record_forward_pass(500);
410 profiler.record_sampling(50);
411 profiler.increment_step();
412
413 assert_eq!(profiler.total_time_us(), 650);
414 assert_eq!(profiler.step_count, 1);
415
416 let breakdown = profiler.breakdown();
417 assert!((breakdown.tokenization_pct - 15.38).abs() < 0.1);
418 assert!((breakdown.forward_pass_pct - 76.92).abs() < 0.1);
419 }
420
421 #[test]
422 fn test_metrics_summary() {
423 let mut metrics = InferenceMetrics::new();
424 metrics.record_step(1000);
425 metrics.record_step(2000);
426 metrics.record_step(1500);
427
428 let summary = metrics.summary();
429 assert_eq!(summary.total_steps, 3);
430 assert_eq!(summary.min_latency_us, 1000);
431 assert_eq!(summary.peak_latency_us, 2000);
432 }
433}