1use super::context::{QueryContext, QueryType};
4use super::types::{BoxFuture, Middleware, MiddlewareResult, Next, QueryResponse};
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::{Arc, RwLock};
8use std::time::Instant;
9
10#[derive(Debug, Clone)]
12pub struct QueryMetrics {
13 pub total_queries: u64,
15 pub successful_queries: u64,
17 pub failed_queries: u64,
19 pub total_time_us: u64,
21 pub avg_time_us: u64,
23 pub min_time_us: u64,
25 pub max_time_us: u64,
27 pub slow_queries: u64,
29 pub cache_hits: u64,
31 pub queries_by_type: HashMap<String, u64>,
33 pub queries_by_model: HashMap<String, u64>,
35}
36
37impl Default for QueryMetrics {
38 fn default() -> Self {
39 Self {
40 total_queries: 0,
41 successful_queries: 0,
42 failed_queries: 0,
43 total_time_us: 0,
44 avg_time_us: 0,
45 min_time_us: u64::MAX,
46 max_time_us: 0,
47 slow_queries: 0,
48 cache_hits: 0,
49 queries_by_type: HashMap::new(),
50 queries_by_model: HashMap::new(),
51 }
52 }
53}
54
55impl QueryMetrics {
56 pub fn new() -> Self {
58 Self::default()
59 }
60
61 pub fn success_rate(&self) -> f64 {
63 if self.total_queries == 0 {
64 1.0
65 } else {
66 self.successful_queries as f64 / self.total_queries as f64
67 }
68 }
69
70 pub fn cache_hit_rate(&self) -> f64 {
72 if self.total_queries == 0 {
73 0.0
74 } else {
75 self.cache_hits as f64 / self.total_queries as f64
76 }
77 }
78
79 pub fn slow_query_rate(&self) -> f64 {
81 if self.total_queries == 0 {
82 0.0
83 } else {
84 self.slow_queries as f64 / self.total_queries as f64
85 }
86 }
87}
88
89pub trait MetricsCollector: Send + Sync {
91 fn record_query(
93 &self,
94 query_type: QueryType,
95 model: Option<&str>,
96 duration_us: u64,
97 success: bool,
98 from_cache: bool,
99 );
100
101 fn get_metrics(&self) -> QueryMetrics;
103
104 fn reset(&self);
106}
107
108#[derive(Debug)]
110pub struct InMemoryMetricsCollector {
111 total_queries: AtomicU64,
112 successful_queries: AtomicU64,
113 failed_queries: AtomicU64,
114 total_time_us: AtomicU64,
115 min_time_us: AtomicU64,
116 max_time_us: AtomicU64,
117 slow_queries: AtomicU64,
118 cache_hits: AtomicU64,
119 slow_threshold_us: u64,
120 queries_by_type: RwLock<HashMap<String, u64>>,
121 queries_by_model: RwLock<HashMap<String, u64>>,
122}
123
124impl InMemoryMetricsCollector {
125 pub fn new() -> Self {
127 Self::with_slow_threshold(1_000_000) }
129
130 pub fn with_slow_threshold(threshold_us: u64) -> Self {
132 Self {
133 total_queries: AtomicU64::new(0),
134 successful_queries: AtomicU64::new(0),
135 failed_queries: AtomicU64::new(0),
136 total_time_us: AtomicU64::new(0),
137 min_time_us: AtomicU64::new(u64::MAX),
138 max_time_us: AtomicU64::new(0),
139 slow_queries: AtomicU64::new(0),
140 cache_hits: AtomicU64::new(0),
141 slow_threshold_us: threshold_us,
142 queries_by_type: RwLock::new(HashMap::new()),
143 queries_by_model: RwLock::new(HashMap::new()),
144 }
145 }
146}
147
148impl Default for InMemoryMetricsCollector {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154impl MetricsCollector for InMemoryMetricsCollector {
155 fn record_query(
156 &self,
157 query_type: QueryType,
158 model: Option<&str>,
159 duration_us: u64,
160 success: bool,
161 from_cache: bool,
162 ) {
163 self.total_queries.fetch_add(1, Ordering::SeqCst);
164
165 if success {
166 self.successful_queries.fetch_add(1, Ordering::SeqCst);
167 } else {
168 self.failed_queries.fetch_add(1, Ordering::SeqCst);
169 }
170
171 self.total_time_us.fetch_add(duration_us, Ordering::SeqCst);
172
173 loop {
175 let current = self.min_time_us.load(Ordering::SeqCst);
176 if duration_us >= current {
177 break;
178 }
179 if self.min_time_us.compare_exchange(
180 current,
181 duration_us,
182 Ordering::SeqCst,
183 Ordering::SeqCst,
184 ).is_ok() {
185 break;
186 }
187 }
188
189 loop {
191 let current = self.max_time_us.load(Ordering::SeqCst);
192 if duration_us <= current {
193 break;
194 }
195 if self.max_time_us.compare_exchange(
196 current,
197 duration_us,
198 Ordering::SeqCst,
199 Ordering::SeqCst,
200 ).is_ok() {
201 break;
202 }
203 }
204
205 if duration_us >= self.slow_threshold_us {
206 self.slow_queries.fetch_add(1, Ordering::SeqCst);
207 }
208
209 if from_cache {
210 self.cache_hits.fetch_add(1, Ordering::SeqCst);
211 }
212
213 {
215 let mut by_type = self.queries_by_type.write().unwrap();
216 let key = format!("{:?}", query_type);
217 *by_type.entry(key).or_insert(0) += 1;
218 }
219
220 if let Some(model) = model {
222 let mut by_model = self.queries_by_model.write().unwrap();
223 *by_model.entry(model.to_string()).or_insert(0) += 1;
224 }
225 }
226
227 fn get_metrics(&self) -> QueryMetrics {
228 let total = self.total_queries.load(Ordering::SeqCst);
229 let total_time = self.total_time_us.load(Ordering::SeqCst);
230 let avg = if total > 0 { total_time / total } else { 0 };
231 let min = self.min_time_us.load(Ordering::SeqCst);
232
233 QueryMetrics {
234 total_queries: total,
235 successful_queries: self.successful_queries.load(Ordering::SeqCst),
236 failed_queries: self.failed_queries.load(Ordering::SeqCst),
237 total_time_us: total_time,
238 avg_time_us: avg,
239 min_time_us: if min == u64::MAX { 0 } else { min },
240 max_time_us: self.max_time_us.load(Ordering::SeqCst),
241 slow_queries: self.slow_queries.load(Ordering::SeqCst),
242 cache_hits: self.cache_hits.load(Ordering::SeqCst),
243 queries_by_type: self.queries_by_type.read().unwrap().clone(),
244 queries_by_model: self.queries_by_model.read().unwrap().clone(),
245 }
246 }
247
248 fn reset(&self) {
249 self.total_queries.store(0, Ordering::SeqCst);
250 self.successful_queries.store(0, Ordering::SeqCst);
251 self.failed_queries.store(0, Ordering::SeqCst);
252 self.total_time_us.store(0, Ordering::SeqCst);
253 self.min_time_us.store(u64::MAX, Ordering::SeqCst);
254 self.max_time_us.store(0, Ordering::SeqCst);
255 self.slow_queries.store(0, Ordering::SeqCst);
256 self.cache_hits.store(0, Ordering::SeqCst);
257 self.queries_by_type.write().unwrap().clear();
258 self.queries_by_model.write().unwrap().clear();
259 }
260}
261
262pub struct MetricsMiddleware {
280 collector: Arc<dyn MetricsCollector>,
281}
282
283impl MetricsMiddleware {
284 pub fn new(collector: Arc<dyn MetricsCollector>) -> Self {
286 Self { collector }
287 }
288
289 pub fn in_memory() -> (Self, Arc<InMemoryMetricsCollector>) {
291 let collector = Arc::new(InMemoryMetricsCollector::new());
292 let middleware = Self::new(collector.clone());
293 (middleware, collector)
294 }
295
296 pub fn collector(&self) -> &Arc<dyn MetricsCollector> {
298 &self.collector
299 }
300}
301
302impl Middleware for MetricsMiddleware {
303 fn handle<'a>(
304 &'a self,
305 ctx: QueryContext,
306 next: Next<'a>,
307 ) -> BoxFuture<'a, MiddlewareResult<QueryResponse>> {
308 Box::pin(async move {
309 let query_type = ctx.query_type();
310 let model = ctx.metadata().model.clone();
311 let start = Instant::now();
312
313 let result = next.run(ctx).await;
314
315 let duration_us = start.elapsed().as_micros() as u64;
316 let (success, from_cache) = match &result {
317 Ok(response) => (true, response.from_cache),
318 Err(_) => (false, false),
319 };
320
321 self.collector.record_query(
322 query_type,
323 model.as_deref(),
324 duration_us,
325 success,
326 from_cache,
327 );
328
329 result
330 })
331 }
332
333 fn name(&self) -> &'static str {
334 "MetricsMiddleware"
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn test_query_metrics_default() {
344 let metrics = QueryMetrics::new();
345 assert_eq!(metrics.total_queries, 0);
346 assert_eq!(metrics.success_rate(), 1.0);
347 assert_eq!(metrics.cache_hit_rate(), 0.0);
348 }
349
350 #[test]
351 fn test_in_memory_collector() {
352 let collector = InMemoryMetricsCollector::new();
353
354 collector.record_query(QueryType::Select, Some("User"), 1000, true, false);
355 collector.record_query(QueryType::Select, Some("User"), 2000, true, true);
356 collector.record_query(QueryType::Insert, Some("Post"), 500, false, false);
357
358 let metrics = collector.get_metrics();
359 assert_eq!(metrics.total_queries, 3);
360 assert_eq!(metrics.successful_queries, 2);
361 assert_eq!(metrics.failed_queries, 1);
362 assert_eq!(metrics.cache_hits, 1);
363 assert_eq!(metrics.min_time_us, 500);
364 assert_eq!(metrics.max_time_us, 2000);
365 }
366
367 #[test]
368 fn test_collector_reset() {
369 let collector = InMemoryMetricsCollector::new();
370 collector.record_query(QueryType::Select, None, 1000, true, false);
371
372 assert_eq!(collector.get_metrics().total_queries, 1);
373
374 collector.reset();
375
376 assert_eq!(collector.get_metrics().total_queries, 0);
377 }
378
379 #[test]
380 fn test_metrics_rates() {
381 let collector = InMemoryMetricsCollector::with_slow_threshold(1000);
382
383 collector.record_query(QueryType::Select, None, 500, true, true);
384 collector.record_query(QueryType::Select, None, 500, true, false);
385 collector.record_query(QueryType::Select, None, 2000, true, false); collector.record_query(QueryType::Select, None, 500, false, false);
387
388 let metrics = collector.get_metrics();
389 assert_eq!(metrics.total_queries, 4);
390 assert!((metrics.success_rate() - 0.75).abs() < 0.01);
391 assert!((metrics.cache_hit_rate() - 0.25).abs() < 0.01);
392 assert!((metrics.slow_query_rate() - 0.25).abs() < 0.01);
393 }
394}
395
396