use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use manifoldb_core::Value;
use super::CacheMetrics;
use crate::database::QueryResult;
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_entries: usize,
pub ttl: Option<Duration>,
pub enabled: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 1000,
ttl: Some(Duration::from_secs(300)), enabled: true,
}
}
}
impl CacheConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn max_entries(mut self, max: usize) -> Self {
self.max_entries = max;
self
}
#[must_use]
pub const fn ttl(mut self, ttl: Option<Duration>) -> Self {
self.ttl = ttl;
self
}
#[must_use]
pub const fn enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
#[must_use]
pub fn disabled() -> Self {
Self { enabled: false, ..Default::default() }
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct QueryCacheKey {
query_hash: u64,
params_hash: u64,
}
impl QueryCacheKey {
#[must_use]
pub fn new(sql: &str, params: &[Value]) -> Self {
let query_hash = hash_string(sql);
let params_hash = hash_params(params);
Self { query_hash, params_hash }
}
#[must_use]
pub fn combined_hash(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.query_hash.hash(&mut hasher);
self.params_hash.hash(&mut hasher);
hasher.finish()
}
}
#[derive(Debug, Clone)]
pub struct CacheEntry {
pub result: QueryResult,
pub created_at: Instant,
pub last_accessed: Instant,
pub accessed_tables: Vec<String>,
}
impl CacheEntry {
#[must_use]
pub fn new(result: QueryResult, accessed_tables: Vec<String>) -> Self {
let now = Instant::now();
Self { result, created_at: now, last_accessed: now, accessed_tables }
}
#[must_use]
pub fn is_expired(&self, ttl: Duration) -> bool {
self.created_at.elapsed() > ttl
}
pub fn touch(&mut self) {
self.last_accessed = Instant::now();
}
}
struct CacheState {
entries: HashMap<QueryCacheKey, CacheEntry>,
lru_order: Vec<QueryCacheKey>,
table_index: HashMap<String, Vec<QueryCacheKey>>,
}
impl CacheState {
fn new() -> Self {
Self { entries: HashMap::new(), lru_order: Vec::new(), table_index: HashMap::new() }
}
}
pub struct QueryCache {
config: CacheConfig,
state: RwLock<CacheState>,
metrics: Arc<CacheMetrics>,
}
impl QueryCache {
#[must_use]
pub fn new(config: CacheConfig) -> Self {
Self {
config,
state: RwLock::new(CacheState::new()),
metrics: Arc::new(CacheMetrics::new()),
}
}
#[must_use]
pub fn disabled() -> Self {
Self::new(CacheConfig::disabled())
}
#[must_use]
pub fn config(&self) -> &CacheConfig {
&self.config
}
#[must_use]
pub fn metrics(&self) -> Arc<CacheMetrics> {
Arc::clone(&self.metrics)
}
#[must_use]
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn get(&self, key: &QueryCacheKey) -> Option<QueryResult> {
if !self.config.enabled {
return None;
}
let state = self.state.read().ok()?;
let entry = match state.entries.get(key) {
Some(e) => e,
None => {
self.metrics.record_miss();
return None;
}
};
if let Some(ttl) = self.config.ttl {
if entry.is_expired(ttl) {
let result_clone = None;
drop(state);
self.remove(key);
self.metrics.record_miss();
return result_clone;
}
}
self.metrics.record_hit();
Some(entry.result.clone())
}
pub fn touch(&self, key: &QueryCacheKey) {
if !self.config.enabled {
return;
}
if let Ok(mut state) = self.state.write() {
if let Some(entry) = state.entries.get_mut(key) {
entry.touch();
if let Some(pos) = state.lru_order.iter().position(|k| k == key) {
state.lru_order.remove(pos);
state.lru_order.push(key.clone());
}
}
}
}
pub fn insert(&self, key: QueryCacheKey, result: QueryResult, accessed_tables: Vec<String>) {
if !self.config.enabled {
return;
}
if let Ok(mut state) = self.state.write() {
while state.entries.len() >= self.config.max_entries && !state.lru_order.is_empty() {
let oldest_key = state.lru_order.remove(0);
self.remove_from_state(&mut state, &oldest_key);
self.metrics.record_eviction();
}
let entry = CacheEntry::new(result, accessed_tables.clone());
for table in &accessed_tables {
state.table_index.entry(table.clone()).or_default().push(key.clone());
}
state.entries.insert(key.clone(), entry);
state.lru_order.push(key);
}
}
pub fn remove(&self, key: &QueryCacheKey) {
if let Ok(mut state) = self.state.write() {
self.remove_from_state(&mut state, key);
}
}
fn remove_from_state(&self, state: &mut CacheState, key: &QueryCacheKey) {
if let Some(entry) = state.entries.remove(key) {
if let Some(pos) = state.lru_order.iter().position(|k| k == key) {
state.lru_order.remove(pos);
}
for table in &entry.accessed_tables {
if let Some(keys) = state.table_index.get_mut(table) {
keys.retain(|k| k != key);
if keys.is_empty() {
state.table_index.remove(table);
}
}
}
}
}
pub fn invalidate_table(&self, table: &str) {
if !self.config.enabled {
return;
}
if let Ok(mut state) = self.state.write() {
if let Some(keys) = state.table_index.remove(table) {
let invalidated_count = keys.len();
for key in keys {
if let Some(entry) = state.entries.remove(&key) {
if let Some(pos) = state.lru_order.iter().position(|k| *k == key) {
state.lru_order.remove(pos);
}
for other_table in &entry.accessed_tables {
if other_table != table {
if let Some(other_keys) = state.table_index.get_mut(other_table) {
other_keys.retain(|k| *k != key);
}
}
}
}
}
self.metrics.record_invalidations(invalidated_count);
}
}
}
pub fn invalidate_tables(&self, tables: &[String]) {
for table in tables {
self.invalidate_table(table);
}
}
pub fn clear(&self) {
if let Ok(mut state) = self.state.write() {
let count = state.entries.len();
state.entries.clear();
state.lru_order.clear();
state.table_index.clear();
self.metrics.record_invalidations(count);
}
}
#[must_use]
pub fn len(&self) -> usize {
self.state.read().map(|s| s.entries.len()).unwrap_or(0)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for QueryCache {
fn default() -> Self {
Self::new(CacheConfig::default())
}
}
fn hash_string(s: &str) -> u64 {
let mut hasher = DefaultHasher::new();
s.hash(&mut hasher);
hasher.finish()
}
fn hash_params(params: &[Value]) -> u64 {
let mut hasher = DefaultHasher::new();
for param in params {
hash_value(param, &mut hasher);
}
hasher.finish()
}
fn hash_value(value: &Value, hasher: &mut DefaultHasher) {
std::mem::discriminant(value).hash(hasher);
match value {
Value::Null => {}
Value::Bool(b) => b.hash(hasher),
Value::Int(n) => n.hash(hasher),
Value::Float(f) => f.to_bits().hash(hasher),
Value::String(s) => s.hash(hasher),
Value::Bytes(b) => b.hash(hasher),
Value::Vector(v) => {
v.len().hash(hasher);
for f in v {
f.to_bits().hash(hasher);
}
}
Value::MultiVector(mv) => {
mv.len().hash(hasher);
for v in mv {
v.len().hash(hasher);
for f in v {
f.to_bits().hash(hasher);
}
}
}
Value::SparseVector(sv) => {
sv.len().hash(hasher);
for (idx, val) in sv {
idx.hash(hasher);
val.to_bits().hash(hasher);
}
}
Value::Array(arr) => {
arr.len().hash(hasher);
for v in arr {
hash_value(v, hasher);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_key_creation() {
let key1 = QueryCacheKey::new("SELECT * FROM users", &[]);
let key2 = QueryCacheKey::new("SELECT * FROM users", &[]);
let key3 = QueryCacheKey::new("SELECT * FROM orders", &[]);
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
#[test]
fn test_cache_key_with_params() {
let params1 = vec![Value::Int(1), Value::String("test".to_string())];
let params2 = vec![Value::Int(1), Value::String("test".to_string())];
let params3 = vec![Value::Int(2), Value::String("test".to_string())];
let key1 = QueryCacheKey::new("SELECT * FROM users WHERE id = $1", ¶ms1);
let key2 = QueryCacheKey::new("SELECT * FROM users WHERE id = $1", ¶ms2);
let key3 = QueryCacheKey::new("SELECT * FROM users WHERE id = $1", ¶ms3);
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
#[test]
fn test_cache_basic_operations() {
let cache = QueryCache::new(CacheConfig::default());
let key = QueryCacheKey::new("SELECT * FROM users", &[]);
let result = QueryResult::empty();
assert!(cache.get(&key).is_none());
cache.insert(key.clone(), result.clone(), vec!["users".to_string()]);
let cached = cache.get(&key);
assert!(cached.is_some());
assert_eq!(cache.len(), 1);
}
#[test]
fn test_cache_invalidation() {
let cache = QueryCache::new(CacheConfig::default());
let key1 = QueryCacheKey::new("SELECT * FROM users", &[]);
let key2 = QueryCacheKey::new("SELECT * FROM orders", &[]);
cache.insert(key1.clone(), QueryResult::empty(), vec!["users".to_string()]);
cache.insert(key2.clone(), QueryResult::empty(), vec!["orders".to_string()]);
assert_eq!(cache.len(), 2);
cache.invalidate_table("users");
assert!(cache.get(&key1).is_none());
assert!(cache.get(&key2).is_some());
assert_eq!(cache.len(), 1);
}
#[test]
fn test_cache_lru_eviction() {
let config = CacheConfig::default().max_entries(2);
let cache = QueryCache::new(config);
let key1 = QueryCacheKey::new("query1", &[]);
let key2 = QueryCacheKey::new("query2", &[]);
let key3 = QueryCacheKey::new("query3", &[]);
cache.insert(key1.clone(), QueryResult::empty(), vec![]);
cache.insert(key2.clone(), QueryResult::empty(), vec![]);
cache.get(&key1);
cache.touch(&key1);
cache.insert(key3.clone(), QueryResult::empty(), vec![]);
assert!(cache.get(&key1).is_some());
assert!(cache.get(&key2).is_none());
assert!(cache.get(&key3).is_some());
}
#[test]
fn test_cache_disabled() {
let cache = QueryCache::disabled();
let key = QueryCacheKey::new("SELECT * FROM users", &[]);
cache.insert(key.clone(), QueryResult::empty(), vec![]);
assert!(cache.get(&key).is_none());
}
#[test]
fn test_cache_ttl_expiration() {
let config = CacheConfig::default().ttl(Some(Duration::from_millis(10)));
let cache = QueryCache::new(config);
let key = QueryCacheKey::new("SELECT * FROM users", &[]);
cache.insert(key.clone(), QueryResult::empty(), vec![]);
assert!(cache.get(&key).is_some());
std::thread::sleep(Duration::from_millis(20));
assert!(cache.get(&key).is_none());
}
#[test]
fn test_cache_clear() {
let cache = QueryCache::new(CacheConfig::default());
cache.insert(QueryCacheKey::new("query1", &[]), QueryResult::empty(), vec![]);
cache.insert(QueryCacheKey::new("query2", &[]), QueryResult::empty(), vec![]);
assert_eq!(cache.len(), 2);
cache.clear();
assert!(cache.is_empty());
}
#[test]
fn test_hash_value() {
let mut hasher1 = DefaultHasher::new();
let mut hasher2 = DefaultHasher::new();
hash_value(&Value::Int(42), &mut hasher1);
hash_value(&Value::Int(42), &mut hasher2);
assert_eq!(hasher1.finish(), hasher2.finish());
}
#[test]
fn test_hash_vector() {
let v1 = Value::Vector(vec![1.0, 2.0, 3.0]);
let v2 = Value::Vector(vec![1.0, 2.0, 3.0]);
let v3 = Value::Vector(vec![1.0, 2.0, 4.0]);
let params1 = vec![v1];
let params2 = vec![v2];
let params3 = vec![v3];
assert_eq!(hash_params(¶ms1), hash_params(¶ms2));
assert_ne!(hash_params(¶ms1), hash_params(¶ms3));
}
}