use crate::error::{AmateRSError, ErrorContext, Result};
use parking_lot::RwLock;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BlockCacheKey {
pub sstable_path: String,
pub block_index: usize,
}
impl BlockCacheKey {
pub fn new(sstable_path: String, block_index: usize) -> Self {
Self {
sstable_path,
block_index,
}
}
}
#[derive(Debug, Clone)]
pub struct CachedBlock {
pub data: Arc<Vec<u8>>,
pub size: usize,
}
impl CachedBlock {
pub fn new(data: Vec<u8>) -> Self {
let size = data.len();
Self {
data: Arc::new(data),
size,
}
}
pub fn as_slice(&self) -> &[u8] {
&self.data
}
}
#[derive(Debug, Clone)]
pub struct BlockCacheConfig {
pub max_size_bytes: usize,
pub enable_stats: bool,
}
impl Default for BlockCacheConfig {
fn default() -> Self {
Self {
max_size_bytes: 128 * 1024 * 1024, enable_stats: true,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub block_count: usize,
pub size_bytes: usize,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn miss_rate(&self) -> f64 {
1.0 - self.hit_rate()
}
}
struct CacheEntry {
key: BlockCacheKey,
block: CachedBlock,
}
pub struct BlockCache {
config: BlockCacheConfig,
cache: Arc<RwLock<HashMap<BlockCacheKey, CachedBlock>>>,
lru_order: Arc<RwLock<VecDeque<BlockCacheKey>>>,
current_size: Arc<RwLock<usize>>,
stats: Arc<RwLock<CacheStats>>,
}
impl BlockCache {
pub fn new() -> Self {
Self::with_config(BlockCacheConfig::default())
}
pub fn with_config(config: BlockCacheConfig) -> Self {
Self {
config,
cache: Arc::new(RwLock::new(HashMap::new())),
lru_order: Arc::new(RwLock::new(VecDeque::new())),
current_size: Arc::new(RwLock::new(0)),
stats: Arc::new(RwLock::new(CacheStats::default())),
}
}
pub fn get(&self, key: &BlockCacheKey) -> Option<CachedBlock> {
let block = {
let cache = self.cache.read();
cache.get(key).cloned()
};
if let Some(ref block) = block {
self.touch(key);
if self.config.enable_stats {
let mut stats = self.stats.write();
stats.hits += 1;
}
Some(block.clone())
} else {
if self.config.enable_stats {
let mut stats = self.stats.write();
stats.misses += 1;
}
None
}
}
pub fn put(&self, key: BlockCacheKey, block: CachedBlock) -> Result<()> {
let block_size = block.size;
self.evict_if_needed(block_size)?;
let (new_block_count, new_size_bytes) = {
let mut cache = self.cache.write();
let mut lru_order = self.lru_order.write();
let mut current_size = self.current_size.write();
if let Some(old_block) = cache.remove(&key) {
*current_size -= old_block.size;
lru_order.retain(|k| k != &key);
}
cache.insert(key.clone(), block);
lru_order.push_back(key);
*current_size += block_size;
(cache.len(), *current_size)
};
if self.config.enable_stats {
let mut stats = self.stats.write();
stats.block_count = new_block_count;
stats.size_bytes = new_size_bytes;
}
Ok(())
}
fn touch(&self, key: &BlockCacheKey) {
let mut lru_order = self.lru_order.write();
lru_order.retain(|k| k != key);
lru_order.push_back(key.clone());
}
fn evict_if_needed(&self, new_block_size: usize) -> Result<()> {
if new_block_size > self.config.max_size_bytes {
return Err(AmateRSError::StorageIntegrity(ErrorContext::new(format!(
"Block size {} exceeds cache size {}",
new_block_size, self.config.max_size_bytes
))));
}
let current_size = *self.current_size.read();
let mut size_to_free =
(current_size + new_block_size).saturating_sub(self.config.max_size_bytes);
while size_to_free > 0 {
let (evicted_size, should_update_stats) = {
let mut cache = self.cache.write();
let mut lru_order = self.lru_order.write();
let mut current_size = self.current_size.write();
if let Some(key) = lru_order.front().cloned() {
if let Some(block) = cache.remove(&key) {
lru_order.pop_front();
*current_size -= block.size;
(block.size, self.config.enable_stats)
} else {
(0, false)
}
} else {
(0, false)
}
};
if evicted_size == 0 {
break;
}
if should_update_stats {
let mut stats = self.stats.write();
stats.evictions += 1;
}
if evicted_size >= size_to_free {
size_to_free = 0;
} else {
size_to_free -= evicted_size;
}
}
Ok(())
}
pub fn clear(&self) {
let mut cache = self.cache.write();
let mut lru_order = self.lru_order.write();
let mut current_size = self.current_size.write();
cache.clear();
lru_order.clear();
*current_size = 0;
if self.config.enable_stats {
let mut stats = self.stats.write();
stats.block_count = 0;
stats.size_bytes = 0;
}
}
pub fn stats(&self) -> CacheStats {
self.stats.read().clone()
}
pub fn current_size(&self) -> usize {
*self.current_size.read()
}
pub fn block_count(&self) -> usize {
self.cache.read().len()
}
pub fn contains(&self, key: &BlockCacheKey) -> bool {
self.cache.read().contains_key(key)
}
pub fn remove(&self, key: &BlockCacheKey) -> Option<CachedBlock> {
let mut cache = self.cache.write();
let mut lru_order = self.lru_order.write();
let mut current_size = self.current_size.write();
if let Some(block) = cache.remove(key) {
lru_order.retain(|k| k != key);
*current_size -= block.size;
if self.config.enable_stats {
let mut stats = self.stats.write();
stats.block_count = cache.len();
stats.size_bytes = *current_size;
}
Some(block)
} else {
None
}
}
pub fn invalidate_sstable(&self, sstable_path: &str) {
let mut cache = self.cache.write();
let mut lru_order = self.lru_order.write();
let mut current_size = self.current_size.write();
let keys_to_remove: Vec<BlockCacheKey> = cache
.keys()
.filter(|k| k.sstable_path == sstable_path)
.cloned()
.collect();
for key in keys_to_remove {
if let Some(block) = cache.remove(&key) {
*current_size -= block.size;
lru_order.retain(|k| k != &key);
}
}
if self.config.enable_stats {
let mut stats = self.stats.write();
stats.block_count = cache.len();
stats.size_bytes = *current_size;
}
}
}
impl Default for BlockCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_block_cache_basic() -> Result<()> {
let cache = BlockCache::new();
let key = BlockCacheKey::new("test.sst".to_string(), 0);
let block = CachedBlock::new(vec![1, 2, 3, 4, 5]);
assert!(cache.get(&key).is_none());
cache.put(key.clone(), block.clone())?;
let retrieved = cache.get(&key).expect("Block should be in cache after put");
assert_eq!(retrieved.as_slice(), &[1, 2, 3, 4, 5]);
Ok(())
}
#[test]
fn test_block_cache_lru_eviction() -> Result<()> {
let config = BlockCacheConfig {
max_size_bytes: 100,
enable_stats: true,
};
let cache = BlockCache::with_config(config);
for i in 0..5 {
let key = BlockCacheKey::new("test.sst".to_string(), i);
let block = CachedBlock::new(vec![0u8; 30]); cache.put(key, block)?;
}
assert!(cache.current_size() <= 100);
let key0 = BlockCacheKey::new("test.sst".to_string(), 0);
let key1 = BlockCacheKey::new("test.sst".to_string(), 1);
assert!(cache.get(&key0).is_none());
assert!(cache.get(&key1).is_none());
let key4 = BlockCacheKey::new("test.sst".to_string(), 4);
assert!(cache.get(&key4).is_some());
Ok(())
}
#[test]
fn test_block_cache_touch() -> Result<()> {
let config = BlockCacheConfig {
max_size_bytes: 100,
enable_stats: true,
};
let cache = BlockCache::with_config(config);
for i in 0..3 {
let key = BlockCacheKey::new("test.sst".to_string(), i);
let block = CachedBlock::new(vec![0u8; 30]);
cache.put(key, block)?;
}
let key0 = BlockCacheKey::new("test.sst".to_string(), 0);
cache.get(&key0);
let key3 = BlockCacheKey::new("test.sst".to_string(), 3);
let block3 = CachedBlock::new(vec![0u8; 30]);
cache.put(key3, block3)?;
assert!(cache.get(&key0).is_some());
let key1 = BlockCacheKey::new("test.sst".to_string(), 1);
assert!(cache.get(&key1).is_none());
Ok(())
}
#[test]
fn test_block_cache_stats() -> Result<()> {
let cache = BlockCache::new();
let key = BlockCacheKey::new("test.sst".to_string(), 0);
let block = CachedBlock::new(vec![1, 2, 3]);
cache.get(&key);
cache.put(key.clone(), block)?;
cache.get(&key);
cache.get(&key);
let stats = cache.stats();
assert_eq!(stats.hits, 2);
assert_eq!(stats.misses, 1);
assert_eq!(stats.hit_rate(), 2.0 / 3.0);
Ok(())
}
#[test]
fn test_block_cache_clear() -> Result<()> {
let cache = BlockCache::new();
for i in 0..5 {
let key = BlockCacheKey::new("test.sst".to_string(), i);
let block = CachedBlock::new(vec![0u8; 100]);
cache.put(key, block)?;
}
assert!(cache.block_count() > 0);
assert!(cache.current_size() > 0);
cache.clear();
assert_eq!(cache.block_count(), 0);
assert_eq!(cache.current_size(), 0);
Ok(())
}
#[test]
fn test_block_cache_remove() -> Result<()> {
let cache = BlockCache::new();
let key = BlockCacheKey::new("test.sst".to_string(), 0);
let block = CachedBlock::new(vec![1, 2, 3]);
cache.put(key.clone(), block)?;
assert!(cache.contains(&key));
cache.remove(&key);
assert!(!cache.contains(&key));
Ok(())
}
#[test]
fn test_block_cache_invalidate_sstable() -> Result<()> {
let cache = BlockCache::new();
for i in 0..3 {
let key = BlockCacheKey::new("test1.sst".to_string(), i);
let block = CachedBlock::new(vec![0u8; 100]);
cache.put(key, block)?;
}
for i in 0..3 {
let key = BlockCacheKey::new("test2.sst".to_string(), i);
let block = CachedBlock::new(vec![0u8; 100]);
cache.put(key, block)?;
}
assert_eq!(cache.block_count(), 6);
cache.invalidate_sstable("test1.sst");
assert_eq!(cache.block_count(), 3);
let key1 = BlockCacheKey::new("test1.sst".to_string(), 0);
assert!(!cache.contains(&key1));
let key2 = BlockCacheKey::new("test2.sst".to_string(), 0);
assert!(cache.contains(&key2));
Ok(())
}
#[test]
fn test_block_cache_concurrent() -> Result<()> {
use std::sync::Arc;
use std::thread;
let cache = Arc::new(BlockCache::new());
let mut handles = vec![];
for thread_id in 0..4 {
let cache = Arc::clone(&cache);
let handle = thread::spawn(move || {
for i in 0..100 {
let key = BlockCacheKey::new(format!("test_{}.sst", thread_id), i);
let block = CachedBlock::new(vec![thread_id as u8; 100]);
cache
.put(key.clone(), block)
.expect("Cache put should succeed in concurrent test");
cache.get(&key);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().expect("Thread should complete successfully");
}
assert!(cache.block_count() > 0);
let stats = cache.stats();
assert!(stats.hits > 0);
Ok(())
}
}