1#![allow(clippy::doc_markdown)]
4
5use alloc::collections::VecDeque;
34use alloc::string::String;
35use alloc::vec::Vec;
36
37use spg_storage::Value;
38
39pub const DEFAULT_MAX_ENTRIES: usize = 1024;
42
43pub const DEFAULT_MAX_BYTES: usize = 16 * 1024 * 1024;
46
47#[derive(Debug, Clone, PartialEq)]
52pub struct CacheKey {
53 pub subquery_repr: String,
54 pub outer_values: Vec<Value>,
55}
56
57pub type GroupMap = (
60 spg_sql::ast::ColumnName,
61 alloc::collections::BTreeMap<String, Value>,
62);
63
64pub type ExprPlan = (
75 usize,
76 alloc::vec::Vec<Option<alloc::rc::Rc<GroupMap>>>,
77 spg_sql::ast::Expr,
78);
79
80#[derive(Debug, Clone)]
86pub enum InListSet {
87 Int(alloc::collections::BTreeSet<i64>),
88 Text(alloc::collections::BTreeSet<String>),
89}
90
91#[derive(Debug, Clone)]
92pub struct InListSetEntry {
93 pub set: InListSet,
94 pub has_null: bool,
97}
98
99#[derive(Debug, Clone)]
100pub struct MemoizeCache {
101 entries: VecDeque<(CacheKey, Value)>,
106 pub group_maps: alloc::collections::BTreeMap<String, Option<alloc::rc::Rc<GroupMap>>>,
112 pub expr_plans: alloc::collections::BTreeMap<usize, ExprPlan>,
114 pub in_sets: alloc::collections::BTreeMap<usize, Option<InListSetEntry>>,
119 pub has_subquery: alloc::collections::BTreeMap<usize, bool>,
124 max_entries: usize,
125 max_bytes: usize,
126 current_bytes: usize,
127 pub hit_count: u64,
128 pub miss_count: u64,
129}
130
131impl Default for MemoizeCache {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137impl MemoizeCache {
138 pub fn new() -> Self {
139 Self {
140 entries: VecDeque::with_capacity(DEFAULT_MAX_ENTRIES),
141 max_entries: DEFAULT_MAX_ENTRIES,
142 max_bytes: DEFAULT_MAX_BYTES,
143 current_bytes: 0,
144 hit_count: 0,
145 miss_count: 0,
146 group_maps: alloc::collections::BTreeMap::new(),
147 expr_plans: alloc::collections::BTreeMap::new(),
148 in_sets: alloc::collections::BTreeMap::new(),
149 has_subquery: alloc::collections::BTreeMap::new(),
150 }
151 }
152
153 pub const fn with_max_entries(mut self, n: usize) -> Self {
154 self.max_entries = n;
155 self
156 }
157
158 pub const fn with_max_bytes(mut self, b: usize) -> Self {
159 self.max_bytes = b;
160 self
161 }
162
163 pub fn len(&self) -> usize {
164 self.entries.len()
165 }
166
167 pub fn is_empty(&self) -> bool {
168 self.entries.is_empty()
169 }
170
171 pub fn get(&mut self, key: &CacheKey) -> Option<Value> {
175 let pos = self.entries.iter().position(|(k, _)| k == key);
176 if let Some(p) = pos {
177 let (k, v) = self.entries.remove(p)?;
178 self.entries.push_front((k, v.clone()));
179 self.hit_count += 1;
180 Some(v)
181 } else {
182 self.miss_count += 1;
183 None
184 }
185 }
186
187 pub fn insert(&mut self, key: CacheKey, value: Value) {
191 let entry_bytes = approx_bytes(&key) + approx_value_bytes(&value);
192 while !self.entries.is_empty()
193 && (self.entries.len() >= self.max_entries
194 || self.current_bytes + entry_bytes > self.max_bytes)
195 {
196 let Some((k, v)) = self.entries.pop_back() else {
197 break;
198 };
199 self.current_bytes = self
200 .current_bytes
201 .saturating_sub(approx_bytes(&k) + approx_value_bytes(&v));
202 }
203 self.current_bytes = self.current_bytes.saturating_add(entry_bytes);
204 self.entries.push_front((key, value));
205 }
206}
207
208fn approx_bytes(key: &CacheKey) -> usize {
209 key.subquery_repr.len()
210 + key
211 .outer_values
212 .iter()
213 .map(approx_value_bytes)
214 .sum::<usize>()
215 + 16
216}
217
218fn approx_value_bytes(v: &Value) -> usize {
219 match v {
220 Value::Null | Value::Bool(_) | Value::SmallInt(_) => 1,
221 Value::Int(_) => 4,
222 Value::BigInt(_) | Value::Float(_) => 8,
223 Value::Date(_) | Value::Timestamp(_) => 8,
224 Value::Interval { .. } => 16,
225 Value::Numeric { .. } => 16,
226 Value::Text(s) | Value::Json(s) => s.len(),
227 Value::Vector(v) => v.len() * 4,
228 Value::Sq8Vector(q) => q.bytes.len() + 8,
229 Value::HalfVector(h) => h.dim() * 2,
230 _ => 16,
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 fn key(repr: &str, outer: &[Value]) -> CacheKey {
240 CacheKey {
241 subquery_repr: repr.into(),
242 outer_values: outer.to_vec(),
243 }
244 }
245
246 #[test]
247 fn empty_cache_misses_everything() {
248 let mut c = MemoizeCache::new();
249 let k = key("SELECT 1", &[Value::Int(1)]);
250 assert!(c.get(&k).is_none());
251 assert_eq!(c.miss_count, 1);
252 assert_eq!(c.hit_count, 0);
253 }
254
255 #[test]
256 fn insert_then_get_hits() {
257 let mut c = MemoizeCache::new();
258 let k = key("SELECT 1", &[Value::Int(1)]);
259 c.insert(k.clone(), Value::BigInt(42));
260 let v = c.get(&k);
261 assert_eq!(v, Some(Value::BigInt(42)));
262 assert_eq!(c.hit_count, 1);
263 }
264
265 #[test]
266 fn repeated_outer_key_hits_after_first_insert() {
267 let mut c = MemoizeCache::new();
268 let repr = "SELECT MAX(x) FROM y WHERE y.k = outer.k";
269 for i in 0..100 {
270 let k = key(repr, &[Value::Int(i % 5)]);
271 if c.get(&k).is_none() {
272 c.insert(k, Value::BigInt(i64::from(i)));
273 }
274 }
275 assert_eq!(c.miss_count, 5);
277 assert_eq!(c.hit_count, 95);
278 }
279
280 #[test]
281 fn lru_eviction_at_max_entries() {
282 let mut c = MemoizeCache::new().with_max_entries(3);
283 for i in 0..5 {
284 let k = key("q", &[Value::Int(i)]);
285 c.insert(k, Value::BigInt(i64::from(i)));
286 }
287 assert!(c.len() <= 3, "len={}", c.len());
288 assert!(c.get(&key("q", &[Value::Int(4)])).is_some());
290 assert!(c.get(&key("q", &[Value::Int(3)])).is_some());
291 assert!(c.get(&key("q", &[Value::Int(2)])).is_some());
292 assert!(c.get(&key("q", &[Value::Int(0)])).is_none());
294 }
295
296 #[test]
297 fn lru_eviction_at_max_bytes() {
298 let mut c = MemoizeCache::new().with_max_bytes(128);
299 for i in 0..10 {
301 let big_str = alloc::string::String::from_iter(core::iter::repeat_n('x', 64));
302 c.insert(key("q", &[Value::Int(i)]), Value::Text(big_str));
303 }
304 assert!(c.len() < 10, "len={}", c.len());
305 }
306
307 #[test]
308 fn distinct_subquery_reprs_dont_collide() {
309 let mut c = MemoizeCache::new();
310 let k1 = key("SELECT 1", &[Value::Int(1)]);
311 let k2 = key("SELECT 2", &[Value::Int(1)]);
312 c.insert(k1.clone(), Value::BigInt(10));
313 c.insert(k2.clone(), Value::BigInt(20));
314 assert_eq!(c.get(&k1), Some(Value::BigInt(10)));
315 assert_eq!(c.get(&k2), Some(Value::BigInt(20)));
316 }
317
318 #[test]
319 fn miss_then_hit_bumps_promotes_to_lru_front() {
320 let mut c = MemoizeCache::new().with_max_entries(3);
321 c.insert(key("q", &[Value::Int(0)]), Value::BigInt(0));
322 c.insert(key("q", &[Value::Int(1)]), Value::BigInt(1));
323 c.insert(key("q", &[Value::Int(2)]), Value::BigInt(2));
324 let _ = c.get(&key("q", &[Value::Int(0)]));
326 c.insert(key("q", &[Value::Int(3)]), Value::BigInt(3));
328 assert!(c.get(&key("q", &[Value::Int(0)])).is_some());
329 assert!(c.get(&key("q", &[Value::Int(1)])).is_none());
330 }
331}