use std::cell::RefCell;
use std::sync::Arc;
use rustrails_support::cache::{CacheOptions, CacheStore, MemoryStore};
use serde_json::Value;
thread_local! {
static QUERY_CACHE_SCOPE: RefCell<Vec<QueryCacheScope>> = const { RefCell::new(Vec::new()) };
}
#[derive(Clone)]
struct QueryCacheScope {
cache: QueryCache,
bypass_depth: usize,
}
#[derive(Clone, Debug)]
pub struct QueryCache {
store: Arc<MemoryStore>,
}
impl QueryCache {
#[must_use]
pub fn new() -> Self {
Self {
store: Arc::new(MemoryStore::new()),
}
}
#[must_use]
pub fn cache_key(sql: &str, binds: &[Value]) -> String {
let binds = match serde_json::to_string(binds) {
Ok(serialized) => serialized,
Err(_) => "[]".to_owned(),
};
format!("{sql}::{binds}")
}
pub fn with_request_scope<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
QUERY_CACHE_SCOPE.with(|scope| {
scope.borrow_mut().push(QueryCacheScope {
cache: QueryCache::new(),
bypass_depth: 0,
});
});
struct Guard;
impl Drop for Guard {
fn drop(&mut self) {
QUERY_CACHE_SCOPE.with(|scope| {
scope.borrow_mut().pop();
});
}
}
let _guard = Guard;
f()
}
pub fn fetch<F>(&self, sql: &str, binds: &[Value], loader: F) -> Value
where
F: FnOnce() -> Value,
{
let key = Self::cache_key(sql, binds);
if let Some((cache, bypassed)) = current_scope() {
if bypassed {
return loader();
}
return cache.store.fetch(&key, CacheOptions::default(), loader);
}
self.store.fetch(&key, CacheOptions::default(), loader)
}
pub fn execute_write<F, R>(&self, operation: F) -> R
where
F: FnOnce() -> R,
{
let result = operation();
if let Some((cache, _)) = current_scope() {
cache.store.clear();
} else {
self.store.clear();
}
result
}
pub fn uncached<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
QUERY_CACHE_SCOPE.with(|scope| {
if let Some(current) = scope.borrow_mut().last_mut() {
current.bypass_depth += 1;
}
});
struct Guard;
impl Drop for Guard {
fn drop(&mut self) {
QUERY_CACHE_SCOPE.with(|scope| {
if let Some(current) = scope.borrow_mut().last_mut() {
current.bypass_depth = current.bypass_depth.saturating_sub(1);
}
});
}
}
let _guard = Guard;
f()
}
}
impl Default for QueryCache {
fn default() -> Self {
Self::new()
}
}
fn current_scope() -> Option<(QueryCache, bool)> {
QUERY_CACHE_SCOPE.with(|scope| {
scope
.borrow()
.last()
.map(|scope| (scope.cache.clone(), scope.bypass_depth > 0))
})
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use serde_json::json;
use super::QueryCache;
#[test]
fn cache_key_depends_on_sql_and_binds() {
let first = QueryCache::cache_key("SELECT 1", &[json!(1)]);
let second = QueryCache::cache_key("SELECT 1", &[json!(2)]);
assert_ne!(first, second);
}
#[test]
fn fetch_uses_cached_value_within_request_scope() {
let cache = QueryCache::new();
let calls = AtomicUsize::new(0);
cache.with_request_scope(|| {
let first = cache.fetch("SELECT * FROM users WHERE id = ?", &[json!(1)], || {
calls.fetch_add(1, Ordering::SeqCst);
json!({"id": 1})
});
let second = cache.fetch("SELECT * FROM users WHERE id = ?", &[json!(1)], || {
calls.fetch_add(1, Ordering::SeqCst);
json!({"id": 2})
});
assert_eq!(first, second);
});
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[test]
fn request_scopes_are_isolated() {
let cache = QueryCache::new();
let calls = AtomicUsize::new(0);
cache.with_request_scope(|| {
let _ = cache.fetch("SELECT 1", &[], || {
calls.fetch_add(1, Ordering::SeqCst);
json!(1)
});
});
cache.with_request_scope(|| {
let _ = cache.fetch("SELECT 1", &[], || {
calls.fetch_add(1, Ordering::SeqCst);
json!(1)
});
});
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[test]
fn uncached_bypasses_request_cache() {
let cache = QueryCache::new();
let calls = AtomicUsize::new(0);
cache.with_request_scope(|| {
let _ = cache.fetch("SELECT 1", &[], || {
calls.fetch_add(1, Ordering::SeqCst);
json!(1)
});
let _ = QueryCache::uncached(|| {
cache.fetch("SELECT 1", &[], || {
calls.fetch_add(1, Ordering::SeqCst);
json!(2)
})
});
});
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[test]
fn execute_write_invalidates_current_scope() {
let cache = QueryCache::new();
let calls = AtomicUsize::new(0);
cache.with_request_scope(|| {
let _ = cache.fetch("SELECT 1", &[], || {
calls.fetch_add(1, Ordering::SeqCst);
json!(1)
});
cache.execute_write(|| ());
let _ = cache.fetch("SELECT 1", &[], || {
calls.fetch_add(1, Ordering::SeqCst);
json!(1)
});
});
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[test]
fn execute_write_invalidates_non_scoped_cache() {
let cache = QueryCache::new();
let calls = AtomicUsize::new(0);
let _ = cache.fetch("SELECT 1", &[], || {
calls.fetch_add(1, Ordering::SeqCst);
json!(1)
});
cache.execute_write(|| ());
let _ = cache.fetch("SELECT 1", &[], || {
calls.fetch_add(1, Ordering::SeqCst);
json!(1)
});
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
}