1#![allow(dead_code)]
11#![allow(clippy::expect_used)]
27
28use std::collections::hash_map::DefaultHasher;
29use std::hash::{Hash, Hasher};
30use std::num::NonZeroUsize;
31use std::time::Instant;
32
33use lru::LruCache;
34
35pub const DEFAULT_MAX_STATEMENTS: usize = 256;
37
38#[derive(Debug, Clone)]
42pub struct PreparedStatement {
43 handle: i32,
45 sql_hash: u64,
47 sql: String,
49 created_at: Instant,
51}
52
53impl PreparedStatement {
54 pub fn new(handle: i32, sql: String) -> Self {
56 Self {
57 handle,
58 sql_hash: hash_sql(&sql),
59 sql,
60 created_at: Instant::now(),
61 }
62 }
63
64 #[must_use]
66 pub fn handle(&self) -> i32 {
67 self.handle
68 }
69
70 #[must_use]
72 pub fn sql_hash(&self) -> u64 {
73 self.sql_hash
74 }
75
76 #[must_use]
78 pub fn sql(&self) -> &str {
79 &self.sql
80 }
81
82 #[must_use]
84 pub fn created_at(&self) -> Instant {
85 self.created_at
86 }
87
88 #[must_use]
90 pub fn age(&self) -> std::time::Duration {
91 self.created_at.elapsed()
92 }
93}
94
95pub struct StatementCache {
101 cache: LruCache<u64, PreparedStatement>,
103 max_size: usize,
105 hits: u64,
107 misses: u64,
109 pending_prepare: Option<String>,
113}
114
115impl StatementCache {
116 #[must_use]
122 pub fn new(max_size: usize) -> Self {
123 assert!(max_size > 0, "max_size must be greater than 0");
124 Self {
125 cache: LruCache::new(NonZeroUsize::new(max_size).expect("max_size > 0")),
126 max_size,
127 hits: 0,
128 misses: 0,
129 pending_prepare: None,
130 }
131 }
132
133 #[must_use]
135 pub fn with_default_size() -> Self {
136 Self::new(DEFAULT_MAX_STATEMENTS)
137 }
138
139 pub fn get(&mut self, sql: &str) -> Option<i32> {
144 let hash = hash_sql(sql);
145 if let Some(stmt) = self.cache.get(&hash) {
146 self.hits += 1;
147 tracing::trace!(sql = sql, handle = stmt.handle, "statement cache hit");
148 Some(stmt.handle)
149 } else {
150 self.misses += 1;
151 tracing::trace!(sql = sql, "statement cache miss");
152 None
153 }
154 }
155
156 pub(crate) fn set_pending(&mut self, key: Option<String>) {
163 self.pending_prepare = key;
164 }
165
166 pub(crate) fn take_pending(&mut self) -> Option<String> {
168 self.pending_prepare.take()
169 }
170
171 pub fn peek(&self, sql: &str) -> Option<&PreparedStatement> {
173 let hash = hash_sql(sql);
174 self.cache.peek(&hash)
175 }
176
177 pub fn insert(&mut self, stmt: PreparedStatement) -> Option<PreparedStatement> {
181 let hash = stmt.sql_hash;
182 tracing::debug!(
183 sql = stmt.sql(),
184 handle = stmt.handle,
185 "caching prepared statement"
186 );
187
188 let evicted = if self.cache.len() >= self.max_size {
190 self.cache.pop_lru().map(|(_, stmt)| stmt)
192 } else {
193 None
194 };
195
196 self.cache.put(hash, stmt);
197 evicted
198 }
199
200 pub fn remove(&mut self, sql: &str) -> Option<PreparedStatement> {
204 let hash = hash_sql(sql);
205 self.cache.pop(&hash)
206 }
207
208 pub fn clear(&mut self) -> impl Iterator<Item = PreparedStatement> + '_ {
213 let mut statements = Vec::with_capacity(self.cache.len());
214 while let Some((_, stmt)) = self.cache.pop_lru() {
215 statements.push(stmt);
216 }
217 tracing::debug!(count = statements.len(), "cleared statement cache");
218 statements.into_iter()
219 }
220
221 #[must_use]
223 pub fn len(&self) -> usize {
224 self.cache.len()
225 }
226
227 #[must_use]
229 pub fn is_empty(&self) -> bool {
230 self.cache.is_empty()
231 }
232
233 #[must_use]
235 pub fn max_size(&self) -> usize {
236 self.max_size
237 }
238
239 #[must_use]
241 pub fn hits(&self) -> u64 {
242 self.hits
243 }
244
245 #[must_use]
247 pub fn misses(&self) -> u64 {
248 self.misses
249 }
250
251 #[must_use]
253 pub fn hit_ratio(&self) -> f64 {
254 let total = self.hits + self.misses;
255 if total == 0 {
256 0.0
257 } else {
258 self.hits as f64 / total as f64
259 }
260 }
261
262 pub fn reset_stats(&mut self) {
264 self.hits = 0;
265 self.misses = 0;
266 }
267
268 #[must_use]
270 pub fn stats(&self) -> StatementCacheStats {
271 StatementCacheStats {
272 hits: self.hits,
273 misses: self.misses,
274 entries: self.cache.len(),
275 }
276 }
277}
278
279#[derive(Debug, Clone, Copy, PartialEq, Eq)]
286pub struct StatementCacheStats {
287 pub hits: u64,
289 pub misses: u64,
291 pub entries: usize,
293}
294
295impl Default for StatementCache {
296 fn default() -> Self {
297 Self::with_default_size()
298 }
299}
300
301impl std::fmt::Debug for StatementCache {
302 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303 f.debug_struct("StatementCache")
304 .field("len", &self.cache.len())
305 .field("max_size", &self.max_size)
306 .field("hits", &self.hits)
307 .field("misses", &self.misses)
308 .finish()
309 }
310}
311
312#[must_use]
316pub fn hash_sql(sql: &str) -> u64 {
317 let mut hasher = DefaultHasher::new();
318 sql.hash(&mut hasher);
319 hasher.finish()
320}
321
322#[derive(Debug, Clone)]
324pub struct StatementCacheConfig {
325 pub enabled: bool,
327 pub max_size: usize,
329}
330
331impl Default for StatementCacheConfig {
332 fn default() -> Self {
333 Self {
334 enabled: true,
335 max_size: DEFAULT_MAX_STATEMENTS,
336 }
337 }
338}
339
340impl StatementCacheConfig {
341 #[must_use]
343 pub fn disabled() -> Self {
344 Self {
345 enabled: false,
346 max_size: 0,
347 }
348 }
349
350 #[must_use]
352 pub fn with_max_size(max_size: usize) -> Self {
353 Self {
354 enabled: true,
355 max_size,
356 }
357 }
358}
359
360#[cfg(test)]
361#[allow(clippy::unwrap_used)]
362mod tests {
363 use super::*;
364
365 #[test]
366 fn test_statement_cache_new() {
367 let cache = StatementCache::new(10);
368 assert_eq!(cache.max_size(), 10);
369 assert!(cache.is_empty());
370 assert_eq!(cache.len(), 0);
371 }
372
373 #[test]
374 fn test_statement_cache_insert_and_get() {
375 let mut cache = StatementCache::new(10);
376
377 let stmt = PreparedStatement::new(1, "SELECT * FROM users".to_string());
378 cache.insert(stmt);
379
380 assert_eq!(cache.len(), 1);
381 assert_eq!(cache.get("SELECT * FROM users"), Some(1));
382 assert_eq!(cache.hits(), 1);
383 assert_eq!(cache.misses(), 0);
384 }
385
386 #[test]
387 fn test_statement_cache_miss() {
388 let mut cache = StatementCache::new(10);
389
390 assert_eq!(cache.get("SELECT 1"), None);
391 assert_eq!(cache.misses(), 1);
392 assert_eq!(cache.hits(), 0);
393 }
394
395 #[test]
396 fn test_statement_cache_lru_eviction() {
397 let mut cache = StatementCache::new(2);
398
399 cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
401 cache.insert(PreparedStatement::new(2, "SELECT 2".to_string()));
402 assert_eq!(cache.len(), 2);
403
404 cache.get("SELECT 1");
406
407 let evicted = cache.insert(PreparedStatement::new(3, "SELECT 3".to_string()));
409
410 assert!(evicted.is_some());
411 assert_eq!(evicted.unwrap().handle(), 2);
412 assert_eq!(cache.len(), 2);
413
414 assert_eq!(cache.get("SELECT 1"), Some(1));
416 assert_eq!(cache.get("SELECT 2"), None);
418 assert_eq!(cache.get("SELECT 3"), Some(3));
420 }
421
422 #[test]
423 fn test_statement_cache_clear() {
424 let mut cache = StatementCache::new(10);
425
426 cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
427 cache.insert(PreparedStatement::new(2, "SELECT 2".to_string()));
428
429 let cleared: Vec<_> = cache.clear().collect();
430 assert_eq!(cleared.len(), 2);
431 assert!(cache.is_empty());
432 }
433
434 #[test]
435 fn test_statement_cache_remove() {
436 let mut cache = StatementCache::new(10);
437
438 cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
439 assert_eq!(cache.len(), 1);
440
441 let removed = cache.remove("SELECT 1");
442 assert!(removed.is_some());
443 assert_eq!(removed.unwrap().handle(), 1);
444 assert!(cache.is_empty());
445 }
446
447 #[test]
448 fn test_statement_cache_hit_ratio() {
449 let mut cache = StatementCache::new(10);
450
451 cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
452
453 cache.get("SELECT 1");
455 cache.get("SELECT 1");
456 cache.get("SELECT 2");
457
458 assert_eq!(cache.hits(), 2);
459 assert_eq!(cache.misses(), 1);
460 assert!((cache.hit_ratio() - 0.666666).abs() < 0.001);
461 }
462
463 #[test]
464 fn test_hash_sql_consistency() {
465 let sql = "SELECT * FROM users WHERE id = @p1";
466 let hash1 = hash_sql(sql);
467 let hash2 = hash_sql(sql);
468 assert_eq!(hash1, hash2);
469 }
470
471 #[test]
472 fn test_hash_sql_different() {
473 let hash1 = hash_sql("SELECT 1");
474 let hash2 = hash_sql("SELECT 2");
475 assert_ne!(hash1, hash2);
476 }
477
478 #[test]
479 fn test_prepared_statement_age() {
480 let stmt = PreparedStatement::new(1, "SELECT 1".to_string());
481 std::thread::sleep(std::time::Duration::from_millis(10));
482 assert!(stmt.age().as_millis() >= 10);
483 }
484
485 #[test]
486 fn test_statement_cache_config_default() {
487 let config = StatementCacheConfig::default();
488 assert!(config.enabled);
489 assert_eq!(config.max_size, DEFAULT_MAX_STATEMENTS);
490 }
491
492 #[test]
493 fn test_statement_cache_config_disabled() {
494 let config = StatementCacheConfig::disabled();
495 assert!(!config.enabled);
496 }
497}