Skip to main content

mssql_client/
statement_cache.rs

1//! Prepared statement caching with LRU eviction.
2//!
3//! **Status: wired into the buffered `query` path behind the off-by-default
4//! [`Config::statement_cache`](crate::Config::statement_cache) flag.** When the
5//! flag is off (the default), parameterized queries use `sp_executesql` and
6//! this cache is never consulted. Some helper methods (`peek`, `remove`,
7//! `hit_ratio`, `reset_stats`, …) are kept for later increments and are not yet
8//! called, hence the `dead_code` allowance below. See LIMITATIONS.md §
9//! Prepared Statement Cache.
10#![allow(dead_code)]
11//!
12//! This module provides automatic caching of prepared statements to improve performance
13//! for repeated query execution. The cache uses an LRU (Least Recently Used) eviction
14//! policy to manage memory and server-side resources.
15//!
16//! ## Lifecycle
17//!
18//! 1. First execution of a parameterized query calls `sp_prepexec` (prepare +
19//!    execute in one round-trip), returning a handle
20//! 2. The handle is cached by SQL hash; subsequent executions use `sp_execute`
21//! 3. When the cache is full, LRU eviction calls `sp_unprepare` for evicted handles
22//! 4. Pool reset (`sp_reset_connection`) invalidates all handles, clearing the cache
23//! 5. Connection close implicitly releases all server-side handles
24
25// Allow expect() for NonZeroUsize construction with validated input
26#![allow(clippy::expect_used)]
27
28use std::collections::hash_map::DefaultHasher;
29use std::hash::{Hash, Hasher};
30use std::num::NonZeroUsize;
31use std::time::Instant;
32
33use lru::LruCache;
34
35/// Default maximum number of prepared statements to cache per connection.
36pub const DEFAULT_MAX_STATEMENTS: usize = 256;
37
38/// A cached prepared statement.
39///
40/// Contains the server-assigned handle and metadata needed for execution.
41#[derive(Debug, Clone)]
42pub struct PreparedStatement {
43    /// Server-assigned handle for this prepared statement.
44    handle: i32,
45    /// Hash of the SQL text for cache lookup.
46    sql_hash: u64,
47    /// The original SQL text (for debugging and logging).
48    sql: String,
49    /// Timestamp when this statement was prepared.
50    created_at: Instant,
51}
52
53impl PreparedStatement {
54    /// Create a new prepared statement.
55    pub fn new(handle: i32, sql: String) -> Self {
56        Self {
57            handle,
58            sql_hash: hash_sql(&sql),
59            sql,
60            created_at: Instant::now(),
61        }
62    }
63
64    /// Get the server-assigned handle.
65    #[must_use]
66    pub fn handle(&self) -> i32 {
67        self.handle
68    }
69
70    /// Get the SQL hash.
71    #[must_use]
72    pub fn sql_hash(&self) -> u64 {
73        self.sql_hash
74    }
75
76    /// Get the SQL text.
77    #[must_use]
78    pub fn sql(&self) -> &str {
79        &self.sql
80    }
81
82    /// Get the creation timestamp.
83    #[must_use]
84    pub fn created_at(&self) -> Instant {
85        self.created_at
86    }
87
88    /// Get the age of this statement.
89    #[must_use]
90    pub fn age(&self) -> std::time::Duration {
91        self.created_at.elapsed()
92    }
93}
94
95/// LRU cache for prepared statements.
96///
97/// This cache automatically evicts the least recently used statements when
98/// the maximum capacity is reached. Evicted statements should have their
99/// server-side handles released via `sp_unprepare`.
100pub struct StatementCache {
101    /// LRU cache of prepared statements keyed by SQL hash.
102    cache: LruCache<u64, PreparedStatement>,
103    /// Maximum number of cached statements.
104    max_size: usize,
105    /// Total number of cache hits (for metrics).
106    hits: u64,
107    /// Total number of cache misses (for metrics).
108    misses: u64,
109    /// Key of an in-flight `sp_prepexec` whose handle has not yet been read off
110    /// the response. Set by the send path on a cold miss, consumed by the read
111    /// path once the `@handle` RETURNVALUE arrives. See [`set_pending`](Self::set_pending).
112    pending_prepare: Option<String>,
113}
114
115impl StatementCache {
116    /// Create a new statement cache with the specified maximum size.
117    ///
118    /// # Panics
119    ///
120    /// Panics if `max_size` is 0.
121    #[must_use]
122    pub fn new(max_size: usize) -> Self {
123        assert!(max_size > 0, "max_size must be greater than 0");
124        Self {
125            cache: LruCache::new(NonZeroUsize::new(max_size).expect("max_size > 0")),
126            max_size,
127            hits: 0,
128            misses: 0,
129            pending_prepare: None,
130        }
131    }
132
133    /// Create a new statement cache with the default maximum size.
134    #[must_use]
135    pub fn with_default_size() -> Self {
136        Self::new(DEFAULT_MAX_STATEMENTS)
137    }
138
139    /// Look up a prepared statement by SQL text.
140    ///
141    /// Returns `Some(handle)` if the statement is cached, `None` otherwise.
142    /// This updates the LRU order.
143    pub fn get(&mut self, sql: &str) -> Option<i32> {
144        let hash = hash_sql(sql);
145        if let Some(stmt) = self.cache.get(&hash) {
146            self.hits += 1;
147            tracing::trace!(sql = sql, handle = stmt.handle, "statement cache hit");
148            Some(stmt.handle)
149        } else {
150            self.misses += 1;
151            tracing::trace!(sql = sql, "statement cache miss");
152            None
153        }
154    }
155
156    /// Record the cache key of an in-flight `sp_prepexec` whose handle will be
157    /// read off the execution response. Pass `None` on a cache hit (no handle
158    /// to capture) so a prior miss's key cannot linger.
159    ///
160    /// Internal plumbing for the cold-miss cache path — not part of the public
161    /// API.
162    pub(crate) fn set_pending(&mut self, key: Option<String>) {
163        self.pending_prepare = key;
164    }
165
166    /// Take the in-flight `sp_prepexec` key, if any, clearing it.
167    pub(crate) fn take_pending(&mut self) -> Option<String> {
168        self.pending_prepare.take()
169    }
170
171    /// Peek at a prepared statement without updating LRU order.
172    pub fn peek(&self, sql: &str) -> Option<&PreparedStatement> {
173        let hash = hash_sql(sql);
174        self.cache.peek(&hash)
175    }
176
177    /// Insert a prepared statement into the cache.
178    ///
179    /// Returns the evicted statement if one was removed due to capacity.
180    pub fn insert(&mut self, stmt: PreparedStatement) -> Option<PreparedStatement> {
181        let hash = stmt.sql_hash;
182        tracing::debug!(
183            sql = stmt.sql(),
184            handle = stmt.handle,
185            "caching prepared statement"
186        );
187
188        // Check if we need to evict
189        let evicted = if self.cache.len() >= self.max_size {
190            // Pop least recently used
191            self.cache.pop_lru().map(|(_, stmt)| stmt)
192        } else {
193            None
194        };
195
196        self.cache.put(hash, stmt);
197        evicted
198    }
199
200    /// Remove a prepared statement from the cache.
201    ///
202    /// Returns the removed statement if it was present.
203    pub fn remove(&mut self, sql: &str) -> Option<PreparedStatement> {
204        let hash = hash_sql(sql);
205        self.cache.pop(&hash)
206    }
207
208    /// Clear all cached statements.
209    ///
210    /// Returns an iterator over all removed statements.
211    /// The caller should call `sp_unprepare` for each returned statement.
212    pub fn clear(&mut self) -> impl Iterator<Item = PreparedStatement> + '_ {
213        let mut statements = Vec::with_capacity(self.cache.len());
214        while let Some((_, stmt)) = self.cache.pop_lru() {
215            statements.push(stmt);
216        }
217        tracing::debug!(count = statements.len(), "cleared statement cache");
218        statements.into_iter()
219    }
220
221    /// Get the number of cached statements.
222    #[must_use]
223    pub fn len(&self) -> usize {
224        self.cache.len()
225    }
226
227    /// Check if the cache is empty.
228    #[must_use]
229    pub fn is_empty(&self) -> bool {
230        self.cache.is_empty()
231    }
232
233    /// Get the maximum cache size.
234    #[must_use]
235    pub fn max_size(&self) -> usize {
236        self.max_size
237    }
238
239    /// Get the number of cache hits.
240    #[must_use]
241    pub fn hits(&self) -> u64 {
242        self.hits
243    }
244
245    /// Get the number of cache misses.
246    #[must_use]
247    pub fn misses(&self) -> u64 {
248        self.misses
249    }
250
251    /// Get the cache hit ratio (0.0 to 1.0).
252    #[must_use]
253    pub fn hit_ratio(&self) -> f64 {
254        let total = self.hits + self.misses;
255        if total == 0 {
256            0.0
257        } else {
258            self.hits as f64 / total as f64
259        }
260    }
261
262    /// Reset cache statistics.
263    pub fn reset_stats(&mut self) {
264        self.hits = 0;
265        self.misses = 0;
266    }
267
268    /// Take a point-in-time snapshot of the cache statistics.
269    #[must_use]
270    pub fn stats(&self) -> StatementCacheStats {
271        StatementCacheStats {
272            hits: self.hits,
273            misses: self.misses,
274            entries: self.cache.len(),
275        }
276    }
277}
278
279/// A point-in-time snapshot of a connection's prepared-statement cache.
280///
281/// Obtained via [`Client::statement_cache_stats`](crate::Client::statement_cache_stats).
282/// Useful for measuring the cache's effectiveness (the
283/// [`Config::statement_cache`](crate::Config::statement_cache) flag is opt-in
284/// precisely so these numbers can be gathered before defaulting it on).
285#[derive(Debug, Clone, Copy, PartialEq, Eq)]
286pub struct StatementCacheStats {
287    /// Number of lookups that found a cached handle.
288    pub hits: u64,
289    /// Number of lookups that missed (triggering an `sp_prepare`).
290    pub misses: u64,
291    /// Number of prepared statements currently cached.
292    pub entries: usize,
293}
294
295impl Default for StatementCache {
296    fn default() -> Self {
297        Self::with_default_size()
298    }
299}
300
301impl std::fmt::Debug for StatementCache {
302    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303        f.debug_struct("StatementCache")
304            .field("len", &self.cache.len())
305            .field("max_size", &self.max_size)
306            .field("hits", &self.hits)
307            .field("misses", &self.misses)
308            .finish()
309    }
310}
311
312/// Hash SQL text for cache lookup.
313///
314/// Uses a stable hash algorithm to ensure consistent lookups.
315#[must_use]
316pub fn hash_sql(sql: &str) -> u64 {
317    let mut hasher = DefaultHasher::new();
318    sql.hash(&mut hasher);
319    hasher.finish()
320}
321
322/// Configuration for statement caching.
323#[derive(Debug, Clone)]
324pub struct StatementCacheConfig {
325    /// Whether statement caching is enabled.
326    pub enabled: bool,
327    /// Maximum number of statements to cache.
328    pub max_size: usize,
329}
330
331impl Default for StatementCacheConfig {
332    fn default() -> Self {
333        Self {
334            enabled: true,
335            max_size: DEFAULT_MAX_STATEMENTS,
336        }
337    }
338}
339
340impl StatementCacheConfig {
341    /// Create a new configuration with caching disabled.
342    #[must_use]
343    pub fn disabled() -> Self {
344        Self {
345            enabled: false,
346            max_size: 0,
347        }
348    }
349
350    /// Create a new configuration with a custom max size.
351    #[must_use]
352    pub fn with_max_size(max_size: usize) -> Self {
353        Self {
354            enabled: true,
355            max_size,
356        }
357    }
358}
359
360#[cfg(test)]
361#[allow(clippy::unwrap_used)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn test_statement_cache_new() {
367        let cache = StatementCache::new(10);
368        assert_eq!(cache.max_size(), 10);
369        assert!(cache.is_empty());
370        assert_eq!(cache.len(), 0);
371    }
372
373    #[test]
374    fn test_statement_cache_insert_and_get() {
375        let mut cache = StatementCache::new(10);
376
377        let stmt = PreparedStatement::new(1, "SELECT * FROM users".to_string());
378        cache.insert(stmt);
379
380        assert_eq!(cache.len(), 1);
381        assert_eq!(cache.get("SELECT * FROM users"), Some(1));
382        assert_eq!(cache.hits(), 1);
383        assert_eq!(cache.misses(), 0);
384    }
385
386    #[test]
387    fn test_statement_cache_miss() {
388        let mut cache = StatementCache::new(10);
389
390        assert_eq!(cache.get("SELECT 1"), None);
391        assert_eq!(cache.misses(), 1);
392        assert_eq!(cache.hits(), 0);
393    }
394
395    #[test]
396    fn test_statement_cache_lru_eviction() {
397        let mut cache = StatementCache::new(2);
398
399        // Insert 2 statements
400        cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
401        cache.insert(PreparedStatement::new(2, "SELECT 2".to_string()));
402        assert_eq!(cache.len(), 2);
403
404        // Access the first statement to make it recently used
405        cache.get("SELECT 1");
406
407        // Insert a third statement - should evict "SELECT 2"
408        let evicted = cache.insert(PreparedStatement::new(3, "SELECT 3".to_string()));
409
410        assert!(evicted.is_some());
411        assert_eq!(evicted.unwrap().handle(), 2);
412        assert_eq!(cache.len(), 2);
413
414        // Verify "SELECT 1" is still cached (was accessed recently)
415        assert_eq!(cache.get("SELECT 1"), Some(1));
416        // Verify "SELECT 2" was evicted
417        assert_eq!(cache.get("SELECT 2"), None);
418        // Verify "SELECT 3" is cached
419        assert_eq!(cache.get("SELECT 3"), Some(3));
420    }
421
422    #[test]
423    fn test_statement_cache_clear() {
424        let mut cache = StatementCache::new(10);
425
426        cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
427        cache.insert(PreparedStatement::new(2, "SELECT 2".to_string()));
428
429        let cleared: Vec<_> = cache.clear().collect();
430        assert_eq!(cleared.len(), 2);
431        assert!(cache.is_empty());
432    }
433
434    #[test]
435    fn test_statement_cache_remove() {
436        let mut cache = StatementCache::new(10);
437
438        cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
439        assert_eq!(cache.len(), 1);
440
441        let removed = cache.remove("SELECT 1");
442        assert!(removed.is_some());
443        assert_eq!(removed.unwrap().handle(), 1);
444        assert!(cache.is_empty());
445    }
446
447    #[test]
448    fn test_statement_cache_hit_ratio() {
449        let mut cache = StatementCache::new(10);
450
451        cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
452
453        // 2 hits, 1 miss
454        cache.get("SELECT 1");
455        cache.get("SELECT 1");
456        cache.get("SELECT 2");
457
458        assert_eq!(cache.hits(), 2);
459        assert_eq!(cache.misses(), 1);
460        assert!((cache.hit_ratio() - 0.666666).abs() < 0.001);
461    }
462
463    #[test]
464    fn test_hash_sql_consistency() {
465        let sql = "SELECT * FROM users WHERE id = @p1";
466        let hash1 = hash_sql(sql);
467        let hash2 = hash_sql(sql);
468        assert_eq!(hash1, hash2);
469    }
470
471    #[test]
472    fn test_hash_sql_different() {
473        let hash1 = hash_sql("SELECT 1");
474        let hash2 = hash_sql("SELECT 2");
475        assert_ne!(hash1, hash2);
476    }
477
478    #[test]
479    fn test_prepared_statement_age() {
480        let stmt = PreparedStatement::new(1, "SELECT 1".to_string());
481        std::thread::sleep(std::time::Duration::from_millis(10));
482        assert!(stmt.age().as_millis() >= 10);
483    }
484
485    #[test]
486    fn test_statement_cache_config_default() {
487        let config = StatementCacheConfig::default();
488        assert!(config.enabled);
489        assert_eq!(config.max_size, DEFAULT_MAX_STATEMENTS);
490    }
491
492    #[test]
493    fn test_statement_cache_config_disabled() {
494        let config = StatementCacheConfig::disabled();
495        assert!(!config.enabled);
496    }
497}