1use dashmap::DashMap;
6use sqlparser::ast::Statement;
7use std::path::PathBuf;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10
11pub struct PlanCache {
13 l1: Arc<DashMap<String, CachedPlan>>,
15 l2_dir: Option<PathBuf>,
17 max_l1_size: usize,
19 stats: CacheStats,
21}
22
23#[derive(Clone)]
25pub struct CachedPlan {
26 pub statement: Statement,
27 pub hit_count: u64,
28}
29
30pub struct CacheStats {
32 pub hits: AtomicU64,
33 pub misses: AtomicU64,
34 pub evictions: AtomicU64,
35 pub l2_hits: AtomicU64,
36}
37
38impl CacheStats {
39 fn new() -> Self {
40 Self {
41 hits: AtomicU64::new(0),
42 misses: AtomicU64::new(0),
43 evictions: AtomicU64::new(0),
44 l2_hits: AtomicU64::new(0),
45 }
46 }
47
48 pub fn hit_rate(&self) -> f64 {
50 let hits = self.hits.load(Ordering::Relaxed);
51 let total = hits + self.misses.load(Ordering::Relaxed);
52 if total == 0 {
53 0.0
54 } else {
55 hits as f64 / total as f64
56 }
57 }
58
59 pub fn total(&self) -> u64 {
61 self.hits.load(Ordering::Relaxed) + self.misses.load(Ordering::Relaxed)
62 }
63}
64
65impl PlanCache {
66 pub fn new(max_size: usize) -> Self {
68 Self {
69 l1: Arc::new(DashMap::with_capacity(max_size)),
70 l2_dir: None,
71 max_l1_size: max_size,
72 stats: CacheStats::new(),
73 }
74 }
75
76 pub fn with_default_size() -> Self {
78 Self::new(1_000)
79 }
80
81 pub fn with_l2_cache(mut self, dir: PathBuf) -> Self {
83 if !dir.exists() {
84 let _ = std::fs::create_dir_all(&dir);
85 }
86 self.l2_dir = Some(dir);
87 self
88 }
89
90 pub fn get(&self, sql: &str) -> Option<Statement> {
92 if let Some(mut entry) = self.l1.get_mut(sql) {
94 entry.hit_count += 1;
95 self.stats.hits.fetch_add(1, Ordering::Relaxed);
96 return Some(entry.statement.clone());
97 }
98
99 if let Some(stmt) = self.get_l2(sql) {
101 self.stats.l2_hits.fetch_add(1, Ordering::Relaxed);
102 self.stats.hits.fetch_add(1, Ordering::Relaxed);
103 self.insert_l1(sql.to_string(), stmt.clone());
105 return Some(stmt);
106 }
107
108 self.stats.misses.fetch_add(1, Ordering::Relaxed);
109 None
110 }
111
112 pub fn insert(&self, sql: String, statement: Statement) {
114 self.insert_l1(sql.clone(), statement.clone());
115 self.put_l2(&sql, &statement);
116 }
117
118 fn insert_l1(&self, sql: String, statement: Statement) {
120 if self.l1.len() >= self.max_l1_size
122 && let Some(lru_key) = self.find_lfu_key()
123 {
124 self.l1.remove(&lru_key);
125 self.stats.evictions.fetch_add(1, Ordering::Relaxed);
126 }
127
128 self.l1.insert(
129 sql,
130 CachedPlan {
131 statement,
132 hit_count: 0,
133 },
134 );
135 }
136
137 fn find_lfu_key(&self) -> Option<String> {
139 let mut min_hits = u64::MAX;
140 let mut lfu_key = None;
141 for entry in self.l1.iter() {
142 if entry.value().hit_count < min_hits {
143 min_hits = entry.value().hit_count;
144 lfu_key = Some(entry.key().clone());
145 }
146 }
147 lfu_key
148 }
149
150 fn put_l2(&self, sql: &str, statement: &Statement) {
152 if let Some(dir) = &self.l2_dir {
153 let hash = Self::hash_sql(sql);
154 let path = dir.join(format!("{hash}.plan"));
155 let data = format!("{statement:?}");
157 let _ = std::fs::write(path, data.as_bytes());
158 }
159 }
160
161 fn get_l2(&self, sql: &str) -> Option<Statement> {
163 let dir = self.l2_dir.as_ref()?;
164 let hash = Self::hash_sql(sql);
165 let path = dir.join(format!("{hash}.plan"));
166
167 if path.exists() {
168 use sqlparser::dialect::GenericDialect;
170 use sqlparser::parser::Parser;
171 let dialect = GenericDialect {};
172 Parser::parse_sql(&dialect, sql).ok()?.into_iter().next()
173 } else {
174 None
175 }
176 }
177
178 fn hash_sql(sql: &str) -> u64 {
180 let mut hash: u64 = 0xcbf29ce484222325;
181 for byte in sql.bytes() {
182 hash ^= byte as u64;
183 hash = hash.wrapping_mul(0x100000001b3);
184 }
185 hash
186 }
187
188 pub fn clear(&self) {
190 self.l1.clear();
191 if let Some(dir) = &self.l2_dir {
192 let _ = std::fs::remove_dir_all(dir);
193 let _ = std::fs::create_dir_all(dir);
194 }
195 }
196
197 pub fn len(&self) -> usize {
199 self.l1.len()
200 }
201
202 pub fn is_empty(&self) -> bool {
204 self.l1.is_empty()
205 }
206
207 pub fn stats(&self) -> &CacheStats {
209 &self.stats
210 }
211
212 pub fn contains(&self, sql: &str) -> bool {
214 self.l1.contains_key(sql)
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 use sqlparser::dialect::GenericDialect;
222 use sqlparser::parser::Parser;
223
224 fn parse_one(sql: &str) -> Statement {
225 let dialect = GenericDialect {};
226 Parser::parse_sql(&dialect, sql)
227 .unwrap()
228 .into_iter()
229 .next()
230 .unwrap()
231 }
232
233 #[test]
234 fn test_plan_cache_basic() {
235 let cache = PlanCache::new(10);
236 let sql = "SELECT * FROM users";
237 let stmt = parse_one(sql);
238
239 cache.insert(sql.to_string(), stmt.clone());
240
241 let cached = cache.get(sql);
242 assert!(cached.is_some());
243 assert_eq!(cache.stats().hits.load(Ordering::Relaxed), 1);
244 assert_eq!(cache.stats().misses.load(Ordering::Relaxed), 0);
245 }
246
247 #[test]
248 fn test_plan_cache_eviction() {
249 let cache = PlanCache::new(2);
250
251 let sql1 = "SELECT * FROM users";
252 let sql2 = "SELECT * FROM orders";
253 let sql3 = "SELECT * FROM products";
254
255 cache.insert(sql1.to_string(), parse_one(sql1));
256 cache.insert(sql2.to_string(), parse_one(sql2));
257 cache.insert(sql3.to_string(), parse_one(sql3)); assert_eq!(cache.len(), 2);
260 assert_eq!(cache.stats().evictions.load(Ordering::Relaxed), 1);
261 }
262
263 #[test]
264 fn test_plan_cache_hit_rate() {
265 let cache = PlanCache::new(10);
266 let sql = "SELECT * FROM users";
267 cache.insert(sql.to_string(), parse_one(sql));
268
269 cache.get(sql); cache.get(sql); cache.get("SELECT 1"); assert_eq!(cache.stats().hits.load(Ordering::Relaxed), 2);
274 assert_eq!(cache.stats().misses.load(Ordering::Relaxed), 1);
275 assert!((cache.stats().hit_rate() - 0.666).abs() < 0.01);
276 }
277
278 #[test]
279 fn test_plan_cache_l2_disk() {
280 let tmp_dir = std::env::temp_dir().join("dbx_plan_cache_test");
281 let _ = std::fs::remove_dir_all(&tmp_dir);
282
283 let cache = PlanCache::new(1).with_l2_cache(tmp_dir.clone());
284 let sql1 = "SELECT * FROM users";
285 let sql2 = "SELECT * FROM orders";
286
287 cache.insert(sql1.to_string(), parse_one(sql1));
288 cache.insert(sql2.to_string(), parse_one(sql2)); let result = cache.get(sql1);
292 assert!(result.is_some());
293 assert_eq!(cache.stats().l2_hits.load(Ordering::Relaxed), 1);
294
295 let _ = std::fs::remove_dir_all(&tmp_dir);
297 }
298
299 #[test]
300 fn test_plan_cache_contains() {
301 let cache = PlanCache::new(10);
302 let sql = "SELECT * FROM users";
303 assert!(!cache.contains(sql));
304
305 cache.insert(sql.to_string(), parse_one(sql));
306 assert!(cache.contains(sql));
307 }
308
309 #[test]
310 fn test_plan_cache_concurrent_access() {
311 use std::thread;
312
313 let cache = Arc::new(PlanCache::new(100));
314 let mut handles = vec![];
315
316 for i in 0..8 {
317 let cache = Arc::clone(&cache);
318 handles.push(thread::spawn(move || {
319 let sql = format!("SELECT * FROM table_{i}");
320 let stmt = parse_one(&sql);
321 cache.insert(sql.clone(), stmt);
322 assert!(cache.get(&sql).is_some());
323 }));
324 }
325
326 for h in handles {
327 h.join().unwrap();
328 }
329
330 assert_eq!(cache.len(), 8);
331 }
332}