kaccy_db/
statement_cache.rs

1//! Prepared statement caching for improved query performance.
2//!
3//! This module provides:
4//! - Statement pool with LRU eviction
5//! - Query fingerprinting for cache key generation
6//! - Statement reuse metrics
7
8use parking_lot::RwLock;
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11use std::collections::{HashMap, VecDeque};
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tracing::{debug, info};
15
16/// Configuration for statement cache
17#[derive(Debug, Clone)]
18pub struct StatementCacheConfig {
19    /// Maximum number of cached statements
20    pub max_size: usize,
21    /// Time-to-live for cached statements (in seconds)
22    pub ttl_secs: u64,
23    /// Whether to enable query fingerprinting
24    pub enable_fingerprinting: bool,
25}
26
27impl Default for StatementCacheConfig {
28    fn default() -> Self {
29        Self {
30            max_size: 1000,
31            ttl_secs: 3600, // 1 hour
32            enable_fingerprinting: true,
33        }
34    }
35}
36
37/// Cached statement metadata
38#[derive(Debug, Clone)]
39struct CachedStatement {
40    /// SQL statement text
41    sql: String,
42    /// Query fingerprint (normalized SQL)
43    #[allow(dead_code)]
44    fingerprint: String,
45    /// Timestamp when statement was cached
46    cached_at: Instant,
47    /// Number of times this statement was reused
48    reuse_count: u64,
49    /// Last access time
50    last_accessed: Instant,
51}
52
53impl CachedStatement {
54    fn new(sql: String, fingerprint: String) -> Self {
55        let now = Instant::now();
56        Self {
57            sql,
58            fingerprint,
59            cached_at: now,
60            reuse_count: 0,
61            last_accessed: now,
62        }
63    }
64
65    fn is_expired(&self, ttl: Duration) -> bool {
66        self.cached_at.elapsed() > ttl
67    }
68
69    fn touch(&mut self) {
70        self.reuse_count += 1;
71        self.last_accessed = Instant::now();
72    }
73}
74
75/// Statement cache with LRU eviction
76#[derive(Debug, Clone)]
77pub struct StatementCache {
78    config: StatementCacheConfig,
79    cache: Arc<RwLock<HashMap<String, CachedStatement>>>,
80    lru_queue: Arc<RwLock<VecDeque<String>>>,
81    stats: Arc<RwLock<CacheStats>>,
82}
83
84#[derive(Debug, Clone, Default, Serialize, Deserialize)]
85pub struct CacheStats {
86    /// Number of cache hits
87    pub hits: u64,
88    /// Number of cache misses
89    pub misses: u64,
90    /// Number of evictions
91    pub evictions: u64,
92    /// Number of expired entries removed
93    pub expirations: u64,
94    /// Total statements in cache
95    pub cached_statements: usize,
96}
97
98impl CacheStats {
99    /// Calculate cache hit rate
100    pub fn hit_rate(&self) -> f64 {
101        let total = self.hits + self.misses;
102        if total == 0 {
103            0.0
104        } else {
105            self.hits as f64 / total as f64
106        }
107    }
108
109    /// Reset statistics
110    pub fn reset(&mut self) {
111        self.hits = 0;
112        self.misses = 0;
113        self.evictions = 0;
114        self.expirations = 0;
115    }
116}
117
118impl Default for StatementCache {
119    fn default() -> Self {
120        Self::new(StatementCacheConfig::default())
121    }
122}
123
124impl StatementCache {
125    /// Create a new statement cache
126    pub fn new(config: StatementCacheConfig) -> Self {
127        Self {
128            config,
129            cache: Arc::new(RwLock::new(HashMap::new())),
130            lru_queue: Arc::new(RwLock::new(VecDeque::new())),
131            stats: Arc::new(RwLock::new(CacheStats::default())),
132        }
133    }
134
135    /// Generate a fingerprint for a SQL query
136    pub fn fingerprint(&self, sql: &str) -> String {
137        if !self.config.enable_fingerprinting {
138            return sql.to_string();
139        }
140
141        // Normalize SQL for fingerprinting
142        let normalized = normalize_sql(sql);
143
144        // Generate SHA256 hash
145        let mut hasher = Sha256::new();
146        hasher.update(normalized.as_bytes());
147        let result = hasher.finalize();
148
149        format!("{:x}", result)
150    }
151
152    /// Get a cached statement
153    pub fn get(&self, sql: &str) -> Option<String> {
154        let fingerprint = self.fingerprint(sql);
155
156        let mut cache = self.cache.write();
157        let mut stats = self.stats.write();
158
159        if let Some(stmt) = cache.get_mut(&fingerprint) {
160            // Check if expired
161            if stmt.is_expired(Duration::from_secs(self.config.ttl_secs)) {
162                cache.remove(&fingerprint);
163                self.remove_from_lru(&fingerprint);
164                stats.expirations += 1;
165                stats.misses += 1;
166                debug!(fingerprint = %fingerprint, "Statement expired");
167                return None;
168            }
169
170            stmt.touch();
171            stats.hits += 1;
172
173            // Move to front of LRU queue
174            self.update_lru(&fingerprint);
175
176            debug!(
177                fingerprint = %fingerprint,
178                reuse_count = stmt.reuse_count,
179                "Statement cache hit"
180            );
181
182            Some(stmt.sql.clone())
183        } else {
184            stats.misses += 1;
185            debug!(fingerprint = %fingerprint, "Statement cache miss");
186            None
187        }
188    }
189
190    /// Cache a statement
191    pub fn put(&self, sql: String) {
192        let fingerprint = self.fingerprint(&sql);
193
194        let mut cache = self.cache.write();
195        let mut stats = self.stats.write();
196
197        // Evict if cache is full
198        while cache.len() >= self.config.max_size {
199            if let Some(oldest) = self.lru_queue.write().pop_back() {
200                cache.remove(&oldest);
201                stats.evictions += 1;
202                debug!(fingerprint = %oldest, "Statement evicted from cache");
203            } else {
204                break;
205            }
206        }
207
208        // Add to cache
209        let stmt = CachedStatement::new(sql, fingerprint.clone());
210        cache.insert(fingerprint.clone(), stmt);
211
212        // Add to LRU queue
213        self.lru_queue.write().push_front(fingerprint.clone());
214
215        stats.cached_statements = cache.len();
216
217        debug!(
218            fingerprint = %fingerprint,
219            cache_size = cache.len(),
220            "Statement cached"
221        );
222    }
223
224    /// Clear the cache
225    pub fn clear(&self) {
226        self.cache.write().clear();
227        self.lru_queue.write().clear();
228        self.stats.write().cached_statements = 0;
229
230        info!("Statement cache cleared");
231    }
232
233    /// Get cache statistics
234    pub fn stats(&self) -> CacheStats {
235        self.stats.read().clone()
236    }
237
238    /// Remove expired entries
239    pub fn remove_expired(&self) {
240        let ttl = Duration::from_secs(self.config.ttl_secs);
241        let mut cache = self.cache.write();
242        let mut stats = self.stats.write();
243
244        let expired: Vec<String> = cache
245            .iter()
246            .filter(|(_, stmt)| stmt.is_expired(ttl))
247            .map(|(fp, _)| fp.clone())
248            .collect();
249
250        for fp in &expired {
251            cache.remove(fp);
252            self.remove_from_lru(fp);
253            stats.expirations += 1;
254        }
255
256        stats.cached_statements = cache.len();
257
258        if !expired.is_empty() {
259            info!(count = expired.len(), "Removed expired statements");
260        }
261    }
262
263    /// Update LRU queue (move to front)
264    fn update_lru(&self, fingerprint: &str) {
265        let mut queue = self.lru_queue.write();
266
267        // Remove from current position
268        if let Some(pos) = queue.iter().position(|fp| fp == fingerprint) {
269            queue.remove(pos);
270        }
271
272        // Add to front
273        queue.push_front(fingerprint.to_string());
274    }
275
276    /// Remove from LRU queue
277    fn remove_from_lru(&self, fingerprint: &str) {
278        let mut queue = self.lru_queue.write();
279        if let Some(pos) = queue.iter().position(|fp| fp == fingerprint) {
280            queue.remove(pos);
281        }
282    }
283}
284
285/// Normalize SQL for fingerprinting
286fn normalize_sql(sql: &str) -> String {
287    // Remove extra whitespace and normalize to lowercase
288    let normalized = sql
289        .split_whitespace()
290        .collect::<Vec<_>>()
291        .join(" ")
292        .to_lowercase();
293
294    // Replace literal values with placeholders
295    // This is a simple approach; a more sophisticated version would use a SQL parser
296
297    replace_literals(&normalized)
298}
299
300/// Replace literal values with placeholders
301fn replace_literals(sql: &str) -> String {
302    // Replace string literals
303    let re_string = regex::Regex::new(r"'[^']*'").unwrap();
304    let sql = re_string.replace_all(sql, "?");
305
306    // Replace numeric literals
307    let re_number = regex::Regex::new(r"\b\d+\b").unwrap();
308    let sql = re_number.replace_all(&sql, "?");
309
310    sql.to_string()
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_statement_cache_config_default() {
319        let config = StatementCacheConfig::default();
320        assert_eq!(config.max_size, 1000);
321        assert_eq!(config.ttl_secs, 3600);
322        assert!(config.enable_fingerprinting);
323    }
324
325    #[test]
326    fn test_fingerprint_generation() {
327        let cache = StatementCache::default();
328
329        let sql1 = "SELECT * FROM users WHERE id = 1";
330        let sql2 = "SELECT * FROM users WHERE id = 2";
331
332        let fp1 = cache.fingerprint(sql1);
333        let fp2 = cache.fingerprint(sql2);
334
335        // Should generate the same fingerprint (normalized)
336        assert_eq!(fp1, fp2);
337    }
338
339    #[test]
340    fn test_cache_put_and_get() {
341        let cache = StatementCache::default();
342        let sql = "SELECT * FROM users WHERE id = 1".to_string();
343
344        cache.put(sql.clone());
345
346        let result = cache.get(&sql);
347        assert!(result.is_some());
348        assert_eq!(result.unwrap(), sql);
349    }
350
351    #[test]
352    fn test_cache_miss() {
353        let cache = StatementCache::default();
354        let sql = "SELECT * FROM users WHERE id = 1";
355
356        let result = cache.get(sql);
357        assert!(result.is_none());
358    }
359
360    #[test]
361    fn test_cache_stats() {
362        let cache = StatementCache::default();
363        let sql = "SELECT * FROM users WHERE id = 1".to_string();
364
365        // Miss
366        cache.get(&sql);
367
368        // Put
369        cache.put(sql.clone());
370
371        // Hit
372        cache.get(&sql);
373        cache.get(&sql);
374
375        let stats = cache.stats();
376        assert_eq!(stats.hits, 2);
377        assert_eq!(stats.misses, 1);
378        assert_eq!(stats.cached_statements, 1);
379    }
380
381    #[test]
382    fn test_cache_hit_rate() {
383        let stats = CacheStats {
384            hits: 80,
385            misses: 20,
386            evictions: 0,
387            expirations: 0,
388            cached_statements: 100,
389        };
390
391        assert!((stats.hit_rate() - 0.8).abs() < 0.001);
392    }
393
394    #[test]
395    fn test_cache_eviction() {
396        let config = StatementCacheConfig {
397            max_size: 3,
398            enable_fingerprinting: false, // Disable fingerprinting for this test
399            ..Default::default()
400        };
401
402        let cache = StatementCache::new(config);
403
404        cache.put("SELECT 1".to_string());
405        cache.put("SELECT 2".to_string());
406        cache.put("SELECT 3".to_string());
407        cache.put("SELECT 4".to_string()); // Should evict oldest
408
409        let stats = cache.stats();
410        assert_eq!(stats.cached_statements, 3);
411        assert_eq!(stats.evictions, 1);
412    }
413
414    #[test]
415    fn test_cache_clear() {
416        let cache = StatementCache::default();
417
418        cache.put("SELECT 1".to_string());
419        cache.put("SELECT 2".to_string());
420
421        cache.clear();
422
423        let stats = cache.stats();
424        assert_eq!(stats.cached_statements, 0);
425    }
426
427    #[test]
428    fn test_normalize_sql() {
429        let sql = "SELECT  *  FROM  users  WHERE  id  =  1";
430        let normalized = normalize_sql(sql);
431
432        assert_eq!(normalized, "select * from users where id = ?");
433    }
434
435    #[test]
436    fn test_replace_literals() {
437        let sql = "SELECT * FROM users WHERE id = 123 AND name = 'john'";
438        let replaced = replace_literals(sql);
439
440        assert!(replaced.contains("?"));
441        assert!(!replaced.contains("123"));
442        assert!(!replaced.contains("john"));
443    }
444
445    #[test]
446    fn test_statement_reuse_count() {
447        let cache = StatementCache::default();
448        let sql = "SELECT * FROM users WHERE id = 1".to_string();
449
450        cache.put(sql.clone());
451
452        // Access multiple times
453        for _ in 0..5 {
454            cache.get(&sql);
455        }
456
457        let fingerprint = cache.fingerprint(&sql);
458        let cached_cache = cache.cache.read();
459        let stmt = cached_cache.get(&fingerprint).unwrap();
460
461        assert_eq!(stmt.reuse_count, 5);
462    }
463
464    #[test]
465    fn test_stats_reset() {
466        let mut stats = CacheStats {
467            hits: 100,
468            misses: 50,
469            evictions: 10,
470            expirations: 5,
471            cached_statements: 200,
472        };
473
474        stats.reset();
475
476        assert_eq!(stats.hits, 0);
477        assert_eq!(stats.misses, 0);
478        assert_eq!(stats.evictions, 0);
479        assert_eq!(stats.expirations, 0);
480    }
481}