use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::sync::Arc;
use lru::LruCache;
use parking_lot::{Condvar, Mutex};
use smallvec::SmallVec;
use crate::error::Result;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ChunkKey {
pub dataset_addr: u64,
pub chunk_offsets: SmallVec<[u64; 4]>,
}
pub struct ChunkCache {
inner: Mutex<ChunkCacheState>,
max_bytes: usize,
}
struct ChunkCacheState {
cache: LruCache<ChunkKey, Arc<Vec<u8>>>,
current_bytes: usize,
in_flight: HashMap<ChunkKey, Arc<InFlightLoad>>,
}
struct InFlightLoad {
completed: Mutex<bool>,
ready: Condvar,
}
impl ChunkCache {
pub fn new(max_bytes: usize, max_slots: usize) -> Self {
let slots = NonZeroUsize::new(max_slots).unwrap_or(NonZeroUsize::new(521).unwrap());
ChunkCache {
inner: Mutex::new(ChunkCacheState {
cache: LruCache::new(slots),
current_bytes: 0,
in_flight: HashMap::new(),
}),
max_bytes,
}
}
pub fn get(&self, key: &ChunkKey) -> Option<Arc<Vec<u8>>> {
let mut cache = self.inner.lock();
cache.cache.get(key).cloned()
}
pub fn insert(&self, key: ChunkKey, data: Vec<u8>) -> Arc<Vec<u8>> {
let data_len = data.len();
let arc = Arc::new(data);
if self.max_bytes == 0 || data_len > self.max_bytes {
return arc;
}
let mut state = self.inner.lock();
while state.current_bytes + data_len > self.max_bytes && !state.cache.is_empty() {
if let Some((_, evicted)) = state.cache.pop_lru() {
state.current_bytes = state.current_bytes.saturating_sub(evicted.len());
}
}
if let Some(replaced) = state.cache.peek(&key) {
state.current_bytes = state.current_bytes.saturating_sub(replaced.len());
}
state.current_bytes += data_len;
state.cache.put(key, arc.clone());
arc
}
pub fn get_or_insert_with<F>(&self, key: ChunkKey, load: F) -> Result<Arc<Vec<u8>>>
where
F: FnOnce() -> Result<Vec<u8>>,
{
loop {
let in_flight = {
let mut state = self.inner.lock();
if let Some(cached) = state.cache.get(&key).cloned() {
return Ok(cached);
}
if let Some(in_flight) = state.in_flight.get(&key) {
Arc::clone(in_flight)
} else {
let in_flight = Arc::new(InFlightLoad {
completed: Mutex::new(false),
ready: Condvar::new(),
});
state.in_flight.insert(key.clone(), Arc::clone(&in_flight));
drop(state);
let result = load().map(|data| self.insert(key.clone(), data));
let mut state = self.inner.lock();
state.in_flight.remove(&key);
let mut completed = in_flight.completed.lock();
*completed = true;
in_flight.ready.notify_all();
return result;
}
};
let mut completed = in_flight.completed.lock();
while !*completed {
in_flight.ready.wait(&mut completed);
}
}
}
}
impl Default for ChunkCache {
fn default() -> Self {
Self::new(64 * 1024 * 1024, 521)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_insert_and_get() {
let cache = ChunkCache::new(1024, 10);
let key = ChunkKey {
dataset_addr: 100,
chunk_offsets: SmallVec::from_vec(vec![0, 0]),
};
cache.insert(key.clone(), vec![1, 2, 3]);
let val = cache.get(&key).unwrap();
assert_eq!(&**val, &[1, 2, 3]);
}
#[test]
fn test_cache_eviction() {
let cache = ChunkCache::new(10, 10); for i in 0..5 {
let key = ChunkKey {
dataset_addr: 100,
chunk_offsets: SmallVec::from_vec(vec![i]),
};
cache.insert(key, vec![0; 4]); }
let first_key = ChunkKey {
dataset_addr: 100,
chunk_offsets: SmallVec::from_vec(vec![0]),
};
assert!(cache.get(&first_key).is_none()); }
#[test]
fn test_cache_disabled_bypasses_storage() {
let cache = ChunkCache::new(0, 10);
let key = ChunkKey {
dataset_addr: 100,
chunk_offsets: SmallVec::from_vec(vec![0]),
};
cache.insert(key.clone(), vec![1, 2, 3]);
assert!(cache.get(&key).is_none());
}
#[test]
fn test_cache_promotes_on_get() {
let cache = ChunkCache::new(12, 10); let key_a = ChunkKey {
dataset_addr: 1,
chunk_offsets: SmallVec::from_vec(vec![0]),
};
let key_b = ChunkKey {
dataset_addr: 2,
chunk_offsets: SmallVec::from_vec(vec![0]),
};
let key_c = ChunkKey {
dataset_addr: 3,
chunk_offsets: SmallVec::from_vec(vec![0]),
};
cache.insert(key_a.clone(), vec![0; 4]); cache.insert(key_b.clone(), vec![0; 4]); cache.insert(key_c.clone(), vec![0; 4]);
assert!(cache.get(&key_a).is_some());
let key_d = ChunkKey {
dataset_addr: 4,
chunk_offsets: SmallVec::from_vec(vec![0]),
};
cache.insert(key_d, vec![0; 4]);
assert!(cache.get(&key_a).is_some()); assert!(cache.get(&key_b).is_none()); }
#[test]
fn test_cache_replacement_updates_accounting() {
let cache = ChunkCache::new(8, 10);
let key = ChunkKey {
dataset_addr: 100,
chunk_offsets: SmallVec::from_vec(vec![0]),
};
cache.insert(key.clone(), vec![1, 2, 3, 4]);
cache.insert(key.clone(), vec![5, 6]);
let other = ChunkKey {
dataset_addr: 100,
chunk_offsets: SmallVec::from_vec(vec![1]),
};
cache.insert(other.clone(), vec![7, 8, 9, 10]);
assert_eq!(&**cache.get(&key).unwrap(), &[5, 6]);
assert!(cache.get(&other).is_some());
}
#[test]
fn test_cache_get_or_insert_with_deduplicates_concurrent_loads() {
use std::sync::atomic::{AtomicUsize, Ordering};
let cache = Arc::new(ChunkCache::new(1024, 10));
let key = ChunkKey {
dataset_addr: 100,
chunk_offsets: SmallVec::from_vec(vec![0, 0]),
};
let load_count = Arc::new(AtomicUsize::new(0));
std::thread::scope(|scope| {
for _ in 0..8 {
let cache = Arc::clone(&cache);
let key = key.clone();
let load_count = Arc::clone(&load_count);
scope.spawn(move || {
let value = cache
.get_or_insert_with(key, || {
load_count.fetch_add(1, Ordering::SeqCst);
std::thread::sleep(std::time::Duration::from_millis(10));
Ok(vec![1, 2, 3, 4])
})
.unwrap();
assert_eq!(&*value, &[1, 2, 3, 4]);
});
}
});
assert_eq!(load_count.load(Ordering::SeqCst), 1);
}
}