1use chrono::{DateTime, Utc};
10use parking_lot::RwLock;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tracing::{debug, error, info, warn};
15use uuid::Uuid;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
19pub enum LogLevel {
20 Debug,
22 Info,
24 Warn,
26 Error,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, Default)]
32pub struct DbContext {
33 pub user_id: Option<Uuid>,
35 pub request_id: Option<String>,
37 pub session_id: Option<String>,
39 pub ip_address: Option<String>,
41 pub user_agent: Option<String>,
43 pub custom_fields: HashMap<String, String>,
45}
46
47impl DbContext {
48 pub fn with_user_id(mut self, user_id: Uuid) -> Self {
50 self.user_id = Some(user_id);
51 self
52 }
53
54 pub fn with_request_id(mut self, request_id: String) -> Self {
56 self.request_id = Some(request_id);
57 self
58 }
59
60 pub fn with_session_id(mut self, session_id: String) -> Self {
62 self.session_id = Some(session_id);
63 self
64 }
65
66 pub fn with_ip_address(mut self, ip_address: String) -> Self {
68 self.ip_address = Some(ip_address);
69 self
70 }
71
72 pub fn with_user_agent(mut self, user_agent: String) -> Self {
74 self.user_agent = Some(user_agent);
75 self
76 }
77
78 pub fn add_field(mut self, key: String, value: String) -> Self {
80 self.custom_fields.insert(key, value);
81 self
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct QueryLogEntry {
88 pub timestamp: DateTime<Utc>,
90 pub level: LogLevel,
92 pub operation: String,
94 pub table: String,
96 pub duration_ms: u64,
98 pub rows_affected: Option<u64>,
100 pub success: bool,
102 pub error: Option<String>,
104 pub context: DbContext,
106 pub sql: Option<String>,
108}
109
110impl QueryLogEntry {
111 pub fn new(operation: String, table: String, duration_ms: u64) -> Self {
113 Self {
114 timestamp: Utc::now(),
115 level: LogLevel::Info,
116 operation,
117 table,
118 duration_ms,
119 rows_affected: None,
120 success: true,
121 error: None,
122 context: DbContext::default(),
123 sql: None,
124 }
125 }
126
127 pub fn with_level(mut self, level: LogLevel) -> Self {
129 self.level = level;
130 self
131 }
132
133 pub fn with_rows_affected(mut self, rows: u64) -> Self {
135 self.rows_affected = Some(rows);
136 self
137 }
138
139 pub fn with_error(mut self, error: String) -> Self {
141 self.success = false;
142 self.error = Some(error);
143 self.level = LogLevel::Error;
144 self
145 }
146
147 pub fn with_context(mut self, context: DbContext) -> Self {
149 self.context = context;
150 self
151 }
152
153 pub fn with_sql(mut self, sql: String) -> Self {
155 self.sql = Some(sql);
156 self
157 }
158
159 pub fn log(&self) {
161 let user_id = self.context.user_id.map(|id| id.to_string());
162 let request_id = self.context.request_id.as_deref();
163
164 match self.level {
165 LogLevel::Debug => {
166 debug!(
167 operation = %self.operation,
168 table = %self.table,
169 duration_ms = self.duration_ms,
170 user_id = ?user_id,
171 request_id = ?request_id,
172 success = self.success,
173 "Database query executed"
174 );
175 }
176 LogLevel::Info => {
177 info!(
178 operation = %self.operation,
179 table = %self.table,
180 duration_ms = self.duration_ms,
181 user_id = ?user_id,
182 request_id = ?request_id,
183 success = self.success,
184 "Database query executed"
185 );
186 }
187 LogLevel::Warn => {
188 warn!(
189 operation = %self.operation,
190 table = %self.table,
191 duration_ms = self.duration_ms,
192 user_id = ?user_id,
193 request_id = ?request_id,
194 success = self.success,
195 error = ?self.error,
196 "Database query warning"
197 );
198 }
199 LogLevel::Error => {
200 error!(
201 operation = %self.operation,
202 table = %self.table,
203 duration_ms = self.duration_ms,
204 user_id = ?user_id,
205 request_id = ?request_id,
206 success = self.success,
207 error = ?self.error,
208 "Database query failed"
209 );
210 }
211 }
212 }
213}
214
215#[derive(Debug, Clone)]
217pub struct AnomalyDetector {
218 baseline_avg_ms: f64,
220 std_dev_threshold: f64,
222 min_samples: usize,
224}
225
226impl Default for AnomalyDetector {
227 fn default() -> Self {
228 Self {
229 baseline_avg_ms: 100.0,
230 std_dev_threshold: 3.0,
231 min_samples: 10,
232 }
233 }
234}
235
236impl AnomalyDetector {
237 pub fn new(baseline_avg_ms: f64, std_dev_threshold: f64) -> Self {
239 Self {
240 baseline_avg_ms,
241 std_dev_threshold,
242 min_samples: 10,
243 }
244 }
245
246 pub fn is_anomalous(&self, duration_ms: u64, samples: &[u64]) -> bool {
248 if samples.len() < self.min_samples {
249 return duration_ms as f64 > self.baseline_avg_ms * 2.0;
251 }
252
253 let mean = samples.iter().sum::<u64>() as f64 / samples.len() as f64;
254 let variance = samples
255 .iter()
256 .map(|&x| {
257 let diff = x as f64 - mean;
258 diff * diff
259 })
260 .sum::<f64>()
261 / samples.len() as f64;
262 let std_dev = variance.sqrt();
263
264 duration_ms as f64 > mean + (self.std_dev_threshold * std_dev)
265 }
266
267 pub fn update_baseline(&mut self, samples: &[u64]) {
269 if samples.is_empty() {
270 return;
271 }
272
273 self.baseline_avg_ms = samples.iter().sum::<u64>() as f64 / samples.len() as f64;
274 }
275}
276
277#[derive(Debug)]
279pub struct PerformanceTracker {
280 samples: Arc<RwLock<HashMap<String, Vec<u64>>>>,
282 detector: AnomalyDetector,
284}
285
286impl Default for PerformanceTracker {
287 fn default() -> Self {
288 Self::new(AnomalyDetector::default())
289 }
290}
291
292impl PerformanceTracker {
293 pub fn new(detector: AnomalyDetector) -> Self {
295 Self {
296 samples: Arc::new(RwLock::new(HashMap::new())),
297 detector,
298 }
299 }
300
301 pub fn record(&self, operation: &str, table: &str, duration_ms: u64) {
303 let key = format!("{}:{}", operation, table);
304 let mut samples = self.samples.write();
305 let entry = samples.entry(key).or_default();
306
307 if entry.len() >= 100 {
309 entry.remove(0);
310 }
311
312 entry.push(duration_ms);
313 }
314
315 pub fn check_anomaly(&self, operation: &str, table: &str, duration_ms: u64) -> bool {
317 let key = format!("{}:{}", operation, table);
318 let samples = self.samples.read();
319
320 if let Some(history) = samples.get(&key) {
321 self.detector.is_anomalous(duration_ms, history)
322 } else {
323 false
324 }
325 }
326
327 pub fn get_stats(&self, operation: &str, table: &str) -> Option<QueryStats> {
329 let key = format!("{}:{}", operation, table);
330 let samples = self.samples.read();
331
332 samples.get(&key).map(|history| {
333 let count = history.len();
334 let sum: u64 = history.iter().sum();
335 let avg = if count > 0 { sum / count as u64 } else { 0 };
336 let min = history.iter().min().copied().unwrap_or(0);
337 let max = history.iter().max().copied().unwrap_or(0);
338
339 QueryStats {
340 operation: operation.to_string(),
341 table: table.to_string(),
342 count,
343 avg_duration_ms: avg,
344 min_duration_ms: min,
345 max_duration_ms: max,
346 }
347 })
348 }
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize)]
353pub struct QueryStats {
354 pub operation: String,
356 pub table: String,
358 pub count: usize,
360 pub avg_duration_ms: u64,
362 pub min_duration_ms: u64,
364 pub max_duration_ms: u64,
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn test_db_context_builder() {
374 let user_id = Uuid::new_v4();
375 let context = DbContext::default()
376 .with_user_id(user_id)
377 .with_request_id("req-123".to_string())
378 .with_session_id("sess-456".to_string())
379 .with_ip_address("127.0.0.1".to_string())
380 .with_user_agent("test-agent".to_string())
381 .add_field("custom".to_string(), "value".to_string());
382
383 assert_eq!(context.user_id, Some(user_id));
384 assert_eq!(context.request_id, Some("req-123".to_string()));
385 assert_eq!(context.session_id, Some("sess-456".to_string()));
386 assert_eq!(context.ip_address, Some("127.0.0.1".to_string()));
387 assert_eq!(context.user_agent, Some("test-agent".to_string()));
388 assert_eq!(
389 context.custom_fields.get("custom"),
390 Some(&"value".to_string())
391 );
392 }
393
394 #[test]
395 fn test_query_log_entry() {
396 let entry = QueryLogEntry::new("SELECT".to_string(), "users".to_string(), 50)
397 .with_rows_affected(10)
398 .with_sql("SELECT * FROM users".to_string());
399
400 assert_eq!(entry.operation, "SELECT");
401 assert_eq!(entry.table, "users");
402 assert_eq!(entry.duration_ms, 50);
403 assert_eq!(entry.rows_affected, Some(10));
404 assert!(entry.success);
405 assert_eq!(entry.sql, Some("SELECT * FROM users".to_string()));
406 }
407
408 #[test]
409 fn test_query_log_entry_with_error() {
410 let entry = QueryLogEntry::new("INSERT".to_string(), "users".to_string(), 100)
411 .with_error("Constraint violation".to_string());
412
413 assert!(!entry.success);
414 assert_eq!(entry.error, Some("Constraint violation".to_string()));
415 assert_eq!(entry.level, LogLevel::Error);
416 }
417
418 #[test]
419 fn test_anomaly_detector_with_insufficient_samples() {
420 let detector = AnomalyDetector::default();
421 let samples = vec![100, 110, 105];
422
423 assert!(!detector.is_anomalous(150, &samples));
425 assert!(detector.is_anomalous(300, &samples));
426 }
427
428 #[test]
429 fn test_anomaly_detector_with_sufficient_samples() {
430 let detector = AnomalyDetector::new(100.0, 3.0);
431 let samples: Vec<u64> = vec![95, 100, 105, 98, 102, 99, 101, 103, 97, 100, 104, 96];
432
433 assert!(!detector.is_anomalous(105, &samples));
435
436 assert!(detector.is_anomalous(500, &samples));
438 }
439
440 #[test]
441 fn test_anomaly_detector_update_baseline() {
442 let mut detector = AnomalyDetector::default();
443 let samples = vec![200, 210, 205, 215];
444
445 detector.update_baseline(&samples);
446 assert!((detector.baseline_avg_ms - 207.5).abs() < 0.1);
447 }
448
449 #[test]
450 fn test_performance_tracker_record() {
451 let tracker = PerformanceTracker::default();
452
453 tracker.record("SELECT", "users", 100);
454 tracker.record("SELECT", "users", 110);
455 tracker.record("SELECT", "users", 105);
456
457 let stats = tracker.get_stats("SELECT", "users").unwrap();
458 assert_eq!(stats.count, 3);
459 assert_eq!(stats.min_duration_ms, 100);
460 assert_eq!(stats.max_duration_ms, 110);
461 }
462
463 #[test]
464 fn test_performance_tracker_check_anomaly() {
465 let tracker = PerformanceTracker::default();
466
467 for i in 0..20 {
469 tracker.record("SELECT", "users", 95 + (i % 10) as u64);
470 }
471
472 assert!(!tracker.check_anomaly("SELECT", "users", 105));
474
475 assert!(tracker.check_anomaly("SELECT", "users", 1000));
477 }
478
479 #[test]
480 fn test_performance_tracker_max_samples() {
481 let tracker = PerformanceTracker::default();
482
483 for i in 0..150 {
485 tracker.record("SELECT", "users", i);
486 }
487
488 let stats = tracker.get_stats("SELECT", "users").unwrap();
489 assert_eq!(stats.count, 100); }
491
492 #[test]
493 fn test_query_log_entry_log() {
494 let entry = QueryLogEntry::new("SELECT".to_string(), "users".to_string(), 50);
495
496 entry.log();
498 }
499}