1#![allow(dead_code)]
11#![allow(clippy::expect_used)]
26
27use std::collections::hash_map::DefaultHasher;
28use std::hash::{Hash, Hasher};
29use std::num::NonZeroUsize;
30use std::time::Instant;
31
32use lru::LruCache;
33
34pub const DEFAULT_MAX_STATEMENTS: usize = 256;
36
37#[derive(Debug, Clone)]
41pub struct PreparedStatement {
42 handle: i32,
44 sql_hash: u64,
46 sql: String,
48 created_at: Instant,
50}
51
52impl PreparedStatement {
53 pub fn new(handle: i32, sql: String) -> Self {
55 Self {
56 handle,
57 sql_hash: hash_sql(&sql),
58 sql,
59 created_at: Instant::now(),
60 }
61 }
62
63 #[must_use]
65 pub fn handle(&self) -> i32 {
66 self.handle
67 }
68
69 #[must_use]
71 pub fn sql_hash(&self) -> u64 {
72 self.sql_hash
73 }
74
75 #[must_use]
77 pub fn sql(&self) -> &str {
78 &self.sql
79 }
80
81 #[must_use]
83 pub fn created_at(&self) -> Instant {
84 self.created_at
85 }
86
87 #[must_use]
89 pub fn age(&self) -> std::time::Duration {
90 self.created_at.elapsed()
91 }
92}
93
94pub struct StatementCache {
100 cache: LruCache<u64, PreparedStatement>,
102 max_size: usize,
104 hits: u64,
106 misses: u64,
108 pending_prepare: Option<String>,
112}
113
114impl StatementCache {
115 #[must_use]
121 pub fn new(max_size: usize) -> Self {
122 assert!(max_size > 0, "max_size must be greater than 0");
123 Self {
124 cache: LruCache::new(NonZeroUsize::new(max_size).expect("max_size > 0")),
125 max_size,
126 hits: 0,
127 misses: 0,
128 pending_prepare: None,
129 }
130 }
131
132 #[must_use]
134 pub fn with_default_size() -> Self {
135 Self::new(DEFAULT_MAX_STATEMENTS)
136 }
137
138 pub fn get(&mut self, sql: &str) -> Option<i32> {
143 let hash = hash_sql(sql);
144 if let Some(stmt) = self.cache.get(&hash) {
145 self.hits += 1;
146 tracing::trace!(sql = sql, handle = stmt.handle, "statement cache hit");
147 Some(stmt.handle)
148 } else {
149 self.misses += 1;
150 tracing::trace!(sql = sql, "statement cache miss");
151 None
152 }
153 }
154
155 pub(crate) fn set_pending(&mut self, key: Option<String>) {
162 self.pending_prepare = key;
163 }
164
165 pub(crate) fn take_pending(&mut self) -> Option<String> {
167 self.pending_prepare.take()
168 }
169
170 pub fn peek(&self, sql: &str) -> Option<&PreparedStatement> {
172 let hash = hash_sql(sql);
173 self.cache.peek(&hash)
174 }
175
176 pub fn insert(&mut self, stmt: PreparedStatement) -> Option<PreparedStatement> {
180 let hash = stmt.sql_hash;
181 tracing::debug!(
182 sql = stmt.sql(),
183 handle = stmt.handle,
184 "caching prepared statement"
185 );
186
187 let evicted = if self.cache.len() >= self.max_size {
189 self.cache.pop_lru().map(|(_, stmt)| stmt)
191 } else {
192 None
193 };
194
195 self.cache.put(hash, stmt);
196 evicted
197 }
198
199 pub fn remove(&mut self, sql: &str) -> Option<PreparedStatement> {
203 let hash = hash_sql(sql);
204 self.cache.pop(&hash)
205 }
206
207 pub fn clear(&mut self) -> impl Iterator<Item = PreparedStatement> + '_ {
212 let mut statements = Vec::with_capacity(self.cache.len());
213 while let Some((_, stmt)) = self.cache.pop_lru() {
214 statements.push(stmt);
215 }
216 tracing::debug!(count = statements.len(), "cleared statement cache");
217 statements.into_iter()
218 }
219
220 #[must_use]
222 pub fn len(&self) -> usize {
223 self.cache.len()
224 }
225
226 #[must_use]
228 pub fn is_empty(&self) -> bool {
229 self.cache.is_empty()
230 }
231
232 #[must_use]
234 pub fn max_size(&self) -> usize {
235 self.max_size
236 }
237
238 #[must_use]
240 pub fn hits(&self) -> u64 {
241 self.hits
242 }
243
244 #[must_use]
246 pub fn misses(&self) -> u64 {
247 self.misses
248 }
249
250 #[must_use]
252 pub fn hit_ratio(&self) -> f64 {
253 let total = self.hits + self.misses;
254 if total == 0 {
255 0.0
256 } else {
257 self.hits as f64 / total as f64
258 }
259 }
260
261 pub fn reset_stats(&mut self) {
263 self.hits = 0;
264 self.misses = 0;
265 }
266
267 #[must_use]
269 pub fn stats(&self) -> StatementCacheStats {
270 StatementCacheStats {
271 hits: self.hits,
272 misses: self.misses,
273 entries: self.cache.len(),
274 }
275 }
276}
277
278#[derive(Debug, Clone, Copy, PartialEq, Eq)]
285pub struct StatementCacheStats {
286 pub hits: u64,
288 pub misses: u64,
290 pub entries: usize,
292}
293
294impl Default for StatementCache {
295 fn default() -> Self {
296 Self::with_default_size()
297 }
298}
299
300impl std::fmt::Debug for StatementCache {
301 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 f.debug_struct("StatementCache")
303 .field("len", &self.cache.len())
304 .field("max_size", &self.max_size)
305 .field("hits", &self.hits)
306 .field("misses", &self.misses)
307 .finish()
308 }
309}
310
311#[must_use]
315pub fn hash_sql(sql: &str) -> u64 {
316 let mut hasher = DefaultHasher::new();
317 sql.hash(&mut hasher);
318 hasher.finish()
319}
320
321#[derive(Debug, Clone)]
323pub struct StatementCacheConfig {
324 pub enabled: bool,
326 pub max_size: usize,
328}
329
330impl Default for StatementCacheConfig {
331 fn default() -> Self {
332 Self {
333 enabled: true,
334 max_size: DEFAULT_MAX_STATEMENTS,
335 }
336 }
337}
338
339impl StatementCacheConfig {
340 #[must_use]
342 pub fn disabled() -> Self {
343 Self {
344 enabled: false,
345 max_size: 0,
346 }
347 }
348
349 #[must_use]
351 pub fn with_max_size(max_size: usize) -> Self {
352 Self {
353 enabled: true,
354 max_size,
355 }
356 }
357}
358
359#[cfg(test)]
360#[allow(clippy::unwrap_used)]
361mod tests {
362 use super::*;
363
364 #[test]
365 fn test_statement_cache_new() {
366 let cache = StatementCache::new(10);
367 assert_eq!(cache.max_size(), 10);
368 assert!(cache.is_empty());
369 assert_eq!(cache.len(), 0);
370 }
371
372 #[test]
373 fn test_statement_cache_insert_and_get() {
374 let mut cache = StatementCache::new(10);
375
376 let stmt = PreparedStatement::new(1, "SELECT * FROM users".to_string());
377 cache.insert(stmt);
378
379 assert_eq!(cache.len(), 1);
380 assert_eq!(cache.get("SELECT * FROM users"), Some(1));
381 assert_eq!(cache.hits(), 1);
382 assert_eq!(cache.misses(), 0);
383 }
384
385 #[test]
386 fn test_statement_cache_miss() {
387 let mut cache = StatementCache::new(10);
388
389 assert_eq!(cache.get("SELECT 1"), None);
390 assert_eq!(cache.misses(), 1);
391 assert_eq!(cache.hits(), 0);
392 }
393
394 #[test]
395 fn test_statement_cache_lru_eviction() {
396 let mut cache = StatementCache::new(2);
397
398 cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
400 cache.insert(PreparedStatement::new(2, "SELECT 2".to_string()));
401 assert_eq!(cache.len(), 2);
402
403 cache.get("SELECT 1");
405
406 let evicted = cache.insert(PreparedStatement::new(3, "SELECT 3".to_string()));
408
409 assert!(evicted.is_some());
410 assert_eq!(evicted.unwrap().handle(), 2);
411 assert_eq!(cache.len(), 2);
412
413 assert_eq!(cache.get("SELECT 1"), Some(1));
415 assert_eq!(cache.get("SELECT 2"), None);
417 assert_eq!(cache.get("SELECT 3"), Some(3));
419 }
420
421 #[test]
422 fn test_statement_cache_clear() {
423 let mut cache = StatementCache::new(10);
424
425 cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
426 cache.insert(PreparedStatement::new(2, "SELECT 2".to_string()));
427
428 let cleared: Vec<_> = cache.clear().collect();
429 assert_eq!(cleared.len(), 2);
430 assert!(cache.is_empty());
431 }
432
433 #[test]
434 fn test_statement_cache_remove() {
435 let mut cache = StatementCache::new(10);
436
437 cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
438 assert_eq!(cache.len(), 1);
439
440 let removed = cache.remove("SELECT 1");
441 assert!(removed.is_some());
442 assert_eq!(removed.unwrap().handle(), 1);
443 assert!(cache.is_empty());
444 }
445
446 #[test]
447 fn test_statement_cache_hit_ratio() {
448 let mut cache = StatementCache::new(10);
449
450 cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
451
452 cache.get("SELECT 1");
454 cache.get("SELECT 1");
455 cache.get("SELECT 2");
456
457 assert_eq!(cache.hits(), 2);
458 assert_eq!(cache.misses(), 1);
459 assert!((cache.hit_ratio() - 0.666666).abs() < 0.001);
460 }
461
462 #[test]
463 fn test_hash_sql_consistency() {
464 let sql = "SELECT * FROM users WHERE id = @p1";
465 let hash1 = hash_sql(sql);
466 let hash2 = hash_sql(sql);
467 assert_eq!(hash1, hash2);
468 }
469
470 #[test]
471 fn test_hash_sql_different() {
472 let hash1 = hash_sql("SELECT 1");
473 let hash2 = hash_sql("SELECT 2");
474 assert_ne!(hash1, hash2);
475 }
476
477 #[test]
478 fn test_prepared_statement_age() {
479 let stmt = PreparedStatement::new(1, "SELECT 1".to_string());
480 std::thread::sleep(std::time::Duration::from_millis(10));
481 assert!(stmt.age().as_millis() >= 10);
482 }
483
484 #[test]
485 fn test_statement_cache_config_default() {
486 let config = StatementCacheConfig::default();
487 assert!(config.enabled);
488 assert_eq!(config.max_size, DEFAULT_MAX_STATEMENTS);
489 }
490
491 #[test]
492 fn test_statement_cache_config_disabled() {
493 let config = StatementCacheConfig::disabled();
494 assert!(!config.enabled);
495 }
496}