#![allow(clippy::doc_markdown)]
use alloc::collections::VecDeque;
use alloc::string::String;
use alloc::vec::Vec;
use spg_storage::Value;
pub const DEFAULT_MAX_ENTRIES: usize = 1024;
pub const DEFAULT_MAX_BYTES: usize = 16 * 1024 * 1024;
#[derive(Debug, Clone, PartialEq)]
pub struct CacheKey {
pub subquery_repr: String,
pub outer_values: Vec<Value>,
}
#[derive(Debug, Clone)]
pub struct MemoizeCache {
entries: VecDeque<(CacheKey, Value)>,
max_entries: usize,
max_bytes: usize,
current_bytes: usize,
pub hit_count: u64,
pub miss_count: u64,
}
impl Default for MemoizeCache {
fn default() -> Self {
Self::new()
}
}
impl MemoizeCache {
pub fn new() -> Self {
Self {
entries: VecDeque::with_capacity(DEFAULT_MAX_ENTRIES),
max_entries: DEFAULT_MAX_ENTRIES,
max_bytes: DEFAULT_MAX_BYTES,
current_bytes: 0,
hit_count: 0,
miss_count: 0,
}
}
pub const fn with_max_entries(mut self, n: usize) -> Self {
self.max_entries = n;
self
}
pub const fn with_max_bytes(mut self, b: usize) -> Self {
self.max_bytes = b;
self
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn get(&mut self, key: &CacheKey) -> Option<Value> {
let pos = self.entries.iter().position(|(k, _)| k == key);
if let Some(p) = pos {
let (k, v) = self.entries.remove(p)?;
self.entries.push_front((k, v.clone()));
self.hit_count += 1;
Some(v)
} else {
self.miss_count += 1;
None
}
}
pub fn insert(&mut self, key: CacheKey, value: Value) {
let entry_bytes = approx_bytes(&key) + approx_value_bytes(&value);
while !self.entries.is_empty()
&& (self.entries.len() >= self.max_entries
|| self.current_bytes + entry_bytes > self.max_bytes)
{
let Some((k, v)) = self.entries.pop_back() else {
break;
};
self.current_bytes = self
.current_bytes
.saturating_sub(approx_bytes(&k) + approx_value_bytes(&v));
}
self.current_bytes = self.current_bytes.saturating_add(entry_bytes);
self.entries.push_front((key, value));
}
}
fn approx_bytes(key: &CacheKey) -> usize {
key.subquery_repr.len()
+ key
.outer_values
.iter()
.map(approx_value_bytes)
.sum::<usize>()
+ 16
}
fn approx_value_bytes(v: &Value) -> usize {
match v {
Value::Null | Value::Bool(_) | Value::SmallInt(_) => 1,
Value::Int(_) => 4,
Value::BigInt(_) | Value::Float(_) => 8,
Value::Date(_) | Value::Timestamp(_) => 8,
Value::Interval { .. } => 16,
Value::Numeric { .. } => 16,
Value::Text(s) | Value::Json(s) => s.len(),
Value::Vector(v) => v.len() * 4,
Value::Sq8Vector(q) => q.bytes.len() + 8,
Value::HalfVector(h) => h.dim() * 2,
_ => 16,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn key(repr: &str, outer: &[Value]) -> CacheKey {
CacheKey {
subquery_repr: repr.into(),
outer_values: outer.to_vec(),
}
}
#[test]
fn empty_cache_misses_everything() {
let mut c = MemoizeCache::new();
let k = key("SELECT 1", &[Value::Int(1)]);
assert!(c.get(&k).is_none());
assert_eq!(c.miss_count, 1);
assert_eq!(c.hit_count, 0);
}
#[test]
fn insert_then_get_hits() {
let mut c = MemoizeCache::new();
let k = key("SELECT 1", &[Value::Int(1)]);
c.insert(k.clone(), Value::BigInt(42));
let v = c.get(&k);
assert_eq!(v, Some(Value::BigInt(42)));
assert_eq!(c.hit_count, 1);
}
#[test]
fn repeated_outer_key_hits_after_first_insert() {
let mut c = MemoizeCache::new();
let repr = "SELECT MAX(x) FROM y WHERE y.k = outer.k";
for i in 0..100 {
let k = key(repr, &[Value::Int(i % 5)]);
if c.get(&k).is_none() {
c.insert(k, Value::BigInt(i64::from(i)));
}
}
assert_eq!(c.miss_count, 5);
assert_eq!(c.hit_count, 95);
}
#[test]
fn lru_eviction_at_max_entries() {
let mut c = MemoizeCache::new().with_max_entries(3);
for i in 0..5 {
let k = key("q", &[Value::Int(i)]);
c.insert(k, Value::BigInt(i64::from(i)));
}
assert!(c.len() <= 3, "len={}", c.len());
assert!(c.get(&key("q", &[Value::Int(4)])).is_some());
assert!(c.get(&key("q", &[Value::Int(3)])).is_some());
assert!(c.get(&key("q", &[Value::Int(2)])).is_some());
assert!(c.get(&key("q", &[Value::Int(0)])).is_none());
}
#[test]
fn lru_eviction_at_max_bytes() {
let mut c = MemoizeCache::new().with_max_bytes(128);
for i in 0..10 {
let big_str = alloc::string::String::from_iter(
core::iter::repeat_n('x', 64),
);
c.insert(
key("q", &[Value::Int(i)]),
Value::Text(big_str),
);
}
assert!(c.len() < 10, "len={}", c.len());
}
#[test]
fn distinct_subquery_reprs_dont_collide() {
let mut c = MemoizeCache::new();
let k1 = key("SELECT 1", &[Value::Int(1)]);
let k2 = key("SELECT 2", &[Value::Int(1)]);
c.insert(k1.clone(), Value::BigInt(10));
c.insert(k2.clone(), Value::BigInt(20));
assert_eq!(c.get(&k1), Some(Value::BigInt(10)));
assert_eq!(c.get(&k2), Some(Value::BigInt(20)));
}
#[test]
fn miss_then_hit_bumps_promotes_to_lru_front() {
let mut c = MemoizeCache::new().with_max_entries(3);
c.insert(key("q", &[Value::Int(0)]), Value::BigInt(0));
c.insert(key("q", &[Value::Int(1)]), Value::BigInt(1));
c.insert(key("q", &[Value::Int(2)]), Value::BigInt(2));
let _ = c.get(&key("q", &[Value::Int(0)]));
c.insert(key("q", &[Value::Int(3)]), Value::BigInt(3));
assert!(c.get(&key("q", &[Value::Int(0)])).is_some());
assert!(c.get(&key("q", &[Value::Int(1)])).is_none());
}
}