use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::ops::Range;
use std::sync::Arc;
use std::sync::Mutex;
use crate::ReadAt;
pub const DEFAULT_MEM_LIMIT: usize = 1024 * 1024 * 20;
#[derive(Clone, Debug)]
pub struct CachedData(Arc<Vec<u8>>);
impl AsRef<[u8]> for CachedData {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
#[derive(Debug)]
pub struct RangeCache<K: Eq + Hash = Range<u64>> {
entries: HashMap<K, CachedData>,
mem_limit: usize,
}
impl<K: Eq + Hash + Clone> RangeCache<K> {
pub fn new(mem_limit: usize) -> Self {
Self {
entries: HashMap::new(),
mem_limit,
}
}
pub fn get(&self, key: &K) -> Option<CachedData> {
self.entries.get(key).cloned()
}
pub fn insert(&mut self, key: K, data: Vec<u8>) -> CachedData {
let cached = CachedData(Arc::new(data));
let mut mem_usage: usize =
self.entries.values().map(|v| v.0.len()).sum::<usize>() + cached.0.len();
if mem_usage > self.mem_limit {
let mut by_size: Vec<(K, usize)> = self
.entries
.iter()
.map(|(k, v)| (k.clone(), v.0.len()))
.collect();
by_size.sort_by_key(|(_, size)| *size);
for (key, size) in by_size {
self.entries.remove(&key);
mem_usage -= size;
if mem_usage <= self.mem_limit {
break;
}
}
}
self.entries.insert(key, cached.clone());
cached
}
}
impl<K: Eq + Hash + Clone> Default for RangeCache<K> {
fn default() -> Self {
Self::new(DEFAULT_MEM_LIMIT)
}
}
type CacheKey<Ctx> = (Ctx, Range<u64>);
pub struct CachingReader<T, Ctx: Clone + Eq + Hash = ()> {
inner: T,
cache: Mutex<RangeCache<CacheKey<Ctx>>>,
}
impl<T, Ctx: Clone + Eq + Hash> CachingReader<T, Ctx> {
pub fn new(inner: T) -> Self {
Self {
inner,
cache: Mutex::new(RangeCache::default()),
}
}
pub fn with_mem_limit(inner: T, mem_limit: usize) -> Self {
Self {
inner,
cache: Mutex::new(RangeCache::new(mem_limit)),
}
}
pub fn inner(&self) -> &T {
&self.inner
}
}
impl<T: Debug, Ctx: Clone + Eq + Hash> Debug for CachingReader<T, Ctx> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CachingReader")
.field("inner", &self.inner)
.finish_non_exhaustive()
}
}
impl<T: ReadAt<Ctx>, Ctx: Clone + Eq + Hash> CachingReader<T, Ctx> {
pub fn read_at_with(&self, ctx: &Ctx, range: Range<u64>) -> std::io::Result<CachedData> {
let key = (ctx.clone(), range.clone());
{
let cache = self.cache.lock().unwrap();
if let Some(cached) = cache.get(&key) {
return Ok(cached);
}
}
let data = self.inner.read_at(ctx, range)?;
let vec = data.as_ref().to_vec();
let mut cache = self.cache.lock().unwrap();
Ok(cache.insert(key, vec))
}
}
impl<T: ReadAt> CachingReader<T> {
pub fn read_at(&self, range: Range<u64>) -> std::io::Result<CachedData> {
self.read_at_with(&(), range)
}
}
impl<T: ReadAt<Ctx>, Ctx: Clone + Eq + Hash> ReadAt<Ctx> for CachingReader<T, Ctx> {
fn read_at(&self, ctx: &Ctx, range: Range<u64>) -> std::io::Result<impl AsRef<[u8]>> {
self.read_at_with(ctx, range)
}
}
pub struct CachingAsyncReader<T, Ctx: Clone + Eq + Hash = ()> {
inner: T,
cache: Mutex<RangeCache<CacheKey<Ctx>>>,
}
impl<T, Ctx: Clone + Eq + Hash> CachingAsyncReader<T, Ctx> {
pub fn new(inner: T) -> Self {
Self {
inner,
cache: Mutex::new(RangeCache::default()),
}
}
pub fn with_mem_limit(inner: T, mem_limit: usize) -> Self {
Self {
inner,
cache: Mutex::new(RangeCache::new(mem_limit)),
}
}
pub fn inner(&self) -> &T {
&self.inner
}
}
impl<T: Debug, Ctx: Clone + Eq + Hash> Debug for CachingAsyncReader<T, Ctx> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CachingAsyncReader")
.field("inner", &self.inner)
.finish_non_exhaustive()
}
}
impl<T: crate::AsyncReadAt<Ctx>, Ctx: Clone + Eq + Hash + Send + Sync> CachingAsyncReader<T, Ctx> {
pub async fn read_at_with(&self, ctx: &Ctx, range: Range<u64>) -> std::io::Result<CachedData> {
let key = (ctx.clone(), range.clone());
{
let cache = self.cache.lock().unwrap();
if let Some(cached) = cache.get(&key) {
return Ok(cached);
}
}
let data = self.inner.read_at(ctx, range).await?;
let vec = data.as_ref().to_vec();
let mut cache = self.cache.lock().unwrap();
Ok(cache.insert(key, vec))
}
}
impl<T: crate::AsyncReadAt> CachingAsyncReader<T> {
pub async fn read_at(&self, range: Range<u64>) -> std::io::Result<CachedData> {
self.read_at_with(&(), range).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
struct CountingSource {
data: Vec<u8>,
read_count: AtomicUsize,
}
impl CountingSource {
fn new(data: Vec<u8>) -> Self {
Self {
data,
read_count: AtomicUsize::new(0),
}
}
fn read_count(&self) -> usize {
self.read_count.load(Ordering::Relaxed)
}
}
impl ReadAt for CountingSource {
fn read_at(&self, _ctx: &(), range: Range<u64>) -> std::io::Result<impl AsRef<[u8]>> {
self.read_count.fetch_add(1, Ordering::Relaxed);
Ok(self.data[range.start as usize..range.end as usize].to_vec())
}
}
#[test]
fn cache_hit_avoids_read() {
let source = CountingSource::new(vec![42u8; 1024]);
let cached = CachingReader::new(source);
let d1 = cached.read_at(0..100).unwrap();
assert_eq!(d1.as_ref().len(), 100);
assert_eq!(cached.inner().read_count(), 1);
let d2 = cached.read_at(0..100).unwrap();
assert_eq!(d2.as_ref().len(), 100);
assert_eq!(cached.inner().read_count(), 1);
let d3 = cached.read_at(100..200).unwrap();
assert_eq!(d3.as_ref().len(), 100);
assert_eq!(cached.inner().read_count(), 2);
}
#[test]
fn eviction_under_pressure() {
let source = CountingSource::new(vec![0u8; 4096]);
let cached = CachingReader::with_mem_limit(source, 256);
let _ = cached.read_at(0..100).unwrap();
assert_eq!(cached.inner().read_count(), 1);
let _ = cached.read_at(0..100).unwrap();
assert_eq!(cached.inner().read_count(), 1);
let _ = cached.read_at(100..300).unwrap();
assert_eq!(cached.inner().read_count(), 2);
let _ = cached.read_at(0..100).unwrap();
assert_eq!(cached.inner().read_count(), 3);
}
#[test]
fn context_aware_caching() {
use std::collections::HashMap;
struct MultiSource {
volumes: HashMap<String, Vec<u8>>,
read_count: AtomicUsize,
}
impl ReadAt<String> for MultiSource {
fn read_at(
&self,
volume: &String,
range: Range<u64>,
) -> std::io::Result<impl AsRef<[u8]>> {
self.read_count.fetch_add(1, Ordering::Relaxed);
let data = self.volumes.get(volume).unwrap();
Ok(data[range.start as usize..range.end as usize].to_vec())
}
}
let mut volumes = HashMap::new();
volumes.insert("a.pkg".to_string(), vec![1u8; 1024]);
volumes.insert("b.pkg".to_string(), vec![2u8; 1024]);
let source = MultiSource {
volumes,
read_count: AtomicUsize::new(0),
};
let cached: CachingReader<_, String> = CachingReader::new(source);
let d1 = cached.read_at_with(&"a.pkg".to_string(), 0..100).unwrap();
assert_eq!(d1.as_ref()[0], 1);
assert_eq!(cached.inner().read_count.load(Ordering::Relaxed), 1);
let d2 = cached.read_at_with(&"b.pkg".to_string(), 0..100).unwrap();
assert_eq!(d2.as_ref()[0], 2);
assert_eq!(cached.inner().read_count.load(Ordering::Relaxed), 2);
let d3 = cached.read_at_with(&"a.pkg".to_string(), 0..100).unwrap();
assert_eq!(d3.as_ref()[0], 1);
assert_eq!(cached.inner().read_count.load(Ordering::Relaxed), 2);
}
}