use std::time::Duration;
use oxisql_core::RowSet;
use crate::lru::LruCache;
use crate::Cache;
fn normalise_query(sql: &str) -> String {
let mut out = String::with_capacity(sql.len());
let mut prev_was_space = true; for ch in sql.chars() {
if ch.is_ascii_whitespace() {
if !prev_was_space {
out.push(' ');
prev_was_space = true;
}
} else {
out.push(ch.to_ascii_uppercase());
prev_was_space = false;
}
}
if out.ends_with(' ') {
out.pop();
}
out
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct QueryCacheStats {
pub hits: u64,
pub misses: u64,
pub len: usize,
pub cap: usize,
}
impl QueryCacheStats {
#[must_use]
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
pub struct SqlQueryCache {
inner: LruCache<String, RowSet>,
hits: u64,
misses: u64,
}
impl SqlQueryCache {
#[must_use]
pub fn new(capacity: usize) -> Self {
Self {
inner: LruCache::new(capacity),
hits: 0,
misses: 0,
}
}
pub fn get(&mut self, sql: &str) -> Option<RowSet> {
let key = normalise_query(sql);
match self.inner.get(&key) {
Some(rs) => {
self.hits += 1;
Some(rs.clone())
}
None => {
self.misses += 1;
None
}
}
}
pub fn put(&mut self, sql: &str, result: RowSet) {
let key = normalise_query(sql);
self.inner.put(key, result);
}
pub fn put_with_ttl(&mut self, sql: &str, result: RowSet, ttl: Duration) {
let key = normalise_query(sql);
self.inner.put_with_ttl(key, result, ttl);
}
pub fn invalidate(&mut self, sql: &str) -> Option<RowSet> {
let key = normalise_query(sql);
self.inner.remove(&key)
}
pub fn clear(&mut self) {
self.inner.clear();
}
pub fn contains(&mut self, sql: &str) -> bool {
let key = normalise_query(sql);
self.inner.contains_key(&key)
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn cap(&self) -> usize {
self.inner.cap()
}
#[must_use]
pub fn stats(&self) -> QueryCacheStats {
QueryCacheStats {
hits: self.hits,
misses: self.misses,
len: self.inner.len(),
cap: self.inner.cap(),
}
}
pub fn resize(&mut self, new_cap: usize) {
self.inner.resize(new_cap);
}
}
pub struct SqlPlanCache<P> {
inner: LruCache<String, P>,
hits: u64,
misses: u64,
}
impl<P: Clone> SqlPlanCache<P> {
#[must_use]
pub fn new(capacity: usize) -> Self {
Self {
inner: LruCache::new(capacity),
hits: 0,
misses: 0,
}
}
pub fn get(&mut self, sql: &str) -> Option<&P> {
let key = normalise_query(sql);
match self.inner.get(&key) {
Some(p) => {
self.hits += 1;
Some(p)
}
None => {
self.misses += 1;
None
}
}
}
pub fn put(&mut self, sql: &str, plan: P) {
let key = normalise_query(sql);
self.inner.put(key, plan);
}
pub fn put_with_ttl(&mut self, sql: &str, plan: P, ttl: Duration) {
let key = normalise_query(sql);
self.inner.put_with_ttl(key, plan, ttl);
}
pub fn invalidate(&mut self, sql: &str) -> Option<P> {
let key = normalise_query(sql);
self.inner.remove(&key)
}
pub fn clear(&mut self) {
self.inner.clear();
}
pub fn contains(&mut self, sql: &str) -> bool {
let key = normalise_query(sql);
self.inner.contains_key(&key)
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn cap(&self) -> usize {
self.inner.cap()
}
#[must_use]
pub fn hits(&self) -> u64 {
self.hits
}
#[must_use]
pub fn misses(&self) -> u64 {
self.misses
}
#[must_use]
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn resize(&mut self, new_cap: usize) {
self.inner.resize(new_cap);
}
}
pub struct CachedQueryRunner<F, E>
where
F: FnMut(&str) -> Result<RowSet, E>,
{
cache: SqlQueryCache,
executor: F,
}
impl<F, E> CachedQueryRunner<F, E>
where
F: FnMut(&str) -> Result<RowSet, E>,
{
pub fn new(capacity: usize, executor: F) -> Self {
Self {
cache: SqlQueryCache::new(capacity),
executor,
}
}
pub fn run(&mut self, sql: &str) -> Result<RowSet, E> {
if let Some(cached) = self.cache.get(sql) {
return Ok(cached);
}
let result = (self.executor)(sql)?;
self.cache.put(sql, result.clone());
Ok(result)
}
pub fn run_with_ttl(&mut self, sql: &str, ttl: Duration) -> Result<RowSet, E> {
if let Some(cached) = self.cache.get(sql) {
return Ok(cached);
}
let result = (self.executor)(sql)?;
self.cache.put_with_ttl(sql, result.clone(), ttl);
Ok(result)
}
pub fn invalidate(&mut self, sql: &str) {
self.cache.invalidate(sql);
}
pub fn clear(&mut self) {
self.cache.clear();
}
#[must_use]
pub fn hits(&self) -> u64 {
self.cache.hits
}
#[must_use]
pub fn misses(&self) -> u64 {
self.cache.misses
}
#[must_use]
pub fn stats(&self) -> QueryCacheStats {
self.cache.stats()
}
}
#[cfg(test)]
mod tests {
use super::*;
use oxisql_core::{Row, RowSet, Value};
fn make_rowset(n: i64) -> RowSet {
let rows: Vec<Row> = (0..n)
.map(|i| Row::new(vec!["id".into()], vec![Value::I64(i)]))
.collect();
RowSet::from_rows(rows)
}
#[test]
fn normalise_trims_whitespace() {
assert_eq!(normalise_query(" select 1 "), "SELECT 1");
}
#[test]
fn normalise_collapses_internal_spaces() {
assert_eq!(normalise_query("select id from t"), "SELECT ID FROM T");
}
#[test]
fn normalise_uppercase() {
assert_eq!(normalise_query("select id from t"), "SELECT ID FROM T");
}
#[test]
fn normalise_tabs_and_newlines() {
assert_eq!(normalise_query("SELECT\tid\nFROM\tt"), "SELECT ID FROM T");
}
#[test]
fn normalise_empty_string() {
assert_eq!(normalise_query(""), "");
assert_eq!(normalise_query(" "), "");
}
#[test]
fn put_and_get_basic() {
let mut cache = SqlQueryCache::new(8);
let rs = make_rowset(3);
cache.put("SELECT id FROM t", rs.clone());
let got = cache.get("SELECT id FROM t");
assert!(got.is_some());
assert_eq!(got.unwrap().len(), 3);
}
#[test]
fn get_normalises_key() {
let mut cache = SqlQueryCache::new(8);
cache.put("SELECT id FROM t", make_rowset(2));
assert!(cache.get("select id from t").is_some());
assert!(cache.get("SELECT\tID\nFROM\tT").is_some());
}
#[test]
fn miss_returns_none() {
let mut cache = SqlQueryCache::new(8);
assert!(cache.get("SELECT 1").is_none());
}
#[test]
fn hits_and_misses_counted() {
let mut cache = SqlQueryCache::new(8);
cache.put("SELECT 1", make_rowset(1));
cache.get("SELECT 1"); cache.get("SELECT 2"); let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert!((stats.hit_rate() - 0.5).abs() < 1e-9);
}
#[test]
fn invalidate_removes_entry() {
let mut cache = SqlQueryCache::new(8);
cache.put("SELECT 1", make_rowset(1));
let removed = cache.invalidate("select 1"); assert!(removed.is_some());
assert!(cache.get("SELECT 1").is_none());
}
#[test]
fn clear_empties_cache() {
let mut cache = SqlQueryCache::new(8);
cache.put("SELECT 1", make_rowset(1));
cache.put("SELECT 2", make_rowset(2));
cache.clear();
assert!(cache.is_empty());
}
#[test]
fn ttl_expiry() {
let mut cache = SqlQueryCache::new(8);
cache.put_with_ttl("SELECT 1", make_rowset(1), Duration::from_nanos(1));
std::thread::yield_now();
assert!(cache.get("SELECT 1").is_none());
}
#[test]
fn resize_evicts_excess() {
let mut cache = SqlQueryCache::new(8);
for i in 0..6 {
cache.put(&format!("SELECT {i}"), make_rowset(1));
}
assert_eq!(cache.len(), 6);
cache.resize(3);
assert!(cache.len() <= 3);
}
#[test]
fn cap_returns_capacity() {
let cache = SqlQueryCache::new(100);
assert_eq!(cache.cap(), 100);
}
#[test]
fn plan_cache_put_and_get() {
let mut cache: SqlPlanCache<Vec<u8>> = SqlPlanCache::new(16);
cache.put("SELECT 1", vec![0x01, 0x02, 0x03]);
let plan = cache.get("SELECT 1").unwrap();
assert_eq!(plan, &[0x01u8, 0x02, 0x03]);
}
#[test]
fn plan_cache_normalises_key() {
let mut cache: SqlPlanCache<Vec<u8>> = SqlPlanCache::new(16);
cache.put("SELECT id FROM t", vec![0xAB]);
assert!(cache.get("select id from t").is_some());
}
#[test]
fn plan_cache_invalidate() {
let mut cache: SqlPlanCache<Vec<u8>> = SqlPlanCache::new(16);
cache.put("SELECT 1", vec![1]);
assert!(cache.invalidate("SELECT 1").is_some());
assert!(cache.get("SELECT 1").is_none());
}
#[test]
fn plan_cache_hit_miss_stats() {
let mut cache: SqlPlanCache<Vec<u8>> = SqlPlanCache::new(16);
cache.put("SELECT 1", vec![1]);
cache.get("SELECT 1"); cache.get("SELECT 2"); assert_eq!(cache.hits(), 1);
assert_eq!(cache.misses(), 1);
assert!((cache.hit_rate() - 0.5).abs() < 1e-9);
}
#[test]
fn plan_cache_ttl_expiry() {
let mut cache: SqlPlanCache<Vec<u8>> = SqlPlanCache::new(16);
cache.put_with_ttl("SELECT 1", vec![1], Duration::from_nanos(1));
std::thread::yield_now();
assert!(cache.get("SELECT 1").is_none());
}
#[test]
fn runner_caches_result() {
let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(0));
let cc = std::sync::Arc::clone(&call_count);
let mut runner = CachedQueryRunner::new(32, move |_sql: &str| -> Result<RowSet, String> {
cc.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(make_rowset(5))
});
let r1 = runner.run("SELECT id FROM t").unwrap();
let r2 = runner.run("SELECT id FROM t").unwrap();
assert_eq!(r1.len(), r2.len());
assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 1);
assert_eq!(runner.hits(), 1);
assert_eq!(runner.misses(), 1);
}
#[test]
fn runner_propagates_executor_error() {
let mut runner = CachedQueryRunner::new(8, |_sql: &str| -> Result<RowSet, String> {
Err("db error".to_string())
});
assert!(runner.run("SELECT 1").is_err());
}
#[test]
fn runner_ttl_invalidates_result() {
let mut runner = CachedQueryRunner::new(8, |_: &str| -> Result<RowSet, String> {
Ok(make_rowset(1))
});
runner
.run_with_ttl("SELECT 1", Duration::from_nanos(1))
.unwrap();
std::thread::yield_now();
runner
.run_with_ttl("SELECT 1", Duration::from_nanos(1))
.unwrap();
assert_eq!(runner.misses(), 2); }
#[test]
fn runner_invalidate_forces_re_execution() {
let mut runner = CachedQueryRunner::new(8, |_: &str| -> Result<RowSet, String> {
Ok(make_rowset(1))
});
runner.run("SELECT 1").unwrap(); runner.invalidate("select 1"); runner.run("SELECT 1").unwrap(); assert_eq!(runner.misses(), 2);
assert_eq!(runner.hits(), 0);
}
}