1use std::collections::HashMap;
7use std::hash::{Hash, Hasher};
8use std::time::Instant;
9
10#[derive(Debug, Clone)]
12pub struct CachedStatement {
13 pub sql: String,
15 pub last_used: Instant,
17 pub hit_count: u64,
19}
20
21#[derive(Debug)]
42pub struct StatementCache {
43 cache: HashMap<u64, CachedStatement>,
44 max_size: usize,
45}
46
47impl StatementCache {
48 pub fn new(max_size: usize) -> Self {
50 Self {
51 cache: HashMap::with_capacity(max_size.min(256)),
52 max_size,
53 }
54 }
55
56 pub fn get_or_insert(&mut self, key: u64, builder: impl FnOnce() -> String) -> &str {
60 if !self.cache.contains_key(&key) && self.cache.len() >= self.max_size {
62 self.evict_lru();
63 }
64
65 let entry = self.cache.entry(key).or_insert_with(|| CachedStatement {
66 sql: builder(),
67 last_used: Instant::now(),
68 hit_count: 0,
69 });
70 entry.last_used = Instant::now();
71 entry.hit_count += 1;
72 &entry.sql
73 }
74
75 pub fn contains(&self, key: u64) -> bool {
77 self.cache.contains_key(&key)
78 }
79
80 pub fn len(&self) -> usize {
82 self.cache.len()
83 }
84
85 pub fn is_empty(&self) -> bool {
87 self.cache.is_empty()
88 }
89
90 pub fn clear(&mut self) {
92 self.cache.clear();
93 }
94
95 fn evict_lru(&mut self) {
97 if let Some((&lru_key, _)) = self.cache.iter().min_by_key(|(_, entry)| entry.last_used) {
98 self.cache.remove(&lru_key);
99 }
100 }
101}
102
103pub fn cache_key(value: &impl Hash) -> u64 {
107 let mut hasher = std::collections::hash_map::DefaultHasher::new();
108 value.hash(&mut hasher);
109 hasher.finish()
110}
111
112impl Default for StatementCache {
113 fn default() -> Self {
114 Self::new(1024)
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn test_cache_hit() {
124 let mut cache = StatementCache::new(10);
125 let sql = cache
126 .get_or_insert(1, || "SELECT 1".to_string())
127 .to_string();
128 assert_eq!(sql, "SELECT 1");
129
130 let sql2 = cache
132 .get_or_insert(1, || panic!("should not be called"))
133 .to_string();
134 assert_eq!(sql2, "SELECT 1");
135 }
136
137 #[test]
138 fn test_cache_miss() {
139 let mut cache = StatementCache::new(10);
140 let sql1 = cache
141 .get_or_insert(1, || "SELECT 1".to_string())
142 .to_string();
143 let sql2 = cache
144 .get_or_insert(2, || "SELECT 2".to_string())
145 .to_string();
146 assert_eq!(sql1, "SELECT 1");
147 assert_eq!(sql2, "SELECT 2");
148 assert_eq!(cache.len(), 2);
149 }
150
151 #[test]
152 fn test_eviction() {
153 let mut cache = StatementCache::new(2);
154 cache.get_or_insert(1, || "SELECT 1".to_string());
155 cache.get_or_insert(2, || "SELECT 2".to_string());
156 cache.get_or_insert(3, || "SELECT 3".to_string());
158
159 assert_eq!(cache.len(), 2);
160 assert!(!cache.contains(1));
161 assert!(cache.contains(2));
162 assert!(cache.contains(3));
163 }
164
165 #[test]
166 fn test_lru_ordering() {
167 let mut cache = StatementCache::new(2);
168 cache.get_or_insert(1, || "SELECT 1".to_string());
169 cache.get_or_insert(2, || "SELECT 2".to_string());
170
171 cache.get_or_insert(1, || panic!("should not rebuild"));
173
174 cache.get_or_insert(3, || "SELECT 3".to_string());
176
177 assert!(cache.contains(1));
178 assert!(!cache.contains(2));
179 assert!(cache.contains(3));
180 }
181
182 #[test]
183 fn test_cache_key_function() {
184 let key1 = cache_key(&"SELECT * FROM users");
185 let key2 = cache_key(&"SELECT * FROM users");
186 let key3 = cache_key(&"SELECT * FROM orders");
187
188 assert_eq!(key1, key2);
189 assert_ne!(key1, key3);
190 }
191
192 #[test]
193 fn test_clear() {
194 let mut cache = StatementCache::new(10);
195 cache.get_or_insert(1, || "SELECT 1".to_string());
196 cache.get_or_insert(2, || "SELECT 2".to_string());
197 assert_eq!(cache.len(), 2);
198
199 cache.clear();
200 assert!(cache.is_empty());
201 }
202}