kizzasi_inference/
metrics.rs

1//! Performance metrics and profiling for inference
2//!
3//! This module provides tools for monitoring and profiling inference performance:
4//! - Latency tracking
5//! - Throughput measurement
6//! - Resource utilization
7//! - Bottleneck identification
8
9use std::collections::VecDeque;
10use std::time::{Duration, Instant};
11
12/// Performance metrics for inference operations
13#[derive(Debug, Clone)]
14pub struct InferenceMetrics {
15    /// Total number of inference steps
16    pub total_steps: u64,
17    /// Total time spent in inference (microseconds)
18    pub total_time_us: u64,
19    /// Recent latencies (for rolling statistics)
20    recent_latencies: VecDeque<u64>,
21    /// Maximum rolling window size
22    window_size: usize,
23    /// Peak latency observed (microseconds)
24    pub peak_latency_us: u64,
25    /// Minimum latency observed (microseconds)
26    pub min_latency_us: u64,
27    /// Start time
28    start_time: Instant,
29}
30
31impl InferenceMetrics {
32    /// Create a new metrics collector
33    pub fn new() -> Self {
34        Self::with_window_size(100)
35    }
36
37    /// Create with custom window size for rolling statistics
38    pub fn with_window_size(window_size: usize) -> Self {
39        Self {
40            total_steps: 0,
41            total_time_us: 0,
42            recent_latencies: VecDeque::with_capacity(window_size),
43            window_size,
44            peak_latency_us: 0,
45            min_latency_us: u64::MAX,
46            start_time: Instant::now(),
47        }
48    }
49
50    /// Record a new inference step
51    pub fn record_step(&mut self, latency_us: u64) {
52        self.total_steps += 1;
53        self.total_time_us += latency_us;
54
55        // Update peak and min
56        self.peak_latency_us = self.peak_latency_us.max(latency_us);
57        if latency_us > 0 {
58            self.min_latency_us = self.min_latency_us.min(latency_us);
59        }
60
61        // Rolling window
62        if self.recent_latencies.len() >= self.window_size {
63            self.recent_latencies.pop_front();
64        }
65        self.recent_latencies.push_back(latency_us);
66    }
67
68    /// Get average latency (microseconds)
69    pub fn avg_latency_us(&self) -> f64 {
70        if self.total_steps == 0 {
71            0.0
72        } else {
73            self.total_time_us as f64 / self.total_steps as f64
74        }
75    }
76
77    /// Get recent average latency (microseconds)
78    pub fn recent_avg_latency_us(&self) -> f64 {
79        if self.recent_latencies.is_empty() {
80            0.0
81        } else {
82            let sum: u64 = self.recent_latencies.iter().sum();
83            sum as f64 / self.recent_latencies.len() as f64
84        }
85    }
86
87    /// Get throughput (steps per second)
88    pub fn throughput(&self) -> f64 {
89        let elapsed_secs = self.start_time.elapsed().as_secs_f64();
90        if elapsed_secs == 0.0 {
91            0.0
92        } else {
93            self.total_steps as f64 / elapsed_secs
94        }
95    }
96
97    /// Get p50, p95, p99 latencies (microseconds)
98    pub fn percentiles(&self) -> (u64, u64, u64) {
99        if self.recent_latencies.is_empty() {
100            return (0, 0, 0);
101        }
102
103        let mut sorted: Vec<u64> = self.recent_latencies.iter().copied().collect();
104        sorted.sort_unstable();
105
106        let p50_idx = (sorted.len() as f64 * 0.50) as usize;
107        let p95_idx = (sorted.len() as f64 * 0.95) as usize;
108        let p99_idx = (sorted.len() as f64 * 0.99) as usize;
109
110        (
111            sorted.get(p50_idx).copied().unwrap_or(0),
112            sorted.get(p95_idx).copied().unwrap_or(0),
113            sorted.get(p99_idx).copied().unwrap_or(0),
114        )
115    }
116
117    /// Reset all metrics
118    pub fn reset(&mut self) {
119        self.total_steps = 0;
120        self.total_time_us = 0;
121        self.recent_latencies.clear();
122        self.peak_latency_us = 0;
123        self.min_latency_us = u64::MAX;
124        self.start_time = Instant::now();
125    }
126
127    /// Get a summary report
128    pub fn summary(&self) -> MetricsSummary {
129        let (p50, p95, p99) = self.percentiles();
130
131        MetricsSummary {
132            total_steps: self.total_steps,
133            avg_latency_us: self.avg_latency_us(),
134            recent_avg_latency_us: self.recent_avg_latency_us(),
135            peak_latency_us: self.peak_latency_us,
136            min_latency_us: if self.min_latency_us == u64::MAX {
137                0
138            } else {
139                self.min_latency_us
140            },
141            throughput_per_sec: self.throughput(),
142            p50_latency_us: p50,
143            p95_latency_us: p95,
144            p99_latency_us: p99,
145            uptime_secs: self.start_time.elapsed().as_secs_f64(),
146        }
147    }
148}
149
150impl Default for InferenceMetrics {
151    fn default() -> Self {
152        Self::new()
153    }
154}
155
156/// Summary of performance metrics
157#[derive(Debug, Clone)]
158pub struct MetricsSummary {
159    pub total_steps: u64,
160    pub avg_latency_us: f64,
161    pub recent_avg_latency_us: f64,
162    pub peak_latency_us: u64,
163    pub min_latency_us: u64,
164    pub throughput_per_sec: f64,
165    pub p50_latency_us: u64,
166    pub p95_latency_us: u64,
167    pub p99_latency_us: u64,
168    pub uptime_secs: f64,
169}
170
171impl MetricsSummary {
172    /// Print a formatted report
173    pub fn print_report(&self) {
174        println!("=== Inference Performance Report ===");
175        println!("Total steps: {}", self.total_steps);
176        println!("Uptime: {:.2}s", self.uptime_secs);
177        println!("Throughput: {:.2} steps/sec", self.throughput_per_sec);
178        println!("\nLatency (microseconds):");
179        println!("  Average: {:.2} µs", self.avg_latency_us);
180        println!("  Recent avg: {:.2} µs", self.recent_avg_latency_us);
181        println!("  Min: {} µs", self.min_latency_us);
182        println!("  Max: {} µs", self.peak_latency_us);
183        println!("\nPercentiles:");
184        println!("  P50: {} µs", self.p50_latency_us);
185        println!("  P95: {} µs", self.p95_latency_us);
186        println!("  P99: {} µs", self.p99_latency_us);
187        println!("=====================================");
188    }
189}
190
191/// Timer for measuring operation duration
192pub struct Timer {
193    start: Instant,
194}
195
196impl Timer {
197    /// Start a new timer
198    pub fn start() -> Self {
199        Self {
200            start: Instant::now(),
201        }
202    }
203
204    /// Get elapsed time in microseconds
205    pub fn elapsed_us(&self) -> u64 {
206        self.start.elapsed().as_micros() as u64
207    }
208
209    /// Get elapsed time in milliseconds
210    pub fn elapsed_ms(&self) -> u64 {
211        self.start.elapsed().as_millis() as u64
212    }
213
214    /// Get elapsed time as Duration
215    pub fn elapsed(&self) -> Duration {
216        self.start.elapsed()
217    }
218}
219
220/// Profiler for tracking different stages of inference
221#[derive(Debug, Clone)]
222pub struct InferenceProfiler {
223    /// Time spent in tokenization (microseconds)
224    pub tokenization_us: u64,
225    /// Time spent in model forward pass (microseconds)
226    pub forward_pass_us: u64,
227    /// Time spent in sampling (microseconds)
228    pub sampling_us: u64,
229    /// Time spent in constraint enforcement (microseconds)
230    pub constraints_us: u64,
231    /// Time spent in detokenization (microseconds)
232    pub detokenization_us: u64,
233    /// Number of profiled steps
234    pub step_count: u64,
235}
236
237impl InferenceProfiler {
238    /// Create a new profiler
239    pub fn new() -> Self {
240        Self {
241            tokenization_us: 0,
242            forward_pass_us: 0,
243            sampling_us: 0,
244            constraints_us: 0,
245            detokenization_us: 0,
246            step_count: 0,
247        }
248    }
249
250    /// Record tokenization time
251    pub fn record_tokenization(&mut self, duration_us: u64) {
252        self.tokenization_us += duration_us;
253    }
254
255    /// Record forward pass time
256    pub fn record_forward_pass(&mut self, duration_us: u64) {
257        self.forward_pass_us += duration_us;
258    }
259
260    /// Record sampling time
261    pub fn record_sampling(&mut self, duration_us: u64) {
262        self.sampling_us += duration_us;
263    }
264
265    /// Record constraint enforcement time
266    pub fn record_constraints(&mut self, duration_us: u64) {
267        self.constraints_us += duration_us;
268    }
269
270    /// Record detokenization time
271    pub fn record_detokenization(&mut self, duration_us: u64) {
272        self.detokenization_us += duration_us;
273    }
274
275    /// Increment step count
276    pub fn increment_step(&mut self) {
277        self.step_count += 1;
278    }
279
280    /// Get total time across all stages
281    pub fn total_time_us(&self) -> u64 {
282        self.tokenization_us
283            + self.forward_pass_us
284            + self.sampling_us
285            + self.constraints_us
286            + self.detokenization_us
287    }
288
289    /// Get breakdown percentages
290    pub fn breakdown(&self) -> ProfileBreakdown {
291        let total = self.total_time_us() as f64;
292        if total == 0.0 {
293            return ProfileBreakdown::default();
294        }
295
296        ProfileBreakdown {
297            tokenization_pct: (self.tokenization_us as f64 / total) * 100.0,
298            forward_pass_pct: (self.forward_pass_us as f64 / total) * 100.0,
299            sampling_pct: (self.sampling_us as f64 / total) * 100.0,
300            constraints_pct: (self.constraints_us as f64 / total) * 100.0,
301            detokenization_pct: (self.detokenization_us as f64 / total) * 100.0,
302        }
303    }
304
305    /// Print profiling report
306    pub fn print_report(&self) {
307        let breakdown = self.breakdown();
308        let total = self.total_time_us();
309
310        println!("=== Inference Profiling Report ===");
311        println!("Total steps: {}", self.step_count);
312        println!("Total time: {} µs ({:.2} ms)", total, total as f64 / 1000.0);
313        println!("\nTime breakdown:");
314        println!(
315            "  Tokenization:   {:>8} µs ({:>5.2}%)",
316            self.tokenization_us, breakdown.tokenization_pct
317        );
318        println!(
319            "  Forward pass:   {:>8} µs ({:>5.2}%)",
320            self.forward_pass_us, breakdown.forward_pass_pct
321        );
322        println!(
323            "  Sampling:       {:>8} µs ({:>5.2}%)",
324            self.sampling_us, breakdown.sampling_pct
325        );
326        println!(
327            "  Constraints:    {:>8} µs ({:>5.2}%)",
328            self.constraints_us, breakdown.constraints_pct
329        );
330        println!(
331            "  Detokenization: {:>8} µs ({:>5.2}%)",
332            self.detokenization_us, breakdown.detokenization_pct
333        );
334        println!("===================================");
335    }
336
337    /// Reset the profiler
338    pub fn reset(&mut self) {
339        *self = Self::new();
340    }
341}
342
343impl Default for InferenceProfiler {
344    fn default() -> Self {
345        Self::new()
346    }
347}
348
349/// Breakdown of time spent in different stages
350#[derive(Debug, Clone, Default)]
351pub struct ProfileBreakdown {
352    pub tokenization_pct: f64,
353    pub forward_pass_pct: f64,
354    pub sampling_pct: f64,
355    pub constraints_pct: f64,
356    pub detokenization_pct: f64,
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use std::thread;
363
364    #[test]
365    fn test_metrics_creation() {
366        let metrics = InferenceMetrics::new();
367        assert_eq!(metrics.total_steps, 0);
368        assert_eq!(metrics.total_time_us, 0);
369    }
370
371    #[test]
372    fn test_record_step() {
373        let mut metrics = InferenceMetrics::new();
374        metrics.record_step(1000);
375        metrics.record_step(2000);
376
377        assert_eq!(metrics.total_steps, 2);
378        assert_eq!(metrics.total_time_us, 3000);
379        assert_eq!(metrics.avg_latency_us(), 1500.0);
380    }
381
382    #[test]
383    fn test_percentiles() {
384        let mut metrics = InferenceMetrics::new();
385        for i in 1..=100 {
386            metrics.record_step(i * 100);
387        }
388
389        let (p50, p95, p99) = metrics.percentiles();
390        assert!(p50 > 4000 && p50 < 6000);
391        assert!(p95 > 9000);
392        assert!(p99 > 9800);
393    }
394
395    #[test]
396    fn test_timer() {
397        let timer = Timer::start();
398        thread::sleep(Duration::from_micros(100));
399        let elapsed = timer.elapsed_us();
400
401        assert!(elapsed >= 100);
402    }
403
404    #[test]
405    fn test_profiler() {
406        let mut profiler = InferenceProfiler::new();
407
408        profiler.record_tokenization(100);
409        profiler.record_forward_pass(500);
410        profiler.record_sampling(50);
411        profiler.increment_step();
412
413        assert_eq!(profiler.total_time_us(), 650);
414        assert_eq!(profiler.step_count, 1);
415
416        let breakdown = profiler.breakdown();
417        assert!((breakdown.tokenization_pct - 15.38).abs() < 0.1);
418        assert!((breakdown.forward_pass_pct - 76.92).abs() < 0.1);
419    }
420
421    #[test]
422    fn test_metrics_summary() {
423        let mut metrics = InferenceMetrics::new();
424        metrics.record_step(1000);
425        metrics.record_step(2000);
426        metrics.record_step(1500);
427
428        let summary = metrics.summary();
429        assert_eq!(summary.total_steps, 3);
430        assert_eq!(summary.min_latency_us, 1000);
431        assert_eq!(summary.peak_latency_us, 2000);
432    }
433}