prax_query/middleware/
metrics.rs

1//! Metrics middleware for query performance tracking.
2
3use super::context::{QueryContext, QueryType};
4use super::types::{BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse};
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::{Arc, RwLock};
8use std::time::Instant;
9
10/// Collected metrics for queries.
11#[derive(Debug, Clone)]
12pub struct QueryMetrics {
13    /// Total number of queries executed.
14    pub total_queries: u64,
15    /// Number of successful queries.
16    pub successful_queries: u64,
17    /// Number of failed queries.
18    pub failed_queries: u64,
19    /// Total execution time in microseconds.
20    pub total_time_us: u64,
21    /// Average execution time in microseconds.
22    pub avg_time_us: u64,
23    /// Minimum execution time in microseconds.
24    pub min_time_us: u64,
25    /// Maximum execution time in microseconds.
26    pub max_time_us: u64,
27    /// Number of slow queries.
28    pub slow_queries: u64,
29    /// Number of cache hits.
30    pub cache_hits: u64,
31    /// Queries by type.
32    pub queries_by_type: HashMap<String, u64>,
33    /// Queries by model.
34    pub queries_by_model: HashMap<String, u64>,
35}
36
37impl Default for QueryMetrics {
38    fn default() -> Self {
39        Self {
40            total_queries: 0,
41            successful_queries: 0,
42            failed_queries: 0,
43            total_time_us: 0,
44            avg_time_us: 0,
45            min_time_us: u64::MAX,
46            max_time_us: 0,
47            slow_queries: 0,
48            cache_hits: 0,
49            queries_by_type: HashMap::new(),
50            queries_by_model: HashMap::new(),
51        }
52    }
53}
54
55impl QueryMetrics {
56    /// Create empty metrics.
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    /// Calculate query success rate (0.0 to 1.0).
62    pub fn success_rate(&self) -> f64 {
63        if self.total_queries == 0 {
64            1.0
65        } else {
66            self.successful_queries as f64 / self.total_queries as f64
67        }
68    }
69
70    /// Calculate cache hit rate (0.0 to 1.0).
71    pub fn cache_hit_rate(&self) -> f64 {
72        if self.total_queries == 0 {
73            0.0
74        } else {
75            self.cache_hits as f64 / self.total_queries as f64
76        }
77    }
78
79    /// Calculate slow query rate (0.0 to 1.0).
80    pub fn slow_query_rate(&self) -> f64 {
81        if self.total_queries == 0 {
82            0.0
83        } else {
84            self.slow_queries as f64 / self.total_queries as f64
85        }
86    }
87}
88
89/// Interface for collecting metrics.
90pub trait MetricsCollector: Send + Sync {
91    /// Record a query execution.
92    fn record_query(
93        &self,
94        query_type: QueryType,
95        model: Option<&str>,
96        duration_us: u64,
97        success: bool,
98        from_cache: bool,
99    );
100
101    /// Get current metrics.
102    fn get_metrics(&self) -> QueryMetrics;
103
104    /// Reset all metrics.
105    fn reset(&self);
106}
107
108/// In-memory metrics collector.
109#[derive(Debug)]
110pub struct InMemoryMetricsCollector {
111    total_queries: AtomicU64,
112    successful_queries: AtomicU64,
113    failed_queries: AtomicU64,
114    total_time_us: AtomicU64,
115    min_time_us: AtomicU64,
116    max_time_us: AtomicU64,
117    slow_queries: AtomicU64,
118    cache_hits: AtomicU64,
119    slow_threshold_us: u64,
120    queries_by_type: RwLock<HashMap<String, u64>>,
121    queries_by_model: RwLock<HashMap<String, u64>>,
122}
123
124impl InMemoryMetricsCollector {
125    /// Create a new in-memory collector.
126    pub fn new() -> Self {
127        Self::with_slow_threshold(1_000_000) // 1 second default
128    }
129
130    /// Create with custom slow query threshold.
131    pub fn with_slow_threshold(threshold_us: u64) -> Self {
132        Self {
133            total_queries: AtomicU64::new(0),
134            successful_queries: AtomicU64::new(0),
135            failed_queries: AtomicU64::new(0),
136            total_time_us: AtomicU64::new(0),
137            min_time_us: AtomicU64::new(u64::MAX),
138            max_time_us: AtomicU64::new(0),
139            slow_queries: AtomicU64::new(0),
140            cache_hits: AtomicU64::new(0),
141            slow_threshold_us: threshold_us,
142            queries_by_type: RwLock::new(HashMap::new()),
143            queries_by_model: RwLock::new(HashMap::new()),
144        }
145    }
146}
147
148impl Default for InMemoryMetricsCollector {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154impl MetricsCollector for InMemoryMetricsCollector {
155    fn record_query(
156        &self,
157        query_type: QueryType,
158        model: Option<&str>,
159        duration_us: u64,
160        success: bool,
161        from_cache: bool,
162    ) {
163        self.total_queries.fetch_add(1, Ordering::SeqCst);
164
165        if success {
166            self.successful_queries.fetch_add(1, Ordering::SeqCst);
167        } else {
168            self.failed_queries.fetch_add(1, Ordering::SeqCst);
169        }
170
171        self.total_time_us.fetch_add(duration_us, Ordering::SeqCst);
172
173        // Update min (using compare-and-swap loop)
174        loop {
175            let current = self.min_time_us.load(Ordering::SeqCst);
176            if duration_us >= current {
177                break;
178            }
179            if self.min_time_us.compare_exchange(
180                current,
181                duration_us,
182                Ordering::SeqCst,
183                Ordering::SeqCst,
184            ).is_ok() {
185                break;
186            }
187        }
188
189        // Update max
190        loop {
191            let current = self.max_time_us.load(Ordering::SeqCst);
192            if duration_us <= current {
193                break;
194            }
195            if self.max_time_us.compare_exchange(
196                current,
197                duration_us,
198                Ordering::SeqCst,
199                Ordering::SeqCst,
200            ).is_ok() {
201                break;
202            }
203        }
204
205        if duration_us >= self.slow_threshold_us {
206            self.slow_queries.fetch_add(1, Ordering::SeqCst);
207        }
208
209        if from_cache {
210            self.cache_hits.fetch_add(1, Ordering::SeqCst);
211        }
212
213        // Update queries by type
214        {
215            let mut by_type = self.queries_by_type.write().unwrap();
216            let key = format!("{:?}", query_type);
217            *by_type.entry(key).or_insert(0) += 1;
218        }
219
220        // Update queries by model
221        if let Some(model) = model {
222            let mut by_model = self.queries_by_model.write().unwrap();
223            *by_model.entry(model.to_string()).or_insert(0) += 1;
224        }
225    }
226
227    fn get_metrics(&self) -> QueryMetrics {
228        let total = self.total_queries.load(Ordering::SeqCst);
229        let total_time = self.total_time_us.load(Ordering::SeqCst);
230        let avg = if total > 0 { total_time / total } else { 0 };
231        let min = self.min_time_us.load(Ordering::SeqCst);
232
233        QueryMetrics {
234            total_queries: total,
235            successful_queries: self.successful_queries.load(Ordering::SeqCst),
236            failed_queries: self.failed_queries.load(Ordering::SeqCst),
237            total_time_us: total_time,
238            avg_time_us: avg,
239            min_time_us: if min == u64::MAX { 0 } else { min },
240            max_time_us: self.max_time_us.load(Ordering::SeqCst),
241            slow_queries: self.slow_queries.load(Ordering::SeqCst),
242            cache_hits: self.cache_hits.load(Ordering::SeqCst),
243            queries_by_type: self.queries_by_type.read().unwrap().clone(),
244            queries_by_model: self.queries_by_model.read().unwrap().clone(),
245        }
246    }
247
248    fn reset(&self) {
249        self.total_queries.store(0, Ordering::SeqCst);
250        self.successful_queries.store(0, Ordering::SeqCst);
251        self.failed_queries.store(0, Ordering::SeqCst);
252        self.total_time_us.store(0, Ordering::SeqCst);
253        self.min_time_us.store(u64::MAX, Ordering::SeqCst);
254        self.max_time_us.store(0, Ordering::SeqCst);
255        self.slow_queries.store(0, Ordering::SeqCst);
256        self.cache_hits.store(0, Ordering::SeqCst);
257        self.queries_by_type.write().unwrap().clear();
258        self.queries_by_model.write().unwrap().clear();
259    }
260}
261
262/// Middleware that collects query metrics.
263///
264/// # Example
265///
266/// ```rust,ignore
267/// use prax_query::middleware::{MetricsMiddleware, InMemoryMetricsCollector};
268///
269/// let collector = Arc::new(InMemoryMetricsCollector::new());
270/// let metrics = MetricsMiddleware::new(collector.clone());
271///
272/// // Use middleware...
273///
274/// // Get metrics
275/// let stats = collector.get_metrics();
276/// println!("Total queries: {}", stats.total_queries);
277/// println!("Avg time: {}us", stats.avg_time_us);
278/// ```
279pub struct MetricsMiddleware {
280    collector: Arc<dyn MetricsCollector>,
281}
282
283impl MetricsMiddleware {
284    /// Create a new metrics middleware.
285    pub fn new(collector: Arc<dyn MetricsCollector>) -> Self {
286        Self { collector }
287    }
288
289    /// Create with default in-memory collector.
290    pub fn in_memory() -> (Self, Arc<InMemoryMetricsCollector>) {
291        let collector = Arc::new(InMemoryMetricsCollector::new());
292        let middleware = Self::new(collector.clone());
293        (middleware, collector)
294    }
295
296    /// Get the metrics collector.
297    pub fn collector(&self) -> &Arc<dyn MetricsCollector> {
298        &self.collector
299    }
300}
301
302impl Middleware for MetricsMiddleware {
303    fn handle<'a>(
304        &'a self,
305        ctx: QueryContext,
306        next: Next<'a>,
307    ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
308        Box::pin(async move {
309            let query_type = ctx.query_type();
310            let model = ctx.metadata().model.clone();
311            let start = Instant::now();
312
313            let result = next.run(ctx).await;
314
315            let duration_us = start.elapsed().as_micros() as u64;
316            let (success, from_cache) = match &result {
317                Ok(response) => (true, response.from_cache),
318                Err(_) => (false, false),
319            };
320
321            self.collector.record_query(
322                query_type,
323                model.as_deref(),
324                duration_us,
325                success,
326                from_cache,
327            );
328
329            result
330        })
331    }
332
333    fn name(&self) -> &'static str {
334        "MetricsMiddleware"
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn test_query_metrics_default() {
344        let metrics = QueryMetrics::new();
345        assert_eq!(metrics.total_queries, 0);
346        assert_eq!(metrics.success_rate(), 1.0);
347        assert_eq!(metrics.cache_hit_rate(), 0.0);
348    }
349
350    #[test]
351    fn test_in_memory_collector() {
352        let collector = InMemoryMetricsCollector::new();
353
354        collector.record_query(QueryType::Select, Some("User"), 1000, true, false);
355        collector.record_query(QueryType::Select, Some("User"), 2000, true, true);
356        collector.record_query(QueryType::Insert, Some("Post"), 500, false, false);
357
358        let metrics = collector.get_metrics();
359        assert_eq!(metrics.total_queries, 3);
360        assert_eq!(metrics.successful_queries, 2);
361        assert_eq!(metrics.failed_queries, 1);
362        assert_eq!(metrics.cache_hits, 1);
363        assert_eq!(metrics.min_time_us, 500);
364        assert_eq!(metrics.max_time_us, 2000);
365    }
366
367    #[test]
368    fn test_collector_reset() {
369        let collector = InMemoryMetricsCollector::new();
370        collector.record_query(QueryType::Select, None, 1000, true, false);
371
372        assert_eq!(collector.get_metrics().total_queries, 1);
373
374        collector.reset();
375
376        assert_eq!(collector.get_metrics().total_queries, 0);
377    }
378
379    #[test]
380    fn test_metrics_rates() {
381        let collector = InMemoryMetricsCollector::with_slow_threshold(1000);
382
383        collector.record_query(QueryType::Select, None, 500, true, true);
384        collector.record_query(QueryType::Select, None, 500, true, false);
385        collector.record_query(QueryType::Select, None, 2000, true, false);  // slow
386        collector.record_query(QueryType::Select, None, 500, false, false);
387
388        let metrics = collector.get_metrics();
389        assert_eq!(metrics.total_queries, 4);
390        assert!((metrics.success_rate() - 0.75).abs() < 0.01);
391        assert!((metrics.cache_hit_rate() - 0.25).abs() < 0.01);
392        assert!((metrics.slow_query_rate() - 0.25).abs() < 0.01);
393    }
394}
395
396