Skip to main content

sqlmodel_query/
cache.rs

1//! Statement caching for compiled SQL queries.
2//!
3//! Caches compiled SQL strings keyed by a hash so repeated queries
4//! avoid redundant string building.
5
6use std::collections::HashMap;
7use std::hash::{Hash, Hasher};
8use std::time::Instant;
9
10/// A cached compiled SQL statement.
11#[derive(Debug, Clone)]
12pub struct CachedStatement {
13    /// The compiled SQL string.
14    pub sql: String,
15    /// When this entry was last accessed.
16    pub last_used: Instant,
17    /// Number of times this statement has been reused.
18    pub hit_count: u64,
19}
20
21/// LRU-style cache for compiled SQL statements.
22///
23/// Keyed by a `u64` hash that callers compute from their query structure.
24/// When the cache exceeds `max_size`, the least-recently-used entry is evicted.
25///
26/// # Example
27///
28/// ```
29/// use sqlmodel_query::cache::StatementCache;
30///
31/// let mut cache = StatementCache::new(100);
32///
33/// // Cache a compiled query
34/// let sql = cache.get_or_insert(12345, || "SELECT * FROM users WHERE id = $1".to_string());
35/// assert_eq!(sql, "SELECT * FROM users WHERE id = $1");
36///
37/// // Second call returns cached version
38/// let called = std::cell::Cell::new(false);
39/// let sql2 = cache.get_or_insert(12345, || {
40///     called.set(true);
41///     "SELECT * FROM users WHERE id = $1".to_string()
42/// });
43/// assert_eq!(sql2, "SELECT * FROM users WHERE id = $1");
44/// assert!(!called.get());
45/// ```
46#[derive(Debug)]
47pub struct StatementCache {
48    cache: HashMap<u64, CachedStatement>,
49    max_size: usize,
50}
51
52impl StatementCache {
53    /// Create a new cache with the given maximum number of entries.
54    pub fn new(max_size: usize) -> Self {
55        Self {
56            cache: HashMap::with_capacity(max_size.min(256)),
57            max_size,
58        }
59    }
60
61    /// Get a cached statement or compile and insert it.
62    ///
63    /// The `builder` closure is only called on cache miss.
64    pub fn get_or_insert(&mut self, key: u64, builder: impl FnOnce() -> String) -> &str {
65        // Check if we need to evict before inserting
66        if !self.cache.contains_key(&key) && self.cache.len() >= self.max_size {
67            self.evict_lru();
68        }
69
70        let entry = self.cache.entry(key).or_insert_with(|| CachedStatement {
71            sql: builder(),
72            last_used: Instant::now(),
73            hit_count: 0,
74        });
75        entry.last_used = Instant::now();
76        entry.hit_count += 1;
77        &entry.sql
78    }
79
80    /// Check if a statement is cached.
81    pub fn contains(&self, key: u64) -> bool {
82        self.cache.contains_key(&key)
83    }
84
85    /// Get cache statistics.
86    pub fn len(&self) -> usize {
87        self.cache.len()
88    }
89
90    /// Check if cache is empty.
91    pub fn is_empty(&self) -> bool {
92        self.cache.is_empty()
93    }
94
95    /// Clear all cached statements.
96    pub fn clear(&mut self) {
97        self.cache.clear();
98    }
99
100    /// Evict the least-recently-used entry.
101    fn evict_lru(&mut self) {
102        if let Some((&lru_key, _)) = self.cache.iter().min_by_key(|(_, entry)| entry.last_used) {
103            self.cache.remove(&lru_key);
104        }
105    }
106}
107
108/// Compute a hash key for caching from any hashable value.
109///
110/// Useful for creating cache keys from query components.
111pub fn cache_key(value: &impl Hash) -> u64 {
112    let mut hasher = std::collections::hash_map::DefaultHasher::new();
113    value.hash(&mut hasher);
114    hasher.finish()
115}
116
117impl Default for StatementCache {
118    fn default() -> Self {
119        Self::new(1024)
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn test_cache_hit() {
129        let mut cache = StatementCache::new(10);
130        let sql = cache
131            .get_or_insert(1, || "SELECT 1".to_string())
132            .to_string();
133        assert_eq!(sql, "SELECT 1");
134
135        // Should return cached value
136        let called = std::cell::Cell::new(false);
137        let sql2 = cache
138            .get_or_insert(1, || {
139                called.set(true);
140                "SELECT 1".to_string()
141            })
142            .to_string();
143        assert_eq!(sql2, "SELECT 1");
144        assert!(!called.get());
145    }
146
147    #[test]
148    fn test_cache_miss() {
149        let mut cache = StatementCache::new(10);
150        let sql1 = cache
151            .get_or_insert(1, || "SELECT 1".to_string())
152            .to_string();
153        let sql2 = cache
154            .get_or_insert(2, || "SELECT 2".to_string())
155            .to_string();
156        assert_eq!(sql1, "SELECT 1");
157        assert_eq!(sql2, "SELECT 2");
158        assert_eq!(cache.len(), 2);
159    }
160
161    #[test]
162    fn test_eviction() {
163        let mut cache = StatementCache::new(2);
164        cache.get_or_insert(1, || "SELECT 1".to_string());
165        cache.get_or_insert(2, || "SELECT 2".to_string());
166        // This should evict key 1 (LRU)
167        cache.get_or_insert(3, || "SELECT 3".to_string());
168
169        assert_eq!(cache.len(), 2);
170        assert!(!cache.contains(1));
171        assert!(cache.contains(2));
172        assert!(cache.contains(3));
173    }
174
175    #[test]
176    fn test_lru_ordering() {
177        let mut cache = StatementCache::new(2);
178        cache.get_or_insert(1, || "SELECT 1".to_string());
179        cache.get_or_insert(2, || "SELECT 2".to_string());
180
181        // Access key 1 to make it recently used
182        let called = std::cell::Cell::new(false);
183        cache.get_or_insert(1, || {
184            called.set(true);
185            "SELECT 1".to_string()
186        });
187        assert!(!called.get());
188
189        // Eviction should remove key 2 (now LRU)
190        cache.get_or_insert(3, || "SELECT 3".to_string());
191
192        assert!(cache.contains(1));
193        assert!(!cache.contains(2));
194        assert!(cache.contains(3));
195    }
196
197    #[test]
198    fn test_cache_key_function() {
199        let key1 = cache_key(&"SELECT * FROM users");
200        let key2 = cache_key(&"SELECT * FROM users");
201        let key3 = cache_key(&"SELECT * FROM orders");
202
203        assert_eq!(key1, key2);
204        assert_ne!(key1, key3);
205    }
206
207    #[test]
208    fn test_clear() {
209        let mut cache = StatementCache::new(10);
210        cache.get_or_insert(1, || "SELECT 1".to_string());
211        cache.get_or_insert(2, || "SELECT 2".to_string());
212        assert_eq!(cache.len(), 2);
213
214        cache.clear();
215        assert!(cache.is_empty());
216    }
217}