oxify_authz/
profiling.rs

1//! Performance profiling utilities for authorization operations
2//!
3//! This module provides tools to measure and analyze the performance
4//! of authorization checks, helping identify bottlenecks and optimization opportunities.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use oxify_authz::profiling::*;
10//! use std::time::Duration;
11//!
12//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
13//! let profiler = AuthzProfiler::new();
14//!
15//! // Profile a check operation
16//! let result = profiler.profile_async("check_document_viewer", async {
17//!     // Your authorization check here
18//!     Ok::<bool, Box<dyn std::error::Error>>(true)
19//! }).await?;
20//!
21//! // Get profiling statistics
22//! let stats = profiler.get_stats();
23//! println!("Total operations: {}", stats.total_operations);
24//! println!("Average latency: {:?}", stats.avg_latency());
25//! # Ok(())
26//! # }
27//! ```
28
29use serde::{Deserialize, Serialize};
30use std::collections::HashMap;
31use std::future::Future;
32use std::sync::atomic::{AtomicU64, Ordering};
33use std::sync::{Arc, Mutex};
34use std::time::{Duration, Instant};
35
36/// Performance profiler for authorization operations
37#[derive(Clone)]
38pub struct AuthzProfiler {
39    metrics: Arc<Mutex<HashMap<String, OperationMetrics>>>,
40}
41
42/// Metrics for a specific operation
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct OperationMetrics {
45    /// Operation name
46    pub name: String,
47    /// Number of times this operation was called
48    pub call_count: u64,
49    /// Total time spent in this operation
50    pub total_duration_ns: u64,
51    /// Minimum observed latency
52    pub min_latency_ns: u64,
53    /// Maximum observed latency
54    pub max_latency_ns: u64,
55    /// P50 latency approximation (median)
56    pub p50_latency_ns: u64,
57    /// P95 latency approximation
58    pub p95_latency_ns: u64,
59    /// P99 latency approximation
60    pub p99_latency_ns: u64,
61    /// Recent latencies for percentile calculation (last 1000)
62    #[serde(skip)]
63    pub recent_latencies: Vec<u64>,
64}
65
66impl OperationMetrics {
67    fn new(name: String) -> Self {
68        Self {
69            name,
70            call_count: 0,
71            total_duration_ns: 0,
72            min_latency_ns: u64::MAX,
73            max_latency_ns: 0,
74            p50_latency_ns: 0,
75            p95_latency_ns: 0,
76            p99_latency_ns: 0,
77            recent_latencies: Vec::with_capacity(1000),
78        }
79    }
80
81    /// Record a new measurement
82    fn record(&mut self, duration_ns: u64) {
83        self.call_count += 1;
84        self.total_duration_ns += duration_ns;
85        self.min_latency_ns = self.min_latency_ns.min(duration_ns);
86        self.max_latency_ns = self.max_latency_ns.max(duration_ns);
87
88        // Keep last 1000 measurements for percentile calculation
89        if self.recent_latencies.len() >= 1000 {
90            self.recent_latencies.remove(0);
91        }
92        self.recent_latencies.push(duration_ns);
93
94        // Update percentiles
95        self.update_percentiles();
96    }
97
98    /// Update percentile calculations based on recent latencies
99    fn update_percentiles(&mut self) {
100        if self.recent_latencies.is_empty() {
101            return;
102        }
103
104        let mut sorted = self.recent_latencies.clone();
105        sorted.sort_unstable();
106
107        let len = sorted.len();
108        self.p50_latency_ns = sorted[len / 2];
109        self.p95_latency_ns = sorted[(len as f64 * 0.95) as usize];
110        self.p99_latency_ns = sorted[(len as f64 * 0.99) as usize];
111    }
112
113    /// Get average latency
114    pub fn avg_latency_ns(&self) -> u64 {
115        if self.call_count == 0 {
116            0
117        } else {
118            self.total_duration_ns / self.call_count
119        }
120    }
121
122    /// Get average latency as Duration
123    pub fn avg_latency(&self) -> Duration {
124        Duration::from_nanos(self.avg_latency_ns())
125    }
126
127    /// Check if operation meets performance target
128    pub fn meets_target(&self, target_p99_ns: u64) -> bool {
129        self.p99_latency_ns <= target_p99_ns
130    }
131}
132
133impl AuthzProfiler {
134    /// Create a new profiler
135    pub fn new() -> Self {
136        Self {
137            metrics: Arc::new(Mutex::new(HashMap::new())),
138        }
139    }
140
141    /// Profile a synchronous operation
142    pub fn profile<F, T>(&self, operation_name: &str, f: F) -> T
143    where
144        F: FnOnce() -> T,
145    {
146        let start = Instant::now();
147        let result = f();
148        let duration = start.elapsed();
149
150        self.record_measurement(operation_name, duration);
151        result
152    }
153
154    /// Profile an asynchronous operation
155    pub async fn profile_async<F, T>(&self, operation_name: &str, fut: F) -> T
156    where
157        F: Future<Output = T>,
158    {
159        let start = Instant::now();
160        let result = fut.await;
161        let duration = start.elapsed();
162
163        self.record_measurement(operation_name, duration);
164        result
165    }
166
167    /// Record a measurement manually
168    pub fn record_measurement(&self, operation_name: &str, duration: Duration) {
169        let duration_ns = duration.as_nanos() as u64;
170
171        let mut metrics = self.metrics.lock().unwrap();
172        let entry = metrics
173            .entry(operation_name.to_string())
174            .or_insert_with(|| OperationMetrics::new(operation_name.to_string()));
175
176        entry.record(duration_ns);
177    }
178
179    /// Get metrics for a specific operation
180    pub fn get_operation_metrics(&self, operation_name: &str) -> Option<OperationMetrics> {
181        self.metrics.lock().unwrap().get(operation_name).cloned()
182    }
183
184    /// Get all profiling statistics
185    pub fn get_stats(&self) -> ProfilingStats {
186        let metrics = self.metrics.lock().unwrap();
187
188        let total_operations: u64 = metrics.values().map(|m| m.call_count).sum();
189        let total_duration_ns: u64 = metrics.values().map(|m| m.total_duration_ns).sum();
190
191        let operations: Vec<OperationMetrics> = metrics.values().cloned().collect();
192
193        ProfilingStats {
194            total_operations,
195            total_duration_ns,
196            operations,
197        }
198    }
199
200    /// Reset all profiling data
201    pub fn reset(&self) {
202        self.metrics.lock().unwrap().clear();
203    }
204
205    /// Generate a performance report
206    pub fn generate_report(&self) -> String {
207        let stats = self.get_stats();
208        let mut report = String::new();
209
210        report.push_str("=== Authorization Performance Report ===\n\n");
211        report.push_str(&format!("Total Operations: {}\n", stats.total_operations));
212        report.push_str(&format!(
213            "Total Time: {:.2}ms\n\n",
214            stats.total_duration_ns as f64 / 1_000_000.0
215        ));
216
217        report.push_str("Operation Breakdown:\n");
218        report.push_str(&format!(
219            "{:<30} {:>10} {:>12} {:>12} {:>12} {:>12}\n",
220            "Operation", "Calls", "Avg (μs)", "P50 (μs)", "P95 (μs)", "P99 (μs)"
221        ));
222        report.push_str(&"-".repeat(100));
223        report.push('\n');
224
225        let mut sorted_ops = stats.operations.clone();
226        sorted_ops.sort_by_key(|m| std::cmp::Reverse(m.call_count));
227
228        for metric in sorted_ops {
229            report.push_str(&format!(
230                "{:<30} {:>10} {:>12.2} {:>12.2} {:>12.2} {:>12.2}\n",
231                metric.name,
232                metric.call_count,
233                metric.avg_latency_ns() as f64 / 1_000.0,
234                metric.p50_latency_ns as f64 / 1_000.0,
235                metric.p95_latency_ns as f64 / 1_000.0,
236                metric.p99_latency_ns as f64 / 1_000.0,
237            ));
238        }
239
240        report
241    }
242
243    /// Export metrics as JSON
244    pub fn export_json(&self) -> serde_json::Result<String> {
245        let stats = self.get_stats();
246        serde_json::to_string_pretty(&stats)
247    }
248}
249
250impl Default for AuthzProfiler {
251    fn default() -> Self {
252        Self::new()
253    }
254}
255
256/// Overall profiling statistics
257#[derive(Debug, Clone, Serialize, Deserialize)]
258pub struct ProfilingStats {
259    /// Total number of operations profiled
260    pub total_operations: u64,
261    /// Total time spent in all operations (nanoseconds)
262    pub total_duration_ns: u64,
263    /// Metrics for individual operations
264    pub operations: Vec<OperationMetrics>,
265}
266
267impl ProfilingStats {
268    /// Get average latency across all operations
269    pub fn avg_latency(&self) -> Duration {
270        if self.total_operations == 0 {
271            Duration::from_nanos(0)
272        } else {
273            Duration::from_nanos(self.total_duration_ns / self.total_operations)
274        }
275    }
276
277    /// Find slowest operation
278    pub fn slowest_operation(&self) -> Option<&OperationMetrics> {
279        self.operations.iter().max_by_key(|m| m.p99_latency_ns)
280    }
281
282    /// Find most called operation
283    pub fn most_called_operation(&self) -> Option<&OperationMetrics> {
284        self.operations.iter().max_by_key(|m| m.call_count)
285    }
286}
287
288/// Lightweight performance counter for hot paths
289#[derive(Clone)]
290pub struct PerfCounter {
291    count: Arc<AtomicU64>,
292    total_ns: Arc<AtomicU64>,
293}
294
295impl PerfCounter {
296    /// Create a new performance counter
297    pub fn new() -> Self {
298        Self {
299            count: Arc::new(AtomicU64::new(0)),
300            total_ns: Arc::new(AtomicU64::new(0)),
301        }
302    }
303
304    /// Record a measurement
305    pub fn record(&self, duration: Duration) {
306        self.count.fetch_add(1, Ordering::Relaxed);
307        self.total_ns
308            .fetch_add(duration.as_nanos() as u64, Ordering::Relaxed);
309    }
310
311    /// Get count
312    pub fn count(&self) -> u64 {
313        self.count.load(Ordering::Relaxed)
314    }
315
316    /// Get average latency in nanoseconds
317    pub fn avg_ns(&self) -> u64 {
318        let count = self.count();
319        if count == 0 {
320            0
321        } else {
322            self.total_ns.load(Ordering::Relaxed) / count
323        }
324    }
325
326    /// Get average latency as Duration
327    pub fn avg_duration(&self) -> Duration {
328        Duration::from_nanos(self.avg_ns())
329    }
330
331    /// Reset the counter
332    pub fn reset(&self) {
333        self.count.store(0, Ordering::Relaxed);
334        self.total_ns.store(0, Ordering::Relaxed);
335    }
336}
337
338impl Default for PerfCounter {
339    fn default() -> Self {
340        Self::new()
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use std::thread;
348    use std::time::Duration;
349
350    #[test]
351    fn test_profiler_basic() {
352        let profiler = AuthzProfiler::new();
353
354        // Profile some operations
355        for _ in 0..10 {
356            profiler.profile("test_op", || {
357                thread::sleep(Duration::from_millis(1));
358            });
359        }
360
361        let stats = profiler.get_stats();
362        assert_eq!(stats.total_operations, 10);
363
364        let metrics = profiler.get_operation_metrics("test_op").unwrap();
365        assert_eq!(metrics.call_count, 10);
366        assert!(metrics.avg_latency().as_millis() >= 1);
367    }
368
369    #[tokio::test]
370    async fn test_profiler_async() {
371        let profiler = AuthzProfiler::new();
372
373        // Profile async operations
374        for _ in 0..5 {
375            profiler
376                .profile_async("async_op", async {
377                    tokio::time::sleep(Duration::from_millis(2)).await;
378                })
379                .await;
380        }
381
382        let metrics = profiler.get_operation_metrics("async_op").unwrap();
383        assert_eq!(metrics.call_count, 5);
384        assert!(metrics.avg_latency().as_millis() >= 2);
385    }
386
387    #[test]
388    fn test_operation_metrics() {
389        let mut metrics = OperationMetrics::new("test".to_string());
390
391        // Record some measurements
392        metrics.record(1_000_000); // 1ms
393        metrics.record(2_000_000); // 2ms
394        metrics.record(3_000_000); // 3ms
395        metrics.record(10_000_000); // 10ms
396
397        assert_eq!(metrics.call_count, 4);
398        assert_eq!(metrics.min_latency_ns, 1_000_000);
399        assert_eq!(metrics.max_latency_ns, 10_000_000);
400        assert_eq!(metrics.avg_latency_ns(), 4_000_000);
401    }
402
403    #[test]
404    fn test_perf_counter() {
405        let counter = PerfCounter::new();
406
407        for _ in 0..100 {
408            counter.record(Duration::from_micros(100));
409        }
410
411        assert_eq!(counter.count(), 100);
412        assert_eq!(counter.avg_ns(), 100_000);
413        assert_eq!(counter.avg_duration(), Duration::from_micros(100));
414    }
415
416    #[test]
417    fn test_report_generation() {
418        let profiler = AuthzProfiler::new();
419
420        profiler.profile("check_permission", || {
421            thread::sleep(Duration::from_micros(100));
422        });
423
424        profiler.profile("write_tuple", || {
425            thread::sleep(Duration::from_micros(200));
426        });
427
428        let report = profiler.generate_report();
429        assert!(report.contains("Authorization Performance Report"));
430        assert!(report.contains("check_permission"));
431        assert!(report.contains("write_tuple"));
432    }
433
434    #[test]
435    fn test_json_export() {
436        let profiler = AuthzProfiler::new();
437
438        profiler.profile("test_op", || {});
439
440        let json = profiler.export_json().unwrap();
441        assert!(json.contains("total_operations"));
442        assert!(json.contains("test_op"));
443    }
444
445    #[test]
446    fn test_percentile_calculation() {
447        let mut metrics = OperationMetrics::new("test".to_string());
448
449        // Record measurements with known distribution
450        for i in 1..=100 {
451            metrics.record(i * 1_000_000); // 1ms to 100ms
452        }
453
454        assert!(metrics.p50_latency_ns > 40_000_000); // ~50ms
455        assert!(metrics.p95_latency_ns > 90_000_000); // ~95ms
456        assert!(metrics.p99_latency_ns > 95_000_000); // ~99ms
457    }
458
459    #[test]
460    fn test_profiler_reset() {
461        let profiler = AuthzProfiler::new();
462
463        profiler.profile("test", || {});
464        assert_eq!(profiler.get_stats().total_operations, 1);
465
466        profiler.reset();
467        assert_eq!(profiler.get_stats().total_operations, 0);
468    }
469
470    #[test]
471    fn test_meets_target() {
472        let mut metrics = OperationMetrics::new("test".to_string());
473
474        for _ in 0..100 {
475            metrics.record(50_000); // 50μs
476        }
477
478        // Should meet 100μs target
479        assert!(metrics.meets_target(100_000));
480
481        // Should not meet 10μs target
482        assert!(!metrics.meets_target(10_000));
483    }
484
485    #[test]
486    fn test_stats_slowest_operation() {
487        let profiler = AuthzProfiler::new();
488
489        profiler.record_measurement("fast_op", Duration::from_micros(10));
490        profiler.record_measurement("slow_op", Duration::from_millis(100));
491
492        let stats = profiler.get_stats();
493        let slowest = stats.slowest_operation().unwrap();
494        assert_eq!(slowest.name, "slow_op");
495    }
496
497    #[test]
498    fn test_stats_most_called() {
499        let profiler = AuthzProfiler::new();
500
501        for _ in 0..100 {
502            profiler.record_measurement("frequent_op", Duration::from_micros(1));
503        }
504        profiler.record_measurement("rare_op", Duration::from_micros(1));
505
506        let stats = profiler.get_stats();
507        let most_called = stats.most_called_operation().unwrap();
508        assert_eq!(most_called.name, "frequent_op");
509        assert_eq!(most_called.call_count, 100);
510    }
511}