use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::io::{Read, Write};
use std::path::PathBuf;
use std::sync::RwLock;
use std::time::Instant;
use bytes::Bytes;
use dashmap::DashMap;
use super::config::{L2Config, StorageBackend};
use super::result::{CachedResult, CacheKey, L2Entry};
#[derive(Debug)]
pub struct L2WarmCache {
config: L2Config,
memory_entries: DashMap<u64, L2Entry>,
mmap_storage: Option<RwLock<MmapStorage>>,
memory_usage: std::sync::atomic::AtomicUsize,
}
#[derive(Debug)]
struct MmapStorage {
path: PathBuf,
file: Option<File>,
index: HashMap<u64, MmapEntry>,
file_size: usize,
}
#[derive(Debug, Clone)]
struct MmapEntry {
offset: usize,
size: usize,
expires_at: u64,
}
impl L2WarmCache {
pub fn new(config: L2Config) -> Self {
let mmap_storage = if config.storage == StorageBackend::Mmap {
config.mmap_path.as_ref().map(|path| {
RwLock::new(MmapStorage::new(path.clone()))
})
} else {
None
};
Self {
config,
memory_entries: DashMap::new(),
mmap_storage,
memory_usage: std::sync::atomic::AtomicUsize::new(0),
}
}
pub async fn get(&self, key: &CacheKey) -> Option<CachedResult> {
if !self.config.enabled {
return None;
}
let hash = key.hash_value();
if let Some(mut entry) = self.memory_entries.get_mut(&hash) {
if entry.is_expired() {
drop(entry);
self.memory_entries.remove(&hash);
return None;
}
entry.touch();
return Some(entry.result.clone());
}
if let Some(ref mmap) = self.mmap_storage {
if let Ok(storage) = mmap.read() {
if let Some(result) = storage.get(hash) {
self.promote_to_memory(key, result.clone());
return Some(result);
}
}
}
None
}
pub async fn put(&self, key: CacheKey, result: CachedResult) {
if !self.config.enabled {
return;
}
let entry_size = result.size() + std::mem::size_of::<L2Entry>();
let max_bytes = self.config.size_mb * 1024 * 1024;
let current_usage = self.memory_usage.load(std::sync::atomic::Ordering::Relaxed);
if current_usage + entry_size > max_bytes {
self.evict_to_fit(entry_size).await;
}
let hash = key.hash_value();
let fingerprint = format!("{:016x}", hash);
let entry = L2Entry::new(key, fingerprint, result);
let entry_memory = entry.memory_size;
self.memory_entries.insert(hash, entry);
self.memory_usage.fetch_add(entry_memory, std::sync::atomic::Ordering::Relaxed);
}
pub async fn remove(&self, key: &CacheKey) {
let hash = key.hash_value();
if let Some((_, entry)) = self.memory_entries.remove(&hash) {
self.memory_usage.fetch_sub(entry.memory_size, std::sync::atomic::Ordering::Relaxed);
}
if let Some(ref mmap) = self.mmap_storage {
if let Ok(mut storage) = mmap.write() {
storage.remove(hash);
}
}
}
pub async fn clear(&self) {
self.memory_entries.clear();
self.memory_usage.store(0, std::sync::atomic::Ordering::Relaxed);
if let Some(ref mmap) = self.mmap_storage {
if let Ok(mut storage) = mmap.write() {
storage.clear();
}
}
}
pub fn len(&self) -> usize {
self.memory_entries.len()
}
pub fn is_empty(&self) -> bool {
self.memory_entries.is_empty()
}
pub fn memory_usage(&self) -> usize {
self.memory_usage.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn stats(&self) -> L2CacheStats {
let total_access: u64 = self.memory_entries
.iter()
.map(|e| e.access_count)
.sum();
L2CacheStats {
entry_count: self.memory_entries.len(),
memory_usage_bytes: self.memory_usage(),
max_memory_bytes: self.config.size_mb * 1024 * 1024,
total_accesses: total_access,
storage_backend: self.config.storage.clone(),
}
}
async fn evict_to_fit(&self, required_bytes: usize) {
let max_bytes = self.config.size_mb * 1024 * 1024;
let target = max_bytes.saturating_sub(required_bytes);
let expired: Vec<u64> = self.memory_entries
.iter()
.filter(|e| e.is_expired())
.map(|e| *e.key())
.collect();
for hash in expired {
if let Some((_, entry)) = self.memory_entries.remove(&hash) {
self.memory_usage.fetch_sub(entry.memory_size, std::sync::atomic::Ordering::Relaxed);
}
}
while self.memory_usage.load(std::sync::atomic::Ordering::Relaxed) > target {
let lru_hash = self.memory_entries
.iter()
.min_by_key(|e| e.last_access)
.map(|e| *e.key());
if let Some(hash) = lru_hash {
if self.mmap_storage.is_some() {
if let Some(entry) = self.memory_entries.get(&hash) {
self.demote_to_mmap(&entry);
}
}
if let Some((_, entry)) = self.memory_entries.remove(&hash) {
self.memory_usage.fetch_sub(entry.memory_size, std::sync::atomic::Ordering::Relaxed);
}
} else {
break;
}
}
}
fn promote_to_memory(&self, key: &CacheKey, result: CachedResult) {
let hash = key.hash_value();
let fingerprint = format!("{:016x}", hash);
let entry = L2Entry::new(key.clone(), fingerprint, result);
let entry_memory = entry.memory_size;
self.memory_entries.insert(hash, entry);
self.memory_usage.fetch_add(entry_memory, std::sync::atomic::Ordering::Relaxed);
}
fn demote_to_mmap(&self, entry: &dashmap::mapref::one::Ref<u64, L2Entry>) {
if let Some(ref mmap) = self.mmap_storage {
if let Ok(mut storage) = mmap.write() {
storage.put(*entry.key(), &entry.result);
}
}
}
pub fn flush_to_disk(&self) -> Result<usize, std::io::Error> {
let Some(ref mmap) = self.mmap_storage else {
return Ok(0);
};
let mut storage = mmap.write()
.map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "Lock poisoned"))?;
let mut count = 0;
for entry in self.memory_entries.iter() {
if !entry.is_expired() {
storage.put(*entry.key(), &entry.result);
count += 1;
}
}
storage.sync()?;
Ok(count)
}
pub fn load_from_disk(&self) -> Result<usize, std::io::Error> {
let Some(ref mmap) = self.mmap_storage else {
return Ok(0);
};
let storage = mmap.read()
.map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "Lock poisoned"))?;
Ok(storage.entry_count())
}
}
impl MmapStorage {
fn new(path: PathBuf) -> Self {
Self {
path,
file: None,
index: HashMap::new(),
file_size: 0,
}
}
fn get(&self, hash: u64) -> Option<CachedResult> {
let entry = self.index.get(&hash)?;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.ok()?
.as_secs();
if now > entry.expires_at {
return None;
}
let mut file = File::open(&self.path).ok()?;
let mut buffer = vec![0u8; entry.size];
use std::io::Seek;
file.seek(std::io::SeekFrom::Start(entry.offset as u64)).ok()?;
file.read_exact(&mut buffer).ok()?;
deserialize_result(&buffer)
}
fn put(&mut self, hash: u64, result: &CachedResult) {
let data = serialize_result(result);
let file = match &mut self.file {
Some(f) => f,
None => {
self.file = OpenOptions::new()
.create(true)
.read(true)
.write(true)
.open(&self.path)
.ok();
match &mut self.file {
Some(f) => f,
None => return,
}
}
};
use std::io::Seek;
if file.seek(std::io::SeekFrom::End(0)).is_err() {
return;
}
let offset = self.file_size;
if file.write_all(&data).is_ok() {
let expires_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() + result.ttl.as_secs())
.unwrap_or(0);
self.index.insert(hash, MmapEntry {
offset,
size: data.len(),
expires_at,
});
self.file_size += data.len();
}
}
fn remove(&mut self, hash: u64) {
self.index.remove(&hash);
}
fn clear(&mut self) {
self.index.clear();
self.file_size = 0;
if let Some(ref mut file) = self.file {
let _ = file.set_len(0);
}
}
fn sync(&mut self) -> Result<(), std::io::Error> {
if let Some(ref file) = self.file {
file.sync_all()?;
}
Ok(())
}
fn entry_count(&self) -> usize {
self.index.len()
}
}
fn serialize_result(result: &CachedResult) -> Vec<u8> {
let mut buffer = Vec::new();
buffer.extend_from_slice(&result.ttl.as_secs().to_le_bytes());
buffer.extend_from_slice(&(result.row_count as u64).to_le_bytes());
buffer.extend_from_slice(&(result.data.len() as u64).to_le_bytes());
buffer.extend_from_slice(&result.data);
buffer
}
fn deserialize_result(buffer: &[u8]) -> Option<CachedResult> {
if buffer.len() < 24 {
return None;
}
let ttl_secs = u64::from_le_bytes(buffer[0..8].try_into().ok()?);
let row_count = u64::from_le_bytes(buffer[8..16].try_into().ok()?) as usize;
let data_len = u64::from_le_bytes(buffer[16..24].try_into().ok()?) as usize;
if buffer.len() < 24 + data_len {
return None;
}
let data = Bytes::copy_from_slice(&buffer[24..24 + data_len]);
Some(CachedResult {
data,
row_count,
cached_at: Instant::now(),
ttl: std::time::Duration::from_secs(ttl_secs),
tables: Vec::new(), execution_time: std::time::Duration::from_millis(0),
})
}
#[derive(Debug, Clone)]
pub struct L2CacheStats {
pub entry_count: usize,
pub memory_usage_bytes: usize,
pub max_memory_bytes: usize,
pub total_accesses: u64,
pub storage_backend: StorageBackend,
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use crate::cache::CacheContext;
use crate::cache::normalizer::NormalizedQuery;
fn create_result(data: &str) -> CachedResult {
CachedResult::new(
Bytes::from(data.to_string()),
1,
Duration::from_secs(60),
vec!["test".to_string()],
Duration::from_millis(5),
)
}
fn create_key(query_hash: u64) -> CacheKey {
CacheKey::from_parts(
query_hash,
"test".to_string(),
None,
None,
)
}
#[tokio::test]
async fn test_basic_get_put() {
let config = L2Config::default();
let cache = L2WarmCache::new(config);
let key = create_key(12345);
let result = create_result("test data");
assert!(cache.get(&key).await.is_none());
cache.put(key.clone(), result.clone()).await;
let cached = cache.get(&key).await;
assert!(cached.is_some());
assert_eq!(cached.unwrap().data, result.data);
}
#[tokio::test]
async fn test_different_keys() {
let config = L2Config::default();
let cache = L2WarmCache::new(config);
let key1 = create_key(11111);
let key2 = create_key(22222);
let result = create_result("data");
cache.put(key1.clone(), result.clone()).await;
assert!(cache.get(&key1).await.is_some());
assert!(cache.get(&key2).await.is_none());
}
#[tokio::test]
async fn test_expiration() {
let config = L2Config {
ttl: Duration::from_millis(10),
..Default::default()
};
let cache = L2WarmCache::new(config);
let key = create_key(12345);
let mut result = create_result("data");
result.ttl = Duration::from_millis(10);
cache.put(key.clone(), result).await;
assert!(cache.get(&key).await.is_some());
std::thread::sleep(Duration::from_millis(15));
assert!(cache.get(&key).await.is_none());
}
#[tokio::test]
async fn test_remove() {
let config = L2Config::default();
let cache = L2WarmCache::new(config);
let key = create_key(12345);
let result = create_result("data");
cache.put(key.clone(), result).await;
assert!(cache.get(&key).await.is_some());
cache.remove(&key).await;
assert!(cache.get(&key).await.is_none());
}
#[tokio::test]
async fn test_clear() {
let config = L2Config::default();
let cache = L2WarmCache::new(config);
cache.put(create_key(111), create_result("1")).await;
cache.put(create_key(222), create_result("2")).await;
assert_eq!(cache.len(), 2);
cache.clear().await;
assert!(cache.is_empty());
}
#[tokio::test]
async fn test_memory_eviction() {
let config = L2Config {
size_mb: 1, ..Default::default()
};
let cache = L2WarmCache::new(config);
let large_data = "x".repeat(100 * 1024); for i in 0..15 {
cache.put(create_key(i), create_result(&large_data)).await;
}
assert!(cache.memory_usage() <= 1024 * 1024 + 100 * 1024);
}
#[tokio::test]
async fn test_stats() {
let config = L2Config::default();
let cache = L2WarmCache::new(config);
cache.put(create_key(111), create_result("1")).await;
cache.put(create_key(222), create_result("2")).await;
cache.get(&create_key(111)).await;
cache.get(&create_key(111)).await;
let stats = cache.stats();
assert_eq!(stats.entry_count, 2);
assert!(stats.memory_usage_bytes > 0);
assert_eq!(stats.storage_backend, StorageBackend::Memory);
}
#[tokio::test]
async fn test_disabled_cache() {
let config = L2Config {
enabled: false,
..Default::default()
};
let cache = L2WarmCache::new(config);
let key = create_key(12345);
cache.put(key.clone(), create_result("data")).await;
assert!(cache.get(&key).await.is_none());
}
#[test]
fn test_serialize_deserialize() {
let result = create_result("test data for serialization");
let serialized = serialize_result(&result);
let deserialized = deserialize_result(&serialized).unwrap();
assert_eq!(deserialized.data, result.data);
assert_eq!(deserialized.row_count, result.row_count);
assert_eq!(deserialized.ttl.as_secs(), result.ttl.as_secs());
}
}