1use super::context::{QueryContext, QueryType};
4use super::types::{BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse};
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::{Arc, RwLock};
8use std::time::Instant;
9
10#[derive(Debug, Clone)]
12pub struct QueryMetrics {
13 pub total_queries: u64,
15 pub successful_queries: u64,
17 pub failed_queries: u64,
19 pub total_time_us: u64,
21 pub avg_time_us: u64,
23 pub min_time_us: u64,
25 pub max_time_us: u64,
27 pub slow_queries: u64,
29 pub cache_hits: u64,
31 pub queries_by_type: HashMap<String, u64>,
33 pub queries_by_model: HashMap<String, u64>,
35}
36
37impl Default for QueryMetrics {
38 fn default() -> Self {
39 Self {
40 total_queries: 0,
41 successful_queries: 0,
42 failed_queries: 0,
43 total_time_us: 0,
44 avg_time_us: 0,
45 min_time_us: u64::MAX,
46 max_time_us: 0,
47 slow_queries: 0,
48 cache_hits: 0,
49 queries_by_type: HashMap::new(),
50 queries_by_model: HashMap::new(),
51 }
52 }
53}
54
55impl QueryMetrics {
56 pub fn new() -> Self {
58 Self::default()
59 }
60
61 pub fn success_rate(&self) -> f64 {
63 if self.total_queries == 0 {
64 1.0
65 } else {
66 self.successful_queries as f64 / self.total_queries as f64
67 }
68 }
69
70 pub fn cache_hit_rate(&self) -> f64 {
72 if self.total_queries == 0 {
73 0.0
74 } else {
75 self.cache_hits as f64 / self.total_queries as f64
76 }
77 }
78
79 pub fn slow_query_rate(&self) -> f64 {
81 if self.total_queries == 0 {
82 0.0
83 } else {
84 self.slow_queries as f64 / self.total_queries as f64
85 }
86 }
87}
88
89pub trait MetricsCollector: Send + Sync {
91 fn record_query(
93 &self,
94 query_type: QueryType,
95 model: Option<&str>,
96 duration_us: u64,
97 success: bool,
98 from_cache: bool,
99 );
100
101 fn get_metrics(&self) -> QueryMetrics;
103
104 fn reset(&self);
106}
107
108#[derive(Debug)]
110pub struct InMemoryMetricsCollector {
111 total_queries: AtomicU64,
112 successful_queries: AtomicU64,
113 failed_queries: AtomicU64,
114 total_time_us: AtomicU64,
115 min_time_us: AtomicU64,
116 max_time_us: AtomicU64,
117 slow_queries: AtomicU64,
118 cache_hits: AtomicU64,
119 slow_threshold_us: u64,
120 queries_by_type: RwLock<HashMap<String, u64>>,
121 queries_by_model: RwLock<HashMap<String, u64>>,
122}
123
124impl InMemoryMetricsCollector {
125 pub fn new() -> Self {
127 Self::with_slow_threshold(1_000_000) }
129
130 pub fn with_slow_threshold(threshold_us: u64) -> Self {
132 Self {
133 total_queries: AtomicU64::new(0),
134 successful_queries: AtomicU64::new(0),
135 failed_queries: AtomicU64::new(0),
136 total_time_us: AtomicU64::new(0),
137 min_time_us: AtomicU64::new(u64::MAX),
138 max_time_us: AtomicU64::new(0),
139 slow_queries: AtomicU64::new(0),
140 cache_hits: AtomicU64::new(0),
141 slow_threshold_us: threshold_us,
142 queries_by_type: RwLock::new(HashMap::new()),
143 queries_by_model: RwLock::new(HashMap::new()),
144 }
145 }
146}
147
148impl Default for InMemoryMetricsCollector {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154impl MetricsCollector for InMemoryMetricsCollector {
155 fn record_query(
156 &self,
157 query_type: QueryType,
158 model: Option<&str>,
159 duration_us: u64,
160 success: bool,
161 from_cache: bool,
162 ) {
163 self.total_queries.fetch_add(1, Ordering::SeqCst);
164
165 if success {
166 self.successful_queries.fetch_add(1, Ordering::SeqCst);
167 } else {
168 self.failed_queries.fetch_add(1, Ordering::SeqCst);
169 }
170
171 self.total_time_us.fetch_add(duration_us, Ordering::SeqCst);
172
173 loop {
175 let current = self.min_time_us.load(Ordering::SeqCst);
176 if duration_us >= current {
177 break;
178 }
179 if self
180 .min_time_us
181 .compare_exchange(current, duration_us, Ordering::SeqCst, Ordering::SeqCst)
182 .is_ok()
183 {
184 break;
185 }
186 }
187
188 loop {
190 let current = self.max_time_us.load(Ordering::SeqCst);
191 if duration_us <= current {
192 break;
193 }
194 if self
195 .max_time_us
196 .compare_exchange(current, duration_us, Ordering::SeqCst, Ordering::SeqCst)
197 .is_ok()
198 {
199 break;
200 }
201 }
202
203 if duration_us >= self.slow_threshold_us {
204 self.slow_queries.fetch_add(1, Ordering::SeqCst);
205 }
206
207 if from_cache {
208 self.cache_hits.fetch_add(1, Ordering::SeqCst);
209 }
210
211 {
213 let mut by_type = self.queries_by_type.write().unwrap();
214 let key = format!("{:?}", query_type);
215 *by_type.entry(key).or_insert(0) += 1;
216 }
217
218 if let Some(model) = model {
220 let mut by_model = self.queries_by_model.write().unwrap();
221 *by_model.entry(model.to_string()).or_insert(0) += 1;
222 }
223 }
224
225 fn get_metrics(&self) -> QueryMetrics {
226 let total = self.total_queries.load(Ordering::SeqCst);
227 let total_time = self.total_time_us.load(Ordering::SeqCst);
228 let avg = if total > 0 { total_time / total } else { 0 };
229 let min = self.min_time_us.load(Ordering::SeqCst);
230
231 QueryMetrics {
232 total_queries: total,
233 successful_queries: self.successful_queries.load(Ordering::SeqCst),
234 failed_queries: self.failed_queries.load(Ordering::SeqCst),
235 total_time_us: total_time,
236 avg_time_us: avg,
237 min_time_us: if min == u64::MAX { 0 } else { min },
238 max_time_us: self.max_time_us.load(Ordering::SeqCst),
239 slow_queries: self.slow_queries.load(Ordering::SeqCst),
240 cache_hits: self.cache_hits.load(Ordering::SeqCst),
241 queries_by_type: self.queries_by_type.read().unwrap().clone(),
242 queries_by_model: self.queries_by_model.read().unwrap().clone(),
243 }
244 }
245
246 fn reset(&self) {
247 self.total_queries.store(0, Ordering::SeqCst);
248 self.successful_queries.store(0, Ordering::SeqCst);
249 self.failed_queries.store(0, Ordering::SeqCst);
250 self.total_time_us.store(0, Ordering::SeqCst);
251 self.min_time_us.store(u64::MAX, Ordering::SeqCst);
252 self.max_time_us.store(0, Ordering::SeqCst);
253 self.slow_queries.store(0, Ordering::SeqCst);
254 self.cache_hits.store(0, Ordering::SeqCst);
255 self.queries_by_type.write().unwrap().clear();
256 self.queries_by_model.write().unwrap().clear();
257 }
258}
259
260pub struct MetricsMiddleware {
278 collector: Arc<dyn MetricsCollector>,
279}
280
281impl MetricsMiddleware {
282 pub fn new(collector: Arc<dyn MetricsCollector>) -> Self {
284 Self { collector }
285 }
286
287 pub fn in_memory() -> (Self, Arc<InMemoryMetricsCollector>) {
289 let collector = Arc::new(InMemoryMetricsCollector::new());
290 let middleware = Self::new(collector.clone());
291 (middleware, collector)
292 }
293
294 pub fn collector(&self) -> &Arc<dyn MetricsCollector> {
296 &self.collector
297 }
298}
299
300impl Middleware for MetricsMiddleware {
301 fn handle<'a>(
302 &'a self,
303 ctx: QueryContext,
304 next: Next<'a>,
305 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
306 Box::pin(async move {
307 let query_type = ctx.query_type();
308 let model = ctx.metadata().model.clone();
309 let start = Instant::now();
310
311 let result = next.run(ctx).await;
312
313 let duration_us = start.elapsed().as_micros() as u64;
314 let (success, from_cache) = match &result {
315 Ok(response) => (true, response.from_cache),
316 Err(_) => (false, false),
317 };
318
319 self.collector.record_query(
320 query_type,
321 model.as_deref(),
322 duration_us,
323 success,
324 from_cache,
325 );
326
327 result
328 })
329 }
330
331 fn name(&self) -> &'static str {
332 "MetricsMiddleware"
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn test_query_metrics_default() {
342 let metrics = QueryMetrics::new();
343 assert_eq!(metrics.total_queries, 0);
344 assert_eq!(metrics.success_rate(), 1.0);
345 assert_eq!(metrics.cache_hit_rate(), 0.0);
346 }
347
348 #[test]
349 fn test_in_memory_collector() {
350 let collector = InMemoryMetricsCollector::new();
351
352 collector.record_query(QueryType::Select, Some("User"), 1000, true, false);
353 collector.record_query(QueryType::Select, Some("User"), 2000, true, true);
354 collector.record_query(QueryType::Insert, Some("Post"), 500, false, false);
355
356 let metrics = collector.get_metrics();
357 assert_eq!(metrics.total_queries, 3);
358 assert_eq!(metrics.successful_queries, 2);
359 assert_eq!(metrics.failed_queries, 1);
360 assert_eq!(metrics.cache_hits, 1);
361 assert_eq!(metrics.min_time_us, 500);
362 assert_eq!(metrics.max_time_us, 2000);
363 }
364
365 #[test]
366 fn test_collector_reset() {
367 let collector = InMemoryMetricsCollector::new();
368 collector.record_query(QueryType::Select, None, 1000, true, false);
369
370 assert_eq!(collector.get_metrics().total_queries, 1);
371
372 collector.reset();
373
374 assert_eq!(collector.get_metrics().total_queries, 0);
375 }
376
377 #[test]
378 fn test_metrics_rates() {
379 let collector = InMemoryMetricsCollector::with_slow_threshold(1000);
380
381 collector.record_query(QueryType::Select, None, 500, true, true);
382 collector.record_query(QueryType::Select, None, 500, true, false);
383 collector.record_query(QueryType::Select, None, 2000, true, false); collector.record_query(QueryType::Select, None, 500, false, false);
385
386 let metrics = collector.get_metrics();
387 assert_eq!(metrics.total_queries, 4);
388 assert!((metrics.success_rate() - 0.75).abs() < 0.01);
389 assert!((metrics.cache_hit_rate() - 0.25).abs() < 0.01);
390 assert!((metrics.slow_query_rate() - 0.25).abs() < 0.01);
391 }
392}