Skip to main content

mssql_client/
statement_cache.rs

1//! Prepared statement caching with LRU eviction.
2//!
3//! **Status: wired into the buffered `query` path behind the off-by-default
4//! [`Config::statement_cache`](crate::Config::statement_cache) flag.** When the
5//! flag is off (the default), parameterized queries use `sp_executesql` and
6//! this cache is never consulted. Some helper methods (`peek`, `remove`,
7//! `hit_ratio`, `reset_stats`, …) are kept for later increments and are not yet
8//! called, hence the `dead_code` allowance below. See LIMITATIONS.md §
9//! Prepared Statement Cache.
10#![allow(dead_code)]
11//!
12//! This module provides automatic caching of prepared statements to improve performance
13//! for repeated query execution. The cache uses an LRU (Least Recently Used) eviction
14//! policy to manage memory and server-side resources.
15//!
16//! ## Lifecycle
17//!
18//! 1. First execution of a parameterized query calls `sp_prepare`, returning a handle
19//! 2. The handle is cached by SQL hash; subsequent executions use `sp_execute`
20//! 3. When the cache is full, LRU eviction calls `sp_unprepare` for evicted handles
21//! 4. Pool reset (`sp_reset_connection`) invalidates all handles, clearing the cache
22//! 5. Connection close implicitly releases all server-side handles
23
24// Allow expect() for NonZeroUsize construction with validated input
25#![allow(clippy::expect_used)]
26
27use std::collections::hash_map::DefaultHasher;
28use std::hash::{Hash, Hasher};
29use std::num::NonZeroUsize;
30use std::time::Instant;
31
32use lru::LruCache;
33
34/// Default maximum number of prepared statements to cache per connection.
35pub const DEFAULT_MAX_STATEMENTS: usize = 256;
36
37/// A cached prepared statement.
38///
39/// Contains the server-assigned handle and metadata needed for execution.
40#[derive(Debug, Clone)]
41pub struct PreparedStatement {
42    /// Server-assigned handle for this prepared statement.
43    handle: i32,
44    /// Hash of the SQL text for cache lookup.
45    sql_hash: u64,
46    /// The original SQL text (for debugging and logging).
47    sql: String,
48    /// Timestamp when this statement was prepared.
49    created_at: Instant,
50}
51
52impl PreparedStatement {
53    /// Create a new prepared statement.
54    pub fn new(handle: i32, sql: String) -> Self {
55        Self {
56            handle,
57            sql_hash: hash_sql(&sql),
58            sql,
59            created_at: Instant::now(),
60        }
61    }
62
63    /// Get the server-assigned handle.
64    #[must_use]
65    pub fn handle(&self) -> i32 {
66        self.handle
67    }
68
69    /// Get the SQL hash.
70    #[must_use]
71    pub fn sql_hash(&self) -> u64 {
72        self.sql_hash
73    }
74
75    /// Get the SQL text.
76    #[must_use]
77    pub fn sql(&self) -> &str {
78        &self.sql
79    }
80
81    /// Get the creation timestamp.
82    #[must_use]
83    pub fn created_at(&self) -> Instant {
84        self.created_at
85    }
86
87    /// Get the age of this statement.
88    #[must_use]
89    pub fn age(&self) -> std::time::Duration {
90        self.created_at.elapsed()
91    }
92}
93
94/// LRU cache for prepared statements.
95///
96/// This cache automatically evicts the least recently used statements when
97/// the maximum capacity is reached. Evicted statements should have their
98/// server-side handles released via `sp_unprepare`.
99pub struct StatementCache {
100    /// LRU cache of prepared statements keyed by SQL hash.
101    cache: LruCache<u64, PreparedStatement>,
102    /// Maximum number of cached statements.
103    max_size: usize,
104    /// Total number of cache hits (for metrics).
105    hits: u64,
106    /// Total number of cache misses (for metrics).
107    misses: u64,
108    /// Key of an in-flight `sp_prepexec` whose handle has not yet been read off
109    /// the response. Set by the send path on a cold miss, consumed by the read
110    /// path once the `@handle` RETURNVALUE arrives. See [`set_pending`](Self::set_pending).
111    pending_prepare: Option<String>,
112}
113
114impl StatementCache {
115    /// Create a new statement cache with the specified maximum size.
116    ///
117    /// # Panics
118    ///
119    /// Panics if `max_size` is 0.
120    #[must_use]
121    pub fn new(max_size: usize) -> Self {
122        assert!(max_size > 0, "max_size must be greater than 0");
123        Self {
124            cache: LruCache::new(NonZeroUsize::new(max_size).expect("max_size > 0")),
125            max_size,
126            hits: 0,
127            misses: 0,
128            pending_prepare: None,
129        }
130    }
131
132    /// Create a new statement cache with the default maximum size.
133    #[must_use]
134    pub fn with_default_size() -> Self {
135        Self::new(DEFAULT_MAX_STATEMENTS)
136    }
137
138    /// Look up a prepared statement by SQL text.
139    ///
140    /// Returns `Some(handle)` if the statement is cached, `None` otherwise.
141    /// This updates the LRU order.
142    pub fn get(&mut self, sql: &str) -> Option<i32> {
143        let hash = hash_sql(sql);
144        if let Some(stmt) = self.cache.get(&hash) {
145            self.hits += 1;
146            tracing::trace!(sql = sql, handle = stmt.handle, "statement cache hit");
147            Some(stmt.handle)
148        } else {
149            self.misses += 1;
150            tracing::trace!(sql = sql, "statement cache miss");
151            None
152        }
153    }
154
155    /// Record the cache key of an in-flight `sp_prepexec` whose handle will be
156    /// read off the execution response. Pass `None` on a cache hit (no handle
157    /// to capture) so a prior miss's key cannot linger.
158    ///
159    /// Internal plumbing for the cold-miss cache path — not part of the public
160    /// API.
161    pub(crate) fn set_pending(&mut self, key: Option<String>) {
162        self.pending_prepare = key;
163    }
164
165    /// Take the in-flight `sp_prepexec` key, if any, clearing it.
166    pub(crate) fn take_pending(&mut self) -> Option<String> {
167        self.pending_prepare.take()
168    }
169
170    /// Peek at a prepared statement without updating LRU order.
171    pub fn peek(&self, sql: &str) -> Option<&PreparedStatement> {
172        let hash = hash_sql(sql);
173        self.cache.peek(&hash)
174    }
175
176    /// Insert a prepared statement into the cache.
177    ///
178    /// Returns the evicted statement if one was removed due to capacity.
179    pub fn insert(&mut self, stmt: PreparedStatement) -> Option<PreparedStatement> {
180        let hash = stmt.sql_hash;
181        tracing::debug!(
182            sql = stmt.sql(),
183            handle = stmt.handle,
184            "caching prepared statement"
185        );
186
187        // Check if we need to evict
188        let evicted = if self.cache.len() >= self.max_size {
189            // Pop least recently used
190            self.cache.pop_lru().map(|(_, stmt)| stmt)
191        } else {
192            None
193        };
194
195        self.cache.put(hash, stmt);
196        evicted
197    }
198
199    /// Remove a prepared statement from the cache.
200    ///
201    /// Returns the removed statement if it was present.
202    pub fn remove(&mut self, sql: &str) -> Option<PreparedStatement> {
203        let hash = hash_sql(sql);
204        self.cache.pop(&hash)
205    }
206
207    /// Clear all cached statements.
208    ///
209    /// Returns an iterator over all removed statements.
210    /// The caller should call `sp_unprepare` for each returned statement.
211    pub fn clear(&mut self) -> impl Iterator<Item = PreparedStatement> + '_ {
212        let mut statements = Vec::with_capacity(self.cache.len());
213        while let Some((_, stmt)) = self.cache.pop_lru() {
214            statements.push(stmt);
215        }
216        tracing::debug!(count = statements.len(), "cleared statement cache");
217        statements.into_iter()
218    }
219
220    /// Get the number of cached statements.
221    #[must_use]
222    pub fn len(&self) -> usize {
223        self.cache.len()
224    }
225
226    /// Check if the cache is empty.
227    #[must_use]
228    pub fn is_empty(&self) -> bool {
229        self.cache.is_empty()
230    }
231
232    /// Get the maximum cache size.
233    #[must_use]
234    pub fn max_size(&self) -> usize {
235        self.max_size
236    }
237
238    /// Get the number of cache hits.
239    #[must_use]
240    pub fn hits(&self) -> u64 {
241        self.hits
242    }
243
244    /// Get the number of cache misses.
245    #[must_use]
246    pub fn misses(&self) -> u64 {
247        self.misses
248    }
249
250    /// Get the cache hit ratio (0.0 to 1.0).
251    #[must_use]
252    pub fn hit_ratio(&self) -> f64 {
253        let total = self.hits + self.misses;
254        if total == 0 {
255            0.0
256        } else {
257            self.hits as f64 / total as f64
258        }
259    }
260
261    /// Reset cache statistics.
262    pub fn reset_stats(&mut self) {
263        self.hits = 0;
264        self.misses = 0;
265    }
266
267    /// Take a point-in-time snapshot of the cache statistics.
268    #[must_use]
269    pub fn stats(&self) -> StatementCacheStats {
270        StatementCacheStats {
271            hits: self.hits,
272            misses: self.misses,
273            entries: self.cache.len(),
274        }
275    }
276}
277
278/// A point-in-time snapshot of a connection's prepared-statement cache.
279///
280/// Obtained via [`Client::statement_cache_stats`](crate::Client::statement_cache_stats).
281/// Useful for measuring the cache's effectiveness (the
282/// [`Config::statement_cache`](crate::Config::statement_cache) flag is opt-in
283/// precisely so these numbers can be gathered before defaulting it on).
284#[derive(Debug, Clone, Copy, PartialEq, Eq)]
285pub struct StatementCacheStats {
286    /// Number of lookups that found a cached handle.
287    pub hits: u64,
288    /// Number of lookups that missed (triggering an `sp_prepare`).
289    pub misses: u64,
290    /// Number of prepared statements currently cached.
291    pub entries: usize,
292}
293
294impl Default for StatementCache {
295    fn default() -> Self {
296        Self::with_default_size()
297    }
298}
299
300impl std::fmt::Debug for StatementCache {
301    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302        f.debug_struct("StatementCache")
303            .field("len", &self.cache.len())
304            .field("max_size", &self.max_size)
305            .field("hits", &self.hits)
306            .field("misses", &self.misses)
307            .finish()
308    }
309}
310
311/// Hash SQL text for cache lookup.
312///
313/// Uses a stable hash algorithm to ensure consistent lookups.
314#[must_use]
315pub fn hash_sql(sql: &str) -> u64 {
316    let mut hasher = DefaultHasher::new();
317    sql.hash(&mut hasher);
318    hasher.finish()
319}
320
321/// Configuration for statement caching.
322#[derive(Debug, Clone)]
323pub struct StatementCacheConfig {
324    /// Whether statement caching is enabled.
325    pub enabled: bool,
326    /// Maximum number of statements to cache.
327    pub max_size: usize,
328}
329
330impl Default for StatementCacheConfig {
331    fn default() -> Self {
332        Self {
333            enabled: true,
334            max_size: DEFAULT_MAX_STATEMENTS,
335        }
336    }
337}
338
339impl StatementCacheConfig {
340    /// Create a new configuration with caching disabled.
341    #[must_use]
342    pub fn disabled() -> Self {
343        Self {
344            enabled: false,
345            max_size: 0,
346        }
347    }
348
349    /// Create a new configuration with a custom max size.
350    #[must_use]
351    pub fn with_max_size(max_size: usize) -> Self {
352        Self {
353            enabled: true,
354            max_size,
355        }
356    }
357}
358
359#[cfg(test)]
360#[allow(clippy::unwrap_used)]
361mod tests {
362    use super::*;
363
364    #[test]
365    fn test_statement_cache_new() {
366        let cache = StatementCache::new(10);
367        assert_eq!(cache.max_size(), 10);
368        assert!(cache.is_empty());
369        assert_eq!(cache.len(), 0);
370    }
371
372    #[test]
373    fn test_statement_cache_insert_and_get() {
374        let mut cache = StatementCache::new(10);
375
376        let stmt = PreparedStatement::new(1, "SELECT * FROM users".to_string());
377        cache.insert(stmt);
378
379        assert_eq!(cache.len(), 1);
380        assert_eq!(cache.get("SELECT * FROM users"), Some(1));
381        assert_eq!(cache.hits(), 1);
382        assert_eq!(cache.misses(), 0);
383    }
384
385    #[test]
386    fn test_statement_cache_miss() {
387        let mut cache = StatementCache::new(10);
388
389        assert_eq!(cache.get("SELECT 1"), None);
390        assert_eq!(cache.misses(), 1);
391        assert_eq!(cache.hits(), 0);
392    }
393
394    #[test]
395    fn test_statement_cache_lru_eviction() {
396        let mut cache = StatementCache::new(2);
397
398        // Insert 2 statements
399        cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
400        cache.insert(PreparedStatement::new(2, "SELECT 2".to_string()));
401        assert_eq!(cache.len(), 2);
402
403        // Access the first statement to make it recently used
404        cache.get("SELECT 1");
405
406        // Insert a third statement - should evict "SELECT 2"
407        let evicted = cache.insert(PreparedStatement::new(3, "SELECT 3".to_string()));
408
409        assert!(evicted.is_some());
410        assert_eq!(evicted.unwrap().handle(), 2);
411        assert_eq!(cache.len(), 2);
412
413        // Verify "SELECT 1" is still cached (was accessed recently)
414        assert_eq!(cache.get("SELECT 1"), Some(1));
415        // Verify "SELECT 2" was evicted
416        assert_eq!(cache.get("SELECT 2"), None);
417        // Verify "SELECT 3" is cached
418        assert_eq!(cache.get("SELECT 3"), Some(3));
419    }
420
421    #[test]
422    fn test_statement_cache_clear() {
423        let mut cache = StatementCache::new(10);
424
425        cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
426        cache.insert(PreparedStatement::new(2, "SELECT 2".to_string()));
427
428        let cleared: Vec<_> = cache.clear().collect();
429        assert_eq!(cleared.len(), 2);
430        assert!(cache.is_empty());
431    }
432
433    #[test]
434    fn test_statement_cache_remove() {
435        let mut cache = StatementCache::new(10);
436
437        cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
438        assert_eq!(cache.len(), 1);
439
440        let removed = cache.remove("SELECT 1");
441        assert!(removed.is_some());
442        assert_eq!(removed.unwrap().handle(), 1);
443        assert!(cache.is_empty());
444    }
445
446    #[test]
447    fn test_statement_cache_hit_ratio() {
448        let mut cache = StatementCache::new(10);
449
450        cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
451
452        // 2 hits, 1 miss
453        cache.get("SELECT 1");
454        cache.get("SELECT 1");
455        cache.get("SELECT 2");
456
457        assert_eq!(cache.hits(), 2);
458        assert_eq!(cache.misses(), 1);
459        assert!((cache.hit_ratio() - 0.666666).abs() < 0.001);
460    }
461
462    #[test]
463    fn test_hash_sql_consistency() {
464        let sql = "SELECT * FROM users WHERE id = @p1";
465        let hash1 = hash_sql(sql);
466        let hash2 = hash_sql(sql);
467        assert_eq!(hash1, hash2);
468    }
469
470    #[test]
471    fn test_hash_sql_different() {
472        let hash1 = hash_sql("SELECT 1");
473        let hash2 = hash_sql("SELECT 2");
474        assert_ne!(hash1, hash2);
475    }
476
477    #[test]
478    fn test_prepared_statement_age() {
479        let stmt = PreparedStatement::new(1, "SELECT 1".to_string());
480        std::thread::sleep(std::time::Duration::from_millis(10));
481        assert!(stmt.age().as_millis() >= 10);
482    }
483
484    #[test]
485    fn test_statement_cache_config_default() {
486        let config = StatementCacheConfig::default();
487        assert!(config.enabled);
488        assert_eq!(config.max_size, DEFAULT_MAX_STATEMENTS);
489    }
490
491    #[test]
492    fn test_statement_cache_config_disabled() {
493        let config = StatementCacheConfig::disabled();
494        assert!(!config.enabled);
495    }
496}