Skip to main content

sqlmodel_query/
cache.rs

1//! Statement caching for compiled SQL queries.
2//!
3//! Caches compiled SQL strings keyed by a hash so repeated queries
4//! avoid redundant string building.
5
6use std::collections::HashMap;
7use std::hash::{Hash, Hasher};
8use std::time::Instant;
9
10/// A cached compiled SQL statement.
11#[derive(Debug, Clone)]
12pub struct CachedStatement {
13    /// The compiled SQL string.
14    pub sql: String,
15    /// When this entry was last accessed.
16    pub last_used: Instant,
17    /// Number of times this statement has been reused.
18    pub hit_count: u64,
19}
20
21/// LRU-style cache for compiled SQL statements.
22///
23/// Keyed by a `u64` hash that callers compute from their query structure.
24/// When the cache exceeds `max_size`, the least-recently-used entry is evicted.
25///
26/// # Example
27///
28/// ```
29/// use sqlmodel_query::cache::StatementCache;
30///
31/// let mut cache = StatementCache::new(100);
32///
33/// // Cache a compiled query
34/// let sql = cache.get_or_insert(12345, || "SELECT * FROM users WHERE id = $1".to_string());
35/// assert_eq!(sql, "SELECT * FROM users WHERE id = $1");
36///
37/// // Second call returns cached version
38/// let sql2 = cache.get_or_insert(12345, || panic!("should not be called"));
39/// assert_eq!(sql2, "SELECT * FROM users WHERE id = $1");
40/// ```
41#[derive(Debug)]
42pub struct StatementCache {
43    cache: HashMap<u64, CachedStatement>,
44    max_size: usize,
45}
46
47impl StatementCache {
48    /// Create a new cache with the given maximum number of entries.
49    pub fn new(max_size: usize) -> Self {
50        Self {
51            cache: HashMap::with_capacity(max_size.min(256)),
52            max_size,
53        }
54    }
55
56    /// Get a cached statement or compile and insert it.
57    ///
58    /// The `builder` closure is only called on cache miss.
59    pub fn get_or_insert(&mut self, key: u64, builder: impl FnOnce() -> String) -> &str {
60        // Check if we need to evict before inserting
61        if !self.cache.contains_key(&key) && self.cache.len() >= self.max_size {
62            self.evict_lru();
63        }
64
65        let entry = self.cache.entry(key).or_insert_with(|| CachedStatement {
66            sql: builder(),
67            last_used: Instant::now(),
68            hit_count: 0,
69        });
70        entry.last_used = Instant::now();
71        entry.hit_count += 1;
72        &entry.sql
73    }
74
75    /// Check if a statement is cached.
76    pub fn contains(&self, key: u64) -> bool {
77        self.cache.contains_key(&key)
78    }
79
80    /// Get cache statistics.
81    pub fn len(&self) -> usize {
82        self.cache.len()
83    }
84
85    /// Check if cache is empty.
86    pub fn is_empty(&self) -> bool {
87        self.cache.is_empty()
88    }
89
90    /// Clear all cached statements.
91    pub fn clear(&mut self) {
92        self.cache.clear();
93    }
94
95    /// Evict the least-recently-used entry.
96    fn evict_lru(&mut self) {
97        if let Some((&lru_key, _)) = self.cache.iter().min_by_key(|(_, entry)| entry.last_used) {
98            self.cache.remove(&lru_key);
99        }
100    }
101}
102
103/// Compute a hash key for caching from any hashable value.
104///
105/// Useful for creating cache keys from query components.
106pub fn cache_key(value: &impl Hash) -> u64 {
107    let mut hasher = std::collections::hash_map::DefaultHasher::new();
108    value.hash(&mut hasher);
109    hasher.finish()
110}
111
112impl Default for StatementCache {
113    fn default() -> Self {
114        Self::new(1024)
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn test_cache_hit() {
124        let mut cache = StatementCache::new(10);
125        let sql = cache
126            .get_or_insert(1, || "SELECT 1".to_string())
127            .to_string();
128        assert_eq!(sql, "SELECT 1");
129
130        // Should return cached value
131        let sql2 = cache
132            .get_or_insert(1, || panic!("should not be called"))
133            .to_string();
134        assert_eq!(sql2, "SELECT 1");
135    }
136
137    #[test]
138    fn test_cache_miss() {
139        let mut cache = StatementCache::new(10);
140        let sql1 = cache
141            .get_or_insert(1, || "SELECT 1".to_string())
142            .to_string();
143        let sql2 = cache
144            .get_or_insert(2, || "SELECT 2".to_string())
145            .to_string();
146        assert_eq!(sql1, "SELECT 1");
147        assert_eq!(sql2, "SELECT 2");
148        assert_eq!(cache.len(), 2);
149    }
150
151    #[test]
152    fn test_eviction() {
153        let mut cache = StatementCache::new(2);
154        cache.get_or_insert(1, || "SELECT 1".to_string());
155        cache.get_or_insert(2, || "SELECT 2".to_string());
156        // This should evict key 1 (LRU)
157        cache.get_or_insert(3, || "SELECT 3".to_string());
158
159        assert_eq!(cache.len(), 2);
160        assert!(!cache.contains(1));
161        assert!(cache.contains(2));
162        assert!(cache.contains(3));
163    }
164
165    #[test]
166    fn test_lru_ordering() {
167        let mut cache = StatementCache::new(2);
168        cache.get_or_insert(1, || "SELECT 1".to_string());
169        cache.get_or_insert(2, || "SELECT 2".to_string());
170
171        // Access key 1 to make it recently used
172        cache.get_or_insert(1, || panic!("should not rebuild"));
173
174        // Eviction should remove key 2 (now LRU)
175        cache.get_or_insert(3, || "SELECT 3".to_string());
176
177        assert!(cache.contains(1));
178        assert!(!cache.contains(2));
179        assert!(cache.contains(3));
180    }
181
182    #[test]
183    fn test_cache_key_function() {
184        let key1 = cache_key(&"SELECT * FROM users");
185        let key2 = cache_key(&"SELECT * FROM users");
186        let key3 = cache_key(&"SELECT * FROM orders");
187
188        assert_eq!(key1, key2);
189        assert_ne!(key1, key3);
190    }
191
192    #[test]
193    fn test_clear() {
194        let mut cache = StatementCache::new(10);
195        cache.get_or_insert(1, || "SELECT 1".to_string());
196        cache.get_or_insert(2, || "SELECT 2".to_string());
197        assert_eq!(cache.len(), 2);
198
199        cache.clear();
200        assert!(cache.is_empty());
201    }
202}