use crate::types::{Row, RowId};
use dashmap::DashMap;
use lru::LruCache;
use parking_lot::RwLock;
use std::num::NonZeroUsize;
use std::sync::Arc;
pub type CacheKey = (String, RowId);
#[derive(Debug, Clone)]
struct AccessPattern {
last_row_id: RowId,
stride: i64,
sequential_count: usize,
last_access: std::time::Instant,
}
pub struct RowCache {
cache: Arc<RwLock<LruCache<CacheKey, Arc<Row>>>>,
stats: Arc<RwLock<CacheStats>>,
access_patterns: Arc<DashMap<String, AccessPattern>>,
prefetch_config: PrefetchConfig,
}
#[derive(Debug, Default, Clone)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub size: usize,
pub capacity: usize,
pub prefetch_triggered: u64,
pub prefetch_useful: u64, }
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn prefetch_efficiency(&self) -> f64 {
if self.prefetch_triggered == 0 {
0.0
} else {
self.prefetch_useful as f64 / self.prefetch_triggered as f64
}
}
}
#[derive(Debug, Clone)]
pub struct PrefetchConfig {
pub enabled: bool,
pub min_sequential_count: usize,
pub prefetch_size: usize,
pub max_stride: i64,
}
impl Default for PrefetchConfig {
fn default() -> Self {
Self {
enabled: true,
min_sequential_count: 3, prefetch_size: 32, max_stride: 100, }
}
}
impl RowCache {
pub fn new(capacity: usize) -> Self {
Self::with_prefetch_config(capacity, PrefetchConfig::default())
}
pub fn with_prefetch_config(capacity: usize, prefetch_config: PrefetchConfig) -> Self {
let capacity = capacity.max(1);
Self {
cache: Arc::new(RwLock::new(
LruCache::new(NonZeroUsize::new(capacity).unwrap())
)),
stats: Arc::new(RwLock::new(CacheStats {
hits: 0,
misses: 0,
size: 0,
capacity,
prefetch_triggered: 0,
prefetch_useful: 0,
})),
access_patterns: Arc::new(DashMap::new()),
prefetch_config,
}
}
pub fn get(&self, table_name: &str, row_id: RowId) -> Option<Arc<Row>> {
let key = (table_name.to_string(), row_id);
let mut cache = self.cache.write();
if let Some(row) = cache.get(&key) {
let mut stats = self.stats.write();
stats.hits += 1;
drop(stats); self.update_access_pattern(table_name, row_id);
Some(Arc::clone(row))
} else {
let mut stats = self.stats.write();
stats.misses += 1;
drop(stats); self.update_access_pattern(table_name, row_id);
None
}
}
fn update_access_pattern(&self, table_name: &str, row_id: RowId) -> Option<(RowId, usize)> {
if !self.prefetch_config.enabled {
return None;
}
let now = std::time::Instant::now();
let should_prefetch = match self.access_patterns.entry(table_name.to_string()) {
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
let pattern = entry.get_mut();
if now.duration_since(pattern.last_access).as_secs() > 1 {
pattern.last_row_id = row_id;
pattern.stride = 0;
pattern.sequential_count = 1;
pattern.last_access = now;
return None;
}
let stride = row_id as i64 - pattern.last_row_id as i64;
if stride == pattern.stride && stride.abs() <= self.prefetch_config.max_stride {
pattern.sequential_count += 1;
pattern.last_row_id = row_id;
pattern.last_access = now;
if pattern.sequential_count >= self.prefetch_config.min_sequential_count {
let next_row_id = (row_id as i64 + stride) as RowId;
Some((next_row_id, self.prefetch_config.prefetch_size))
} else {
None
}
} else if stride.abs() <= self.prefetch_config.max_stride {
pattern.stride = stride;
pattern.sequential_count = 2; pattern.last_row_id = row_id;
pattern.last_access = now;
None
} else {
pattern.stride = 0;
pattern.sequential_count = 1;
pattern.last_row_id = row_id;
pattern.last_access = now;
None
}
}
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(AccessPattern {
last_row_id: row_id,
stride: 0,
sequential_count: 1,
last_access: now,
});
None
}
};
should_prefetch
}
pub fn check_prefetch(&self, table_name: &str, row_id: RowId) -> Option<(RowId, usize, i64)> {
if !self.prefetch_config.enabled {
return None;
}
if let Some(pattern_ref) = self.access_patterns.get(table_name) {
let pattern = pattern_ref.value();
if pattern.last_access.elapsed().as_secs() > 1 {
return None;
}
if pattern.sequential_count >= self.prefetch_config.min_sequential_count {
let stride = pattern.stride;
let next_row_id = (row_id as i64 + stride) as RowId;
return Some((next_row_id, self.prefetch_config.prefetch_size, stride));
}
}
None
}
pub fn put(&self, table_name: String, row_id: RowId, row: Row) {
let key = (table_name, row_id);
let row_arc = Arc::new(row);
let mut cache = self.cache.write();
cache.put(key, row_arc);
let mut stats = self.stats.write();
stats.size = cache.len();
}
pub fn put_batch(&self, table_name: &str, rows: Vec<(RowId, Row)>) {
let mut cache = self.cache.write();
for (row_id, row) in rows {
let key = (table_name.to_string(), row_id);
cache.put(key, Arc::new(row));
}
let mut stats = self.stats.write();
stats.size = cache.len();
}
pub fn invalidate(&self, table_name: &str, row_id: RowId) {
let key = (table_name.to_string(), row_id);
let mut cache = self.cache.write();
cache.pop(&key);
let mut stats = self.stats.write();
stats.size = cache.len();
}
pub fn invalidate_table(&self, table_name: &str) {
let mut cache = self.cache.write();
let keys_to_remove: Vec<CacheKey> = cache
.iter()
.filter(|(key, _)| key.0 == table_name)
.map(|(key, _)| key.clone())
.collect();
for key in keys_to_remove {
cache.pop(&key);
}
let mut stats = self.stats.write();
stats.size = cache.len();
}
pub fn clear(&self) {
let mut cache = self.cache.write();
cache.clear();
let mut stats = self.stats.write();
stats.size = 0;
stats.hits = 0;
stats.misses = 0;
stats.prefetch_triggered = 0;
stats.prefetch_useful = 0;
self.access_patterns.clear();
}
pub fn record_prefetch(&self, count: usize) {
let mut stats = self.stats.write();
stats.prefetch_triggered += count as u64;
}
pub fn record_prefetch_hit(&self) {
let mut stats = self.stats.write();
stats.prefetch_useful += 1;
}
pub fn stats(&self) -> CacheStats {
self.stats.read().clone()
}
pub fn print_stats(&self) {
let stats = self.stats();
println!("๐ Row Cache Statistics:");
println!(" Hits: {}, Misses: {}", stats.hits, stats.misses);
println!(" Hit Rate: {:.2}%", stats.hit_rate() * 100.0);
println!(" Size: {}/{} rows", stats.size, stats.capacity);
if self.prefetch_config.enabled {
println!("๐ Prefetch Statistics:");
println!(" Triggered: {} rows", stats.prefetch_triggered);
println!(" Useful: {} rows", stats.prefetch_useful);
println!(" Efficiency: {:.2}%", stats.prefetch_efficiency() * 100.0);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Value;
#[test]
fn test_row_cache_basic() {
let cache = RowCache::new(100);
let mut row = Row::new();
row.push(Value::Integer(1));
row.push(Value::Text("test".to_string()));
assert!(cache.get("users", 1).is_none());
cache.put("users".to_string(), 1, row.clone());
let cached_row = cache.get("users", 1).unwrap();
assert_eq!(cached_row.len(), 2);
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert_eq!(stats.hit_rate(), 0.5);
}
#[test]
fn test_row_cache_invalidation() {
let cache = RowCache::new(100);
let mut row = Row::new();
row.push(Value::Integer(1));
cache.put("users".to_string(), 1, row.clone());
assert!(cache.get("users", 1).is_some());
cache.invalidate("users", 1);
assert!(cache.get("users", 1).is_none());
}
#[test]
fn test_row_cache_lru_eviction() {
let cache = RowCache::new(3);
for i in 1..=3 {
let mut row = Row::new();
row.push(Value::Integer(i));
cache.put("users".to_string(), i as u64, row);
}
let stats = cache.stats();
assert_eq!(stats.size, 3);
let mut row = Row::new();
row.push(Value::Integer(4));
cache.put("users".to_string(), 4, row);
let stats = cache.stats();
assert_eq!(stats.size, 3);
assert!(cache.get("users", 1).is_none());
assert!(cache.get("users", 4).is_some());
}
}