use std::collections::HashSet;
use std::io::{Read, Write};
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::RwLock;
use dashmap::DashMap;
use super::{CacheEntry, CompressionType, TierStats};
use crate::distribcache::QueryFingerprint;
struct BloomFilter {
bits: Vec<u64>,
num_hashes: usize,
}
impl BloomFilter {
fn new(capacity: usize) -> Self {
let bits_per_item = 10; let num_bits = capacity * bits_per_item;
let num_words = (num_bits + 63) / 64;
Self {
bits: vec![0; num_words],
num_hashes: 7, }
}
fn insert(&mut self, data: &[u8]) {
for i in 0..self.num_hashes {
let hash = self.hash(data, i);
let idx = hash as usize % (self.bits.len() * 64);
let word = idx / 64;
let bit = idx % 64;
self.bits[word] |= 1 << bit;
}
}
fn may_contain(&self, data: &[u8]) -> bool {
for i in 0..self.num_hashes {
let hash = self.hash(data, i);
let idx = hash as usize % (self.bits.len() * 64);
let word = idx / 64;
let bit = idx % 64;
if (self.bits[word] & (1 << bit)) == 0 {
return false;
}
}
true
}
fn hash(&self, data: &[u8], seed: usize) -> u64 {
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
seed.hash(&mut hasher);
data.hash(&mut hasher);
hasher.finish()
}
fn clear(&mut self) {
self.bits.fill(0);
}
}
pub struct WarmCache {
index: DashMap<u64, EntryMetadata>,
data: DashMap<u64, Vec<u8>>,
bloom: RwLock<BloomFilter>,
table_index: DashMap<String, HashSet<u64>>,
compression: CompressionType,
_path: PathBuf,
current_size: AtomicU64,
max_size: u64,
hits: AtomicU64,
misses: AtomicU64,
compressed_size: AtomicU64,
uncompressed_size: AtomicU64,
}
#[derive(Debug, Clone)]
struct EntryMetadata {
compressed_size: usize,
uncompressed_size: usize,
created_at: u64,
ttl_secs: u64,
tables: Vec<String>,
}
impl WarmCache {
pub fn new(max_size: u64, path: PathBuf, compression: CompressionType) -> Self {
Self {
index: DashMap::new(),
data: DashMap::new(),
bloom: RwLock::new(BloomFilter::new(100_000)),
table_index: DashMap::new(),
compression,
_path: path,
current_size: AtomicU64::new(0),
max_size,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
compressed_size: AtomicU64::new(0),
uncompressed_size: AtomicU64::new(0),
}
}
pub fn get(&self, fingerprint: &QueryFingerprint) -> Option<CacheEntry> {
let key = self.fingerprint_to_hash(fingerprint);
let key_bytes = key.to_le_bytes();
{
let bloom = self.bloom.read().ok()?;
if !bloom.may_contain(&key_bytes) {
self.misses.fetch_add(1, Ordering::Relaxed);
return None;
}
}
let metadata = self.index.get(&key)?;
let now = std::time::SystemTime::now()
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if now > metadata.created_at + metadata.ttl_secs {
drop(metadata);
self.remove_entry(key);
self.misses.fetch_add(1, Ordering::Relaxed);
return None;
}
let compressed = self.data.get(&key)?;
let decompressed = self.decompress(&compressed)?;
let entry: CacheEntry = bincode::deserialize(&decompressed).ok()?;
self.hits.fetch_add(1, Ordering::Relaxed);
Some(entry)
}
pub fn insert(&self, fingerprint: QueryFingerprint, entry: CacheEntry) {
let key = self.fingerprint_to_hash(&fingerprint);
let serialized = match bincode::serialize(&entry) {
Ok(s) => s,
Err(_) => return,
};
let uncompressed_size = serialized.len();
let compressed = match self.compress(&serialized) {
Some(c) => c,
None => return,
};
let compressed_size = compressed.len();
while self.current_size.load(Ordering::Relaxed) + compressed_size as u64 > self.max_size {
if !self.evict_oldest() {
break;
}
}
self.remove_entry(key);
let metadata = EntryMetadata {
compressed_size,
uncompressed_size,
created_at: entry.created_at,
ttl_secs: entry.ttl_secs,
tables: entry.tables.clone(),
};
for table in &entry.tables {
self.table_index
.entry(table.clone())
.or_default()
.insert(key);
}
{
if let Ok(mut bloom) = self.bloom.write() {
bloom.insert(&key.to_le_bytes());
}
}
self.index.insert(key, metadata);
self.data.insert(key, compressed);
self.current_size.fetch_add(compressed_size as u64, Ordering::Relaxed);
self.compressed_size.fetch_add(compressed_size as u64, Ordering::Relaxed);
self.uncompressed_size.fetch_add(uncompressed_size as u64, Ordering::Relaxed);
}
pub fn invalidate_by_table(&self, table: &str) {
if let Some((_, keys)) = self.table_index.remove(table) {
for key in keys {
self.remove_entry(key);
}
}
}
pub fn invalidate(&self, fingerprint: &QueryFingerprint) {
let key = self.fingerprint_to_hash(fingerprint);
self.remove_entry(key);
}
fn remove_entry(&self, key: u64) {
if let Some((_, metadata)) = self.index.remove(&key) {
self.data.remove(&key);
self.current_size.fetch_sub(metadata.compressed_size as u64, Ordering::Relaxed);
for table in &metadata.tables {
if let Some(mut keys) = self.table_index.get_mut(table) {
keys.remove(&key);
}
}
}
}
fn evict_oldest(&self) -> bool {
let mut oldest_key = None;
let mut oldest_time = u64::MAX;
for entry in self.index.iter() {
if entry.created_at < oldest_time {
oldest_time = entry.created_at;
oldest_key = Some(*entry.key());
}
}
if let Some(key) = oldest_key {
self.remove_entry(key);
return true;
}
false
}
fn compress(&self, data: &[u8]) -> Option<Vec<u8>> {
match self.compression {
CompressionType::None => {
let mut output = Vec::with_capacity(data.len() + 1);
output.push(0x00); output.extend_from_slice(data);
Some(output)
}
CompressionType::Lz4 => {
let mut output = Vec::with_capacity(data.len() + 1);
output.push(0x01); output.extend_from_slice(data);
Some(output)
}
CompressionType::Zstd => {
let compressed = zstd::stream::encode_all(data, 3).ok()?;
let mut output = Vec::with_capacity(compressed.len() + 1);
output.push(0x02); output.extend_from_slice(&compressed);
Some(output)
}
}
}
fn decompress(&self, data: &[u8]) -> Option<Vec<u8>> {
if data.is_empty() {
return None;
}
let marker = data[0];
let payload = &data[1..];
match marker {
0x00 => Some(payload.to_vec()), 0x01 => Some(payload.to_vec()), 0x02 => {
zstd::stream::decode_all(payload).ok()
}
_ => Some(data.to_vec()), }
}
fn fingerprint_to_hash(&self, fingerprint: &QueryFingerprint) -> u64 {
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
fingerprint.template.hash(&mut hasher);
if let Some(param) = fingerprint.param_hash {
param.hash(&mut hasher);
}
hasher.finish()
}
pub fn stats(&self) -> TierStats {
let compressed = self.compressed_size.load(Ordering::Relaxed);
let uncompressed = self.uncompressed_size.load(Ordering::Relaxed);
TierStats {
size_bytes: self.current_size.load(Ordering::Relaxed),
max_size_bytes: self.max_size,
entry_count: self.index.len() as u64,
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
evictions: 0,
compression_ratio: if compressed > 0 {
Some(uncompressed as f64 / compressed as f64)
} else {
None
},
peer_count: None,
healthy_peers: None,
}
}
pub fn clear(&self) {
self.index.clear();
self.data.clear();
self.table_index.clear();
if let Ok(mut bloom) = self.bloom.write() {
bloom.clear();
}
self.current_size.store(0, Ordering::Relaxed);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_warm_cache_insert_get() {
let cache = WarmCache::new(
1024 * 1024 * 1024,
PathBuf::from("/tmp/test-cache"),
CompressionType::Lz4,
);
let fp = QueryFingerprint::from_query("SELECT * FROM users");
let entry = CacheEntry::new(vec![1, 2, 3], vec!["users".to_string()], 1)
.with_ttl(Duration::from_secs(300));
cache.insert(fp.clone(), entry);
let result = cache.get(&fp);
assert!(result.is_some());
assert_eq!(result.unwrap().data, vec![1, 2, 3]);
}
#[test]
fn test_warm_cache_bloom_filter() {
let cache = WarmCache::new(
1024 * 1024,
PathBuf::from("/tmp/test-cache"),
CompressionType::None,
);
let fp1 = QueryFingerprint::from_query("SELECT * FROM users");
let fp2 = QueryFingerprint::from_query("SELECT * FROM orders");
cache.insert(
fp1.clone(),
CacheEntry::new(vec![1], vec![], 1).with_ttl(Duration::from_secs(300)),
);
assert!(cache.get(&fp1).is_some());
assert!(cache.get(&fp2).is_none());
}
#[test]
fn test_warm_cache_invalidate_by_table() {
let cache = WarmCache::new(
1024 * 1024,
PathBuf::from("/tmp/test-cache"),
CompressionType::None,
);
let fp1 = QueryFingerprint::from_query("SELECT * FROM users");
let fp2 = QueryFingerprint::from_query("SELECT * FROM orders");
cache.insert(
fp1.clone(),
CacheEntry::new(vec![1], vec!["users".to_string()], 1)
.with_ttl(Duration::from_secs(300)),
);
cache.insert(
fp2.clone(),
CacheEntry::new(vec![2], vec!["orders".to_string()], 1)
.with_ttl(Duration::from_secs(300)),
);
cache.invalidate_by_table("users");
assert!(cache.get(&fp1).is_none());
assert!(cache.get(&fp2).is_some());
}
#[test]
fn test_warm_cache_stats() {
let cache = WarmCache::new(
1024 * 1024,
PathBuf::from("/tmp/test-cache"),
CompressionType::Lz4,
);
let fp = QueryFingerprint::from_query("SELECT * FROM users");
cache.insert(
fp.clone(),
CacheEntry::new(vec![1], vec![], 1).with_ttl(Duration::from_secs(300)),
);
cache.get(&fp); let fp2 = QueryFingerprint::from_query("SELECT * FROM orders");
cache.get(&fp2);
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert!(stats.compression_ratio.is_some());
}
}