prax_query/middleware/
timing.rs

1//! Timing middleware for measuring query execution time.
2
3use super::context::QueryContext;
4use super::types::{BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse};
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Instant;
7
8/// Result of timing a query.
9#[derive(Debug, Clone)]
10pub struct TimingResult {
11    /// Execution time in nanoseconds.
12    pub duration_ns: u64,
13    /// Execution time in microseconds.
14    pub duration_us: u64,
15    /// Execution time in milliseconds.
16    pub duration_ms: u64,
17}
18
19impl TimingResult {
20    /// Create from a duration.
21    pub fn from_nanos(ns: u64) -> Self {
22        Self {
23            duration_ns: ns,
24            duration_us: ns / 1000,
25            duration_ms: ns / 1_000_000,
26        }
27    }
28}
29
30/// Middleware that measures query execution time.
31///
32/// This is a lightweight middleware that only adds timing information
33/// to the response. For more comprehensive metrics, use `MetricsMiddleware`.
34pub struct TimingMiddleware {
35    /// Total execution time in nanoseconds.
36    total_time_ns: AtomicU64,
37    /// Number of queries timed.
38    query_count: AtomicU64,
39}
40
41impl TimingMiddleware {
42    /// Create a new timing middleware.
43    pub fn new() -> Self {
44        Self {
45            total_time_ns: AtomicU64::new(0),
46            query_count: AtomicU64::new(0),
47        }
48    }
49
50    /// Get the total execution time in nanoseconds.
51    pub fn total_time_ns(&self) -> u64 {
52        self.total_time_ns.load(Ordering::Relaxed)
53    }
54
55    /// Get the total execution time in microseconds.
56    pub fn total_time_us(&self) -> u64 {
57        self.total_time_ns() / 1000
58    }
59
60    /// Get the total execution time in milliseconds.
61    pub fn total_time_ms(&self) -> u64 {
62        self.total_time_ns() / 1_000_000
63    }
64
65    /// Get the number of queries timed.
66    pub fn query_count(&self) -> u64 {
67        self.query_count.load(Ordering::Relaxed)
68    }
69
70    /// Get the average execution time in nanoseconds.
71    pub fn avg_time_ns(&self) -> u64 {
72        let count = self.query_count();
73        if count == 0 {
74            0
75        } else {
76            self.total_time_ns() / count
77        }
78    }
79
80    /// Get the average execution time in microseconds.
81    pub fn avg_time_us(&self) -> u64 {
82        self.avg_time_ns() / 1000
83    }
84
85    /// Reset timing statistics.
86    pub fn reset(&self) {
87        self.total_time_ns.store(0, Ordering::SeqCst);
88        self.query_count.store(0, Ordering::SeqCst);
89    }
90}
91
92impl Default for TimingMiddleware {
93    fn default() -> Self {
94        Self::new()
95    }
96}
97
98impl Middleware for TimingMiddleware {
99    fn handle<'a>(
100        &'a self,
101        ctx: QueryContext,
102        next: Next<'a>,
103    ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
104        Box::pin(async move {
105            let start = Instant::now();
106
107            let result = next.run(ctx).await;
108
109            let elapsed = start.elapsed();
110            let elapsed_ns = elapsed.as_nanos() as u64;
111            let elapsed_us = elapsed.as_micros() as u64;
112
113            self.total_time_ns.fetch_add(elapsed_ns, Ordering::Relaxed);
114            self.query_count.fetch_add(1, Ordering::Relaxed);
115
116            // Update the response with execution time
117            result.map(|mut response| {
118                response.execution_time_us = elapsed_us;
119                response
120            })
121        })
122    }
123
124    fn name(&self) -> &'static str {
125        "TimingMiddleware"
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn test_timing_result() {
135        let result = TimingResult::from_nanos(1_500_000);
136        assert_eq!(result.duration_ns, 1_500_000);
137        assert_eq!(result.duration_us, 1500);
138        assert_eq!(result.duration_ms, 1);
139    }
140
141    #[test]
142    fn test_timing_middleware_initial_state() {
143        let middleware = TimingMiddleware::new();
144        assert_eq!(middleware.total_time_ns(), 0);
145        assert_eq!(middleware.query_count(), 0);
146        assert_eq!(middleware.avg_time_ns(), 0);
147    }
148
149    #[test]
150    fn test_timing_middleware_reset() {
151        let middleware = TimingMiddleware::new();
152        middleware.total_time_ns.store(1000, Ordering::SeqCst);
153        middleware.query_count.store(2, Ordering::SeqCst);
154
155        middleware.reset();
156
157        assert_eq!(middleware.total_time_ns(), 0);
158        assert_eq!(middleware.query_count(), 0);
159    }
160}
161
162