1use std::collections::HashMap;
7use std::hash::{Hash, Hasher};
8use std::time::Instant;
9
10#[derive(Debug, Clone)]
12pub struct CachedStatement {
13 pub sql: String,
15 pub last_used: Instant,
17 pub hit_count: u64,
19}
20
21#[derive(Debug)]
47pub struct StatementCache {
48 cache: HashMap<u64, CachedStatement>,
49 max_size: usize,
50}
51
52impl StatementCache {
53 pub fn new(max_size: usize) -> Self {
55 Self {
56 cache: HashMap::with_capacity(max_size.min(256)),
57 max_size,
58 }
59 }
60
61 pub fn get_or_insert(&mut self, key: u64, builder: impl FnOnce() -> String) -> &str {
65 if !self.cache.contains_key(&key) && self.cache.len() >= self.max_size {
67 self.evict_lru();
68 }
69
70 let entry = self.cache.entry(key).or_insert_with(|| CachedStatement {
71 sql: builder(),
72 last_used: Instant::now(),
73 hit_count: 0,
74 });
75 entry.last_used = Instant::now();
76 entry.hit_count += 1;
77 &entry.sql
78 }
79
80 pub fn contains(&self, key: u64) -> bool {
82 self.cache.contains_key(&key)
83 }
84
85 pub fn len(&self) -> usize {
87 self.cache.len()
88 }
89
90 pub fn is_empty(&self) -> bool {
92 self.cache.is_empty()
93 }
94
95 pub fn clear(&mut self) {
97 self.cache.clear();
98 }
99
100 fn evict_lru(&mut self) {
102 if let Some((&lru_key, _)) = self.cache.iter().min_by_key(|(_, entry)| entry.last_used) {
103 self.cache.remove(&lru_key);
104 }
105 }
106}
107
108pub fn cache_key(value: &impl Hash) -> u64 {
112 let mut hasher = std::collections::hash_map::DefaultHasher::new();
113 value.hash(&mut hasher);
114 hasher.finish()
115}
116
117impl Default for StatementCache {
118 fn default() -> Self {
119 Self::new(1024)
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn test_cache_hit() {
129 let mut cache = StatementCache::new(10);
130 let sql = cache
131 .get_or_insert(1, || "SELECT 1".to_string())
132 .to_string();
133 assert_eq!(sql, "SELECT 1");
134
135 let called = std::cell::Cell::new(false);
137 let sql2 = cache
138 .get_or_insert(1, || {
139 called.set(true);
140 "SELECT 1".to_string()
141 })
142 .to_string();
143 assert_eq!(sql2, "SELECT 1");
144 assert!(!called.get());
145 }
146
147 #[test]
148 fn test_cache_miss() {
149 let mut cache = StatementCache::new(10);
150 let sql1 = cache
151 .get_or_insert(1, || "SELECT 1".to_string())
152 .to_string();
153 let sql2 = cache
154 .get_or_insert(2, || "SELECT 2".to_string())
155 .to_string();
156 assert_eq!(sql1, "SELECT 1");
157 assert_eq!(sql2, "SELECT 2");
158 assert_eq!(cache.len(), 2);
159 }
160
161 #[test]
162 fn test_eviction() {
163 let mut cache = StatementCache::new(2);
164 cache.get_or_insert(1, || "SELECT 1".to_string());
165 cache.get_or_insert(2, || "SELECT 2".to_string());
166 cache.get_or_insert(3, || "SELECT 3".to_string());
168
169 assert_eq!(cache.len(), 2);
170 assert!(!cache.contains(1));
171 assert!(cache.contains(2));
172 assert!(cache.contains(3));
173 }
174
175 #[test]
176 fn test_lru_ordering() {
177 let mut cache = StatementCache::new(2);
178 cache.get_or_insert(1, || "SELECT 1".to_string());
179 cache.get_or_insert(2, || "SELECT 2".to_string());
180
181 let called = std::cell::Cell::new(false);
183 cache.get_or_insert(1, || {
184 called.set(true);
185 "SELECT 1".to_string()
186 });
187 assert!(!called.get());
188
189 cache.get_or_insert(3, || "SELECT 3".to_string());
191
192 assert!(cache.contains(1));
193 assert!(!cache.contains(2));
194 assert!(cache.contains(3));
195 }
196
197 #[test]
198 fn test_cache_key_function() {
199 let key1 = cache_key(&"SELECT * FROM users");
200 let key2 = cache_key(&"SELECT * FROM users");
201 let key3 = cache_key(&"SELECT * FROM orders");
202
203 assert_eq!(key1, key2);
204 assert_ne!(key1, key3);
205 }
206
207 #[test]
208 fn test_clear() {
209 let mut cache = StatementCache::new(10);
210 cache.get_or_insert(1, || "SELECT 1".to_string());
211 cache.get_or_insert(2, || "SELECT 2".to_string());
212 assert_eq!(cache.len(), 2);
213
214 cache.clear();
215 assert!(cache.is_empty());
216 }
217}