1use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11use std::collections::{HashMap, VecDeque};
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tracing::{debug, info};
15
16#[derive(Debug, Clone)]
18pub struct StatementCacheConfig {
19 pub max_size: usize,
21 pub ttl_secs: u64,
23 pub enable_fingerprinting: bool,
25}
26
27impl Default for StatementCacheConfig {
28 fn default() -> Self {
29 Self {
30 max_size: 1000,
31 ttl_secs: 3600, enable_fingerprinting: true,
33 }
34 }
35}
36
37#[derive(Debug, Clone)]
39struct CachedStatement {
40 sql: String,
42 #[allow(dead_code)]
44 fingerprint: String,
45 cached_at: Instant,
47 reuse_count: u64,
49 last_accessed: Instant,
51}
52
53impl CachedStatement {
54 fn new(sql: String, fingerprint: String) -> Self {
55 let now = Instant::now();
56 Self {
57 sql,
58 fingerprint,
59 cached_at: now,
60 reuse_count: 0,
61 last_accessed: now,
62 }
63 }
64
65 fn is_expired(&self, ttl: Duration) -> bool {
66 self.cached_at.elapsed() > ttl
67 }
68
69 fn touch(&mut self) {
70 self.reuse_count += 1;
71 self.last_accessed = Instant::now();
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct StatementCache {
78 config: StatementCacheConfig,
79 cache: Arc<RwLock<HashMap<String, CachedStatement>>>,
80 lru_queue: Arc<RwLock<VecDeque<String>>>,
81 stats: Arc<RwLock<CacheStats>>,
82}
83
84#[derive(Debug, Clone, Default, Serialize, Deserialize)]
85pub struct CacheStats {
86 pub hits: u64,
88 pub misses: u64,
90 pub evictions: u64,
92 pub expirations: u64,
94 pub cached_statements: usize,
96}
97
98impl CacheStats {
99 pub fn hit_rate(&self) -> f64 {
101 let total = self.hits + self.misses;
102 if total == 0 {
103 0.0
104 } else {
105 self.hits as f64 / total as f64
106 }
107 }
108
109 pub fn reset(&mut self) {
111 self.hits = 0;
112 self.misses = 0;
113 self.evictions = 0;
114 self.expirations = 0;
115 }
116}
117
118impl Default for StatementCache {
119 fn default() -> Self {
120 Self::new(StatementCacheConfig::default())
121 }
122}
123
124impl StatementCache {
125 pub fn new(config: StatementCacheConfig) -> Self {
127 Self {
128 config,
129 cache: Arc::new(RwLock::new(HashMap::new())),
130 lru_queue: Arc::new(RwLock::new(VecDeque::new())),
131 stats: Arc::new(RwLock::new(CacheStats::default())),
132 }
133 }
134
135 pub fn fingerprint(&self, sql: &str) -> String {
137 if !self.config.enable_fingerprinting {
138 return sql.to_string();
139 }
140
141 let normalized = normalize_sql(sql);
143
144 let mut hasher = Sha256::new();
146 hasher.update(normalized.as_bytes());
147 let result = hasher.finalize();
148
149 format!("{:x}", result)
150 }
151
152 pub fn get(&self, sql: &str) -> Option<String> {
154 let fingerprint = self.fingerprint(sql);
155
156 let mut cache = self.cache.write();
157 let mut stats = self.stats.write();
158
159 if let Some(stmt) = cache.get_mut(&fingerprint) {
160 if stmt.is_expired(Duration::from_secs(self.config.ttl_secs)) {
162 cache.remove(&fingerprint);
163 self.remove_from_lru(&fingerprint);
164 stats.expirations += 1;
165 stats.misses += 1;
166 debug!(fingerprint = %fingerprint, "Statement expired");
167 return None;
168 }
169
170 stmt.touch();
171 stats.hits += 1;
172
173 self.update_lru(&fingerprint);
175
176 debug!(
177 fingerprint = %fingerprint,
178 reuse_count = stmt.reuse_count,
179 "Statement cache hit"
180 );
181
182 Some(stmt.sql.clone())
183 } else {
184 stats.misses += 1;
185 debug!(fingerprint = %fingerprint, "Statement cache miss");
186 None
187 }
188 }
189
190 pub fn put(&self, sql: String) {
192 let fingerprint = self.fingerprint(&sql);
193
194 let mut cache = self.cache.write();
195 let mut stats = self.stats.write();
196
197 while cache.len() >= self.config.max_size {
199 if let Some(oldest) = self.lru_queue.write().pop_back() {
200 cache.remove(&oldest);
201 stats.evictions += 1;
202 debug!(fingerprint = %oldest, "Statement evicted from cache");
203 } else {
204 break;
205 }
206 }
207
208 let stmt = CachedStatement::new(sql, fingerprint.clone());
210 cache.insert(fingerprint.clone(), stmt);
211
212 self.lru_queue.write().push_front(fingerprint.clone());
214
215 stats.cached_statements = cache.len();
216
217 debug!(
218 fingerprint = %fingerprint,
219 cache_size = cache.len(),
220 "Statement cached"
221 );
222 }
223
224 pub fn clear(&self) {
226 self.cache.write().clear();
227 self.lru_queue.write().clear();
228 self.stats.write().cached_statements = 0;
229
230 info!("Statement cache cleared");
231 }
232
233 pub fn stats(&self) -> CacheStats {
235 self.stats.read().clone()
236 }
237
238 pub fn remove_expired(&self) {
240 let ttl = Duration::from_secs(self.config.ttl_secs);
241 let mut cache = self.cache.write();
242 let mut stats = self.stats.write();
243
244 let expired: Vec<String> = cache
245 .iter()
246 .filter(|(_, stmt)| stmt.is_expired(ttl))
247 .map(|(fp, _)| fp.clone())
248 .collect();
249
250 for fp in &expired {
251 cache.remove(fp);
252 self.remove_from_lru(fp);
253 stats.expirations += 1;
254 }
255
256 stats.cached_statements = cache.len();
257
258 if !expired.is_empty() {
259 info!(count = expired.len(), "Removed expired statements");
260 }
261 }
262
263 fn update_lru(&self, fingerprint: &str) {
265 let mut queue = self.lru_queue.write();
266
267 if let Some(pos) = queue.iter().position(|fp| fp == fingerprint) {
269 queue.remove(pos);
270 }
271
272 queue.push_front(fingerprint.to_string());
274 }
275
276 fn remove_from_lru(&self, fingerprint: &str) {
278 let mut queue = self.lru_queue.write();
279 if let Some(pos) = queue.iter().position(|fp| fp == fingerprint) {
280 queue.remove(pos);
281 }
282 }
283}
284
285fn normalize_sql(sql: &str) -> String {
287 let normalized = sql
289 .split_whitespace()
290 .collect::<Vec<_>>()
291 .join(" ")
292 .to_lowercase();
293
294 replace_literals(&normalized)
298}
299
300fn replace_literals(sql: &str) -> String {
302 let re_string = regex::Regex::new(r"'[^']*'").unwrap();
304 let sql = re_string.replace_all(sql, "?");
305
306 let re_number = regex::Regex::new(r"\b\d+\b").unwrap();
308 let sql = re_number.replace_all(&sql, "?");
309
310 sql.to_string()
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_statement_cache_config_default() {
319 let config = StatementCacheConfig::default();
320 assert_eq!(config.max_size, 1000);
321 assert_eq!(config.ttl_secs, 3600);
322 assert!(config.enable_fingerprinting);
323 }
324
325 #[test]
326 fn test_fingerprint_generation() {
327 let cache = StatementCache::default();
328
329 let sql1 = "SELECT * FROM users WHERE id = 1";
330 let sql2 = "SELECT * FROM users WHERE id = 2";
331
332 let fp1 = cache.fingerprint(sql1);
333 let fp2 = cache.fingerprint(sql2);
334
335 assert_eq!(fp1, fp2);
337 }
338
339 #[test]
340 fn test_cache_put_and_get() {
341 let cache = StatementCache::default();
342 let sql = "SELECT * FROM users WHERE id = 1".to_string();
343
344 cache.put(sql.clone());
345
346 let result = cache.get(&sql);
347 assert!(result.is_some());
348 assert_eq!(result.unwrap(), sql);
349 }
350
351 #[test]
352 fn test_cache_miss() {
353 let cache = StatementCache::default();
354 let sql = "SELECT * FROM users WHERE id = 1";
355
356 let result = cache.get(sql);
357 assert!(result.is_none());
358 }
359
360 #[test]
361 fn test_cache_stats() {
362 let cache = StatementCache::default();
363 let sql = "SELECT * FROM users WHERE id = 1".to_string();
364
365 cache.get(&sql);
367
368 cache.put(sql.clone());
370
371 cache.get(&sql);
373 cache.get(&sql);
374
375 let stats = cache.stats();
376 assert_eq!(stats.hits, 2);
377 assert_eq!(stats.misses, 1);
378 assert_eq!(stats.cached_statements, 1);
379 }
380
381 #[test]
382 fn test_cache_hit_rate() {
383 let stats = CacheStats {
384 hits: 80,
385 misses: 20,
386 evictions: 0,
387 expirations: 0,
388 cached_statements: 100,
389 };
390
391 assert!((stats.hit_rate() - 0.8).abs() < 0.001);
392 }
393
394 #[test]
395 fn test_cache_eviction() {
396 let config = StatementCacheConfig {
397 max_size: 3,
398 enable_fingerprinting: false, ..Default::default()
400 };
401
402 let cache = StatementCache::new(config);
403
404 cache.put("SELECT 1".to_string());
405 cache.put("SELECT 2".to_string());
406 cache.put("SELECT 3".to_string());
407 cache.put("SELECT 4".to_string()); let stats = cache.stats();
410 assert_eq!(stats.cached_statements, 3);
411 assert_eq!(stats.evictions, 1);
412 }
413
414 #[test]
415 fn test_cache_clear() {
416 let cache = StatementCache::default();
417
418 cache.put("SELECT 1".to_string());
419 cache.put("SELECT 2".to_string());
420
421 cache.clear();
422
423 let stats = cache.stats();
424 assert_eq!(stats.cached_statements, 0);
425 }
426
427 #[test]
428 fn test_normalize_sql() {
429 let sql = "SELECT * FROM users WHERE id = 1";
430 let normalized = normalize_sql(sql);
431
432 assert_eq!(normalized, "select * from users where id = ?");
433 }
434
435 #[test]
436 fn test_replace_literals() {
437 let sql = "SELECT * FROM users WHERE id = 123 AND name = 'john'";
438 let replaced = replace_literals(sql);
439
440 assert!(replaced.contains("?"));
441 assert!(!replaced.contains("123"));
442 assert!(!replaced.contains("john"));
443 }
444
445 #[test]
446 fn test_statement_reuse_count() {
447 let cache = StatementCache::default();
448 let sql = "SELECT * FROM users WHERE id = 1".to_string();
449
450 cache.put(sql.clone());
451
452 for _ in 0..5 {
454 cache.get(&sql);
455 }
456
457 let fingerprint = cache.fingerprint(&sql);
458 let cached_cache = cache.cache.read();
459 let stmt = cached_cache.get(&fingerprint).unwrap();
460
461 assert_eq!(stmt.reuse_count, 5);
462 }
463
464 #[test]
465 fn test_stats_reset() {
466 let mut stats = CacheStats {
467 hits: 100,
468 misses: 50,
469 evictions: 10,
470 expirations: 5,
471 cached_statements: 200,
472 };
473
474 stats.reset();
475
476 assert_eq!(stats.hits, 0);
477 assert_eq!(stats.misses, 0);
478 assert_eq!(stats.evictions, 0);
479 assert_eq!(stats.expirations, 0);
480 }
481}