Skip to main content

prax_query/middleware/
timing.rs

1//! Timing middleware for measuring query execution time.
2
3use super::context::QueryContext;
4use super::types::{BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse};
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Instant;
7
8/// Result of timing a query.
9#[derive(Debug, Clone)]
10pub struct TimingResult {
11    /// Execution time in nanoseconds.
12    pub duration_ns: u64,
13    /// Execution time in microseconds.
14    pub duration_us: u64,
15    /// Execution time in milliseconds.
16    pub duration_ms: u64,
17}
18
19impl TimingResult {
20    /// Create from a duration.
21    pub fn from_nanos(ns: u64) -> Self {
22        Self {
23            duration_ns: ns,
24            duration_us: ns / 1000,
25            duration_ms: ns / 1_000_000,
26        }
27    }
28}
29
30/// Middleware that measures query execution time.
31///
32/// This is a lightweight middleware that only adds timing information
33/// to the response. For more comprehensive metrics, use `MetricsMiddleware`.
34pub struct TimingMiddleware {
35    /// Total execution time in nanoseconds.
36    total_time_ns: AtomicU64,
37    /// Number of queries timed.
38    query_count: AtomicU64,
39}
40
41impl TimingMiddleware {
42    /// Create a new timing middleware.
43    pub fn new() -> Self {
44        Self {
45            total_time_ns: AtomicU64::new(0),
46            query_count: AtomicU64::new(0),
47        }
48    }
49
50    /// Get the total execution time in nanoseconds.
51    pub fn total_time_ns(&self) -> u64 {
52        self.total_time_ns.load(Ordering::Relaxed)
53    }
54
55    /// Get the total execution time in microseconds.
56    pub fn total_time_us(&self) -> u64 {
57        self.total_time_ns() / 1000
58    }
59
60    /// Get the total execution time in milliseconds.
61    pub fn total_time_ms(&self) -> u64 {
62        self.total_time_ns() / 1_000_000
63    }
64
65    /// Get the number of queries timed.
66    pub fn query_count(&self) -> u64 {
67        self.query_count.load(Ordering::Relaxed)
68    }
69
70    /// Get the average execution time in nanoseconds.
71    pub fn avg_time_ns(&self) -> u64 {
72        self.total_time_ns()
73            .checked_div(self.query_count())
74            .unwrap_or(0)
75    }
76
77    /// Get the average execution time in microseconds.
78    pub fn avg_time_us(&self) -> u64 {
79        self.avg_time_ns() / 1000
80    }
81
82    /// Reset timing statistics.
83    pub fn reset(&self) {
84        self.total_time_ns.store(0, Ordering::SeqCst);
85        self.query_count.store(0, Ordering::SeqCst);
86    }
87}
88
89impl Default for TimingMiddleware {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl Middleware for TimingMiddleware {
96    fn handle<'a>(
97        &'a self,
98        ctx: QueryContext,
99        next: Next<'a>,
100    ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
101        Box::pin(async move {
102            let start = Instant::now();
103
104            let result = next.run(ctx).await;
105
106            let elapsed = start.elapsed();
107            let elapsed_ns = elapsed.as_nanos() as u64;
108            let elapsed_us = elapsed.as_micros() as u64;
109
110            self.total_time_ns.fetch_add(elapsed_ns, Ordering::Relaxed);
111            self.query_count.fetch_add(1, Ordering::Relaxed);
112
113            // Update the response with execution time
114            result.map(|mut response| {
115                response.execution_time_us = elapsed_us;
116                response
117            })
118        })
119    }
120
121    fn name(&self) -> &'static str {
122        "TimingMiddleware"
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn test_timing_result() {
132        let result = TimingResult::from_nanos(1_500_000);
133        assert_eq!(result.duration_ns, 1_500_000);
134        assert_eq!(result.duration_us, 1500);
135        assert_eq!(result.duration_ms, 1);
136    }
137
138    #[test]
139    fn test_timing_middleware_initial_state() {
140        let middleware = TimingMiddleware::new();
141        assert_eq!(middleware.total_time_ns(), 0);
142        assert_eq!(middleware.query_count(), 0);
143        assert_eq!(middleware.avg_time_ns(), 0);
144    }
145
146    #[test]
147    fn test_timing_middleware_reset() {
148        let middleware = TimingMiddleware::new();
149        middleware.total_time_ns.store(1000, Ordering::SeqCst);
150        middleware.query_count.store(2, Ordering::SeqCst);
151
152        middleware.reset();
153
154        assert_eq!(middleware.total_time_ns(), 0);
155        assert_eq!(middleware.query_count(), 0);
156    }
157}