mssql_client/
statement_cache.rs

1//! Prepared statement caching with LRU eviction.
2//!
3//! This module provides automatic caching of prepared statements to improve performance
4//! for repeated query execution. The cache uses an LRU (Least Recently Used) eviction
5//! policy to manage memory and server-side resources.
6//!
7//! ## Lifecycle
8//!
9//! 1. First execution of a parameterized query calls `sp_prepare`, returning a handle
10//! 2. The handle is cached by SQL hash; subsequent executions use `sp_execute`
11//! 3. When the cache is full, LRU eviction calls `sp_unprepare` for evicted handles
12//! 4. Pool reset (`sp_reset_connection`) invalidates all handles, clearing the cache
13//! 5. Connection close implicitly releases all server-side handles
14
15// Allow expect() for NonZeroUsize construction with validated input
16#![allow(clippy::expect_used)]
17
18use std::collections::hash_map::DefaultHasher;
19use std::hash::{Hash, Hasher};
20use std::num::NonZeroUsize;
21use std::time::Instant;
22
23use lru::LruCache;
24
25/// Default maximum number of prepared statements to cache per connection.
26pub const DEFAULT_MAX_STATEMENTS: usize = 256;
27
28/// A cached prepared statement.
29///
30/// Contains the server-assigned handle and metadata needed for execution.
31#[derive(Debug, Clone)]
32pub struct PreparedStatement {
33    /// Server-assigned handle for this prepared statement.
34    handle: i32,
35    /// Hash of the SQL text for cache lookup.
36    sql_hash: u64,
37    /// The original SQL text (for debugging and logging).
38    sql: String,
39    /// Timestamp when this statement was prepared.
40    created_at: Instant,
41}
42
43impl PreparedStatement {
44    /// Create a new prepared statement.
45    pub fn new(handle: i32, sql: String) -> Self {
46        Self {
47            handle,
48            sql_hash: hash_sql(&sql),
49            sql,
50            created_at: Instant::now(),
51        }
52    }
53
54    /// Get the server-assigned handle.
55    #[must_use]
56    pub fn handle(&self) -> i32 {
57        self.handle
58    }
59
60    /// Get the SQL hash.
61    #[must_use]
62    pub fn sql_hash(&self) -> u64 {
63        self.sql_hash
64    }
65
66    /// Get the SQL text.
67    #[must_use]
68    pub fn sql(&self) -> &str {
69        &self.sql
70    }
71
72    /// Get the creation timestamp.
73    #[must_use]
74    pub fn created_at(&self) -> Instant {
75        self.created_at
76    }
77
78    /// Get the age of this statement.
79    #[must_use]
80    pub fn age(&self) -> std::time::Duration {
81        self.created_at.elapsed()
82    }
83}
84
85/// LRU cache for prepared statements.
86///
87/// This cache automatically evicts the least recently used statements when
88/// the maximum capacity is reached. Evicted statements should have their
89/// server-side handles released via `sp_unprepare`.
90pub struct StatementCache {
91    /// LRU cache of prepared statements keyed by SQL hash.
92    cache: LruCache<u64, PreparedStatement>,
93    /// Maximum number of cached statements.
94    max_size: usize,
95    /// Total number of cache hits (for metrics).
96    hits: u64,
97    /// Total number of cache misses (for metrics).
98    misses: u64,
99}
100
101impl StatementCache {
102    /// Create a new statement cache with the specified maximum size.
103    ///
104    /// # Panics
105    ///
106    /// Panics if `max_size` is 0.
107    #[must_use]
108    pub fn new(max_size: usize) -> Self {
109        assert!(max_size > 0, "max_size must be greater than 0");
110        Self {
111            cache: LruCache::new(NonZeroUsize::new(max_size).expect("max_size > 0")),
112            max_size,
113            hits: 0,
114            misses: 0,
115        }
116    }
117
118    /// Create a new statement cache with the default maximum size.
119    #[must_use]
120    pub fn with_default_size() -> Self {
121        Self::new(DEFAULT_MAX_STATEMENTS)
122    }
123
124    /// Look up a prepared statement by SQL text.
125    ///
126    /// Returns `Some(handle)` if the statement is cached, `None` otherwise.
127    /// This updates the LRU order.
128    pub fn get(&mut self, sql: &str) -> Option<i32> {
129        let hash = hash_sql(sql);
130        if let Some(stmt) = self.cache.get(&hash) {
131            self.hits += 1;
132            tracing::trace!(sql = sql, handle = stmt.handle, "statement cache hit");
133            Some(stmt.handle)
134        } else {
135            self.misses += 1;
136            tracing::trace!(sql = sql, "statement cache miss");
137            None
138        }
139    }
140
141    /// Peek at a prepared statement without updating LRU order.
142    pub fn peek(&self, sql: &str) -> Option<&PreparedStatement> {
143        let hash = hash_sql(sql);
144        self.cache.peek(&hash)
145    }
146
147    /// Insert a prepared statement into the cache.
148    ///
149    /// Returns the evicted statement if one was removed due to capacity.
150    pub fn insert(&mut self, stmt: PreparedStatement) -> Option<PreparedStatement> {
151        let hash = stmt.sql_hash;
152        tracing::debug!(
153            sql = stmt.sql(),
154            handle = stmt.handle,
155            "caching prepared statement"
156        );
157
158        // Check if we need to evict
159        let evicted = if self.cache.len() >= self.max_size {
160            // Pop least recently used
161            self.cache.pop_lru().map(|(_, stmt)| stmt)
162        } else {
163            None
164        };
165
166        self.cache.put(hash, stmt);
167        evicted
168    }
169
170    /// Remove a prepared statement from the cache.
171    ///
172    /// Returns the removed statement if it was present.
173    pub fn remove(&mut self, sql: &str) -> Option<PreparedStatement> {
174        let hash = hash_sql(sql);
175        self.cache.pop(&hash)
176    }
177
178    /// Clear all cached statements.
179    ///
180    /// Returns an iterator over all removed statements.
181    /// The caller should call `sp_unprepare` for each returned statement.
182    pub fn clear(&mut self) -> impl Iterator<Item = PreparedStatement> + '_ {
183        let mut statements = Vec::with_capacity(self.cache.len());
184        while let Some((_, stmt)) = self.cache.pop_lru() {
185            statements.push(stmt);
186        }
187        tracing::debug!(count = statements.len(), "cleared statement cache");
188        statements.into_iter()
189    }
190
191    /// Get the number of cached statements.
192    #[must_use]
193    pub fn len(&self) -> usize {
194        self.cache.len()
195    }
196
197    /// Check if the cache is empty.
198    #[must_use]
199    pub fn is_empty(&self) -> bool {
200        self.cache.is_empty()
201    }
202
203    /// Get the maximum cache size.
204    #[must_use]
205    pub fn max_size(&self) -> usize {
206        self.max_size
207    }
208
209    /// Get the number of cache hits.
210    #[must_use]
211    pub fn hits(&self) -> u64 {
212        self.hits
213    }
214
215    /// Get the number of cache misses.
216    #[must_use]
217    pub fn misses(&self) -> u64 {
218        self.misses
219    }
220
221    /// Get the cache hit ratio (0.0 to 1.0).
222    #[must_use]
223    pub fn hit_ratio(&self) -> f64 {
224        let total = self.hits + self.misses;
225        if total == 0 {
226            0.0
227        } else {
228            self.hits as f64 / total as f64
229        }
230    }
231
232    /// Reset cache statistics.
233    pub fn reset_stats(&mut self) {
234        self.hits = 0;
235        self.misses = 0;
236    }
237}
238
239impl Default for StatementCache {
240    fn default() -> Self {
241        Self::with_default_size()
242    }
243}
244
245impl std::fmt::Debug for StatementCache {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        f.debug_struct("StatementCache")
248            .field("len", &self.cache.len())
249            .field("max_size", &self.max_size)
250            .field("hits", &self.hits)
251            .field("misses", &self.misses)
252            .finish()
253    }
254}
255
256/// Hash SQL text for cache lookup.
257///
258/// Uses a stable hash algorithm to ensure consistent lookups.
259#[must_use]
260pub fn hash_sql(sql: &str) -> u64 {
261    let mut hasher = DefaultHasher::new();
262    sql.hash(&mut hasher);
263    hasher.finish()
264}
265
266/// Configuration for statement caching.
267#[derive(Debug, Clone)]
268pub struct StatementCacheConfig {
269    /// Whether statement caching is enabled.
270    pub enabled: bool,
271    /// Maximum number of statements to cache.
272    pub max_size: usize,
273}
274
275impl Default for StatementCacheConfig {
276    fn default() -> Self {
277        Self {
278            enabled: true,
279            max_size: DEFAULT_MAX_STATEMENTS,
280        }
281    }
282}
283
284impl StatementCacheConfig {
285    /// Create a new configuration with caching disabled.
286    #[must_use]
287    pub fn disabled() -> Self {
288        Self {
289            enabled: false,
290            max_size: 0,
291        }
292    }
293
294    /// Create a new configuration with a custom max size.
295    #[must_use]
296    pub fn with_max_size(max_size: usize) -> Self {
297        Self {
298            enabled: true,
299            max_size,
300        }
301    }
302}
303
304#[cfg(test)]
305#[allow(clippy::unwrap_used)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_statement_cache_new() {
311        let cache = StatementCache::new(10);
312        assert_eq!(cache.max_size(), 10);
313        assert!(cache.is_empty());
314        assert_eq!(cache.len(), 0);
315    }
316
317    #[test]
318    fn test_statement_cache_insert_and_get() {
319        let mut cache = StatementCache::new(10);
320
321        let stmt = PreparedStatement::new(1, "SELECT * FROM users".to_string());
322        cache.insert(stmt);
323
324        assert_eq!(cache.len(), 1);
325        assert_eq!(cache.get("SELECT * FROM users"), Some(1));
326        assert_eq!(cache.hits(), 1);
327        assert_eq!(cache.misses(), 0);
328    }
329
330    #[test]
331    fn test_statement_cache_miss() {
332        let mut cache = StatementCache::new(10);
333
334        assert_eq!(cache.get("SELECT 1"), None);
335        assert_eq!(cache.misses(), 1);
336        assert_eq!(cache.hits(), 0);
337    }
338
339    #[test]
340    fn test_statement_cache_lru_eviction() {
341        let mut cache = StatementCache::new(2);
342
343        // Insert 2 statements
344        cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
345        cache.insert(PreparedStatement::new(2, "SELECT 2".to_string()));
346        assert_eq!(cache.len(), 2);
347
348        // Access the first statement to make it recently used
349        cache.get("SELECT 1");
350
351        // Insert a third statement - should evict "SELECT 2"
352        let evicted = cache.insert(PreparedStatement::new(3, "SELECT 3".to_string()));
353
354        assert!(evicted.is_some());
355        assert_eq!(evicted.unwrap().handle(), 2);
356        assert_eq!(cache.len(), 2);
357
358        // Verify "SELECT 1" is still cached (was accessed recently)
359        assert_eq!(cache.get("SELECT 1"), Some(1));
360        // Verify "SELECT 2" was evicted
361        assert_eq!(cache.get("SELECT 2"), None);
362        // Verify "SELECT 3" is cached
363        assert_eq!(cache.get("SELECT 3"), Some(3));
364    }
365
366    #[test]
367    fn test_statement_cache_clear() {
368        let mut cache = StatementCache::new(10);
369
370        cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
371        cache.insert(PreparedStatement::new(2, "SELECT 2".to_string()));
372
373        let cleared: Vec<_> = cache.clear().collect();
374        assert_eq!(cleared.len(), 2);
375        assert!(cache.is_empty());
376    }
377
378    #[test]
379    fn test_statement_cache_remove() {
380        let mut cache = StatementCache::new(10);
381
382        cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
383        assert_eq!(cache.len(), 1);
384
385        let removed = cache.remove("SELECT 1");
386        assert!(removed.is_some());
387        assert_eq!(removed.unwrap().handle(), 1);
388        assert!(cache.is_empty());
389    }
390
391    #[test]
392    fn test_statement_cache_hit_ratio() {
393        let mut cache = StatementCache::new(10);
394
395        cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
396
397        // 2 hits, 1 miss
398        cache.get("SELECT 1");
399        cache.get("SELECT 1");
400        cache.get("SELECT 2");
401
402        assert_eq!(cache.hits(), 2);
403        assert_eq!(cache.misses(), 1);
404        assert!((cache.hit_ratio() - 0.666666).abs() < 0.001);
405    }
406
407    #[test]
408    fn test_hash_sql_consistency() {
409        let sql = "SELECT * FROM users WHERE id = @p1";
410        let hash1 = hash_sql(sql);
411        let hash2 = hash_sql(sql);
412        assert_eq!(hash1, hash2);
413    }
414
415    #[test]
416    fn test_hash_sql_different() {
417        let hash1 = hash_sql("SELECT 1");
418        let hash2 = hash_sql("SELECT 2");
419        assert_ne!(hash1, hash2);
420    }
421
422    #[test]
423    fn test_prepared_statement_age() {
424        let stmt = PreparedStatement::new(1, "SELECT 1".to_string());
425        std::thread::sleep(std::time::Duration::from_millis(10));
426        assert!(stmt.age().as_millis() >= 10);
427    }
428
429    #[test]
430    fn test_statement_cache_config_default() {
431        let config = StatementCacheConfig::default();
432        assert!(config.enabled);
433        assert_eq!(config.max_size, DEFAULT_MAX_STATEMENTS);
434    }
435
436    #[test]
437    fn test_statement_cache_config_disabled() {
438        let config = StatementCacheConfig::disabled();
439        assert!(!config.enabled);
440    }
441}