fskit 0.2.0

Abstractions for building read-only sans-io abstractions VFS
Documentation
//! Caching wrappers for byte-range readers.

use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::ops::Range;
use std::sync::Arc;
use std::sync::Mutex;

use crate::ReadAt;

pub const DEFAULT_MEM_LIMIT: usize = 1024 * 1024 * 20;

/// Cheap-to-clone cached byte buffer.
#[derive(Clone, Debug)]
pub struct CachedData(Arc<Vec<u8>>);

impl AsRef<[u8]> for CachedData {
    fn as_ref(&self) -> &[u8] {
        &self.0
    }
}

/// Bounded cache with largest-first eviction. Generic over cache key `K`.
#[derive(Debug)]
pub struct RangeCache<K: Eq + Hash = Range<u64>> {
    entries: HashMap<K, CachedData>,
    mem_limit: usize,
}

impl<K: Eq + Hash + Clone> RangeCache<K> {
    pub fn new(mem_limit: usize) -> Self {
        Self {
            entries: HashMap::new(),
            mem_limit,
        }
    }

    pub fn get(&self, key: &K) -> Option<CachedData> {
        self.entries.get(key).cloned()
    }

    /// Insert data, evicting largest entries first if over budget.
    pub fn insert(&mut self, key: K, data: Vec<u8>) -> CachedData {
        let cached = CachedData(Arc::new(data));

        let mut mem_usage: usize =
            self.entries.values().map(|v| v.0.len()).sum::<usize>() + cached.0.len();
        if mem_usage > self.mem_limit {
            let mut by_size: Vec<(K, usize)> = self
                .entries
                .iter()
                .map(|(k, v)| (k.clone(), v.0.len()))
                .collect();
            by_size.sort_by_key(|(_, size)| *size);
            for (key, size) in by_size {
                self.entries.remove(&key);
                mem_usage -= size;
                if mem_usage <= self.mem_limit {
                    break;
                }
            }
        }

        self.entries.insert(key, cached.clone());
        cached
    }
}

impl<K: Eq + Hash + Clone> Default for RangeCache<K> {
    fn default() -> Self {
        Self::new(DEFAULT_MEM_LIMIT)
    }
}

type CacheKey<Ctx> = (Ctx, Range<u64>);

/// Caching wrapper around a sync [`ReadAt`] source.
///
/// # Example
///
/// ```
/// use fskit::{ReadAt, cache::CachingReader};
///
/// struct FileSource(Vec<u8>);
///
/// impl ReadAt for FileSource {
///     fn read_at(&self, _ctx: &(), range: std::ops::Range<u64>) -> std::io::Result<impl AsRef<[u8]>> {
///         Ok(&self.0[range.start as usize..range.end as usize])
///     }
/// }
///
/// let source = FileSource(vec![0u8; 1024]);
/// let cached = CachingReader::new(source);
///
/// let data = cached.read_at(0..100).unwrap();
/// assert_eq!(data.as_ref().len(), 100);
/// ```
pub struct CachingReader<T, Ctx: Clone + Eq + Hash = ()> {
    inner: T,
    cache: Mutex<RangeCache<CacheKey<Ctx>>>,
}

impl<T, Ctx: Clone + Eq + Hash> CachingReader<T, Ctx> {
    pub fn new(inner: T) -> Self {
        Self {
            inner,
            cache: Mutex::new(RangeCache::default()),
        }
    }

    pub fn with_mem_limit(inner: T, mem_limit: usize) -> Self {
        Self {
            inner,
            cache: Mutex::new(RangeCache::new(mem_limit)),
        }
    }

    pub fn inner(&self) -> &T {
        &self.inner
    }
}

impl<T: Debug, Ctx: Clone + Eq + Hash> Debug for CachingReader<T, Ctx> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("CachingReader")
            .field("inner", &self.inner)
            .finish_non_exhaustive()
    }
}

impl<T: ReadAt<Ctx>, Ctx: Clone + Eq + Hash> CachingReader<T, Ctx> {
    pub fn read_at_with(&self, ctx: &Ctx, range: Range<u64>) -> std::io::Result<CachedData> {
        let key = (ctx.clone(), range.clone());

        {
            let cache = self.cache.lock().unwrap();
            if let Some(cached) = cache.get(&key) {
                return Ok(cached);
            }
        }

        let data = self.inner.read_at(ctx, range)?;
        let vec = data.as_ref().to_vec();

        let mut cache = self.cache.lock().unwrap();
        Ok(cache.insert(key, vec))
    }
}

impl<T: ReadAt> CachingReader<T> {
    pub fn read_at(&self, range: Range<u64>) -> std::io::Result<CachedData> {
        self.read_at_with(&(), range)
    }
}

impl<T: ReadAt<Ctx>, Ctx: Clone + Eq + Hash> ReadAt<Ctx> for CachingReader<T, Ctx> {
    fn read_at(&self, ctx: &Ctx, range: Range<u64>) -> std::io::Result<impl AsRef<[u8]>> {
        self.read_at_with(ctx, range)
    }
}

/// Async caching wrapper around an [`AsyncReadAt`] source.
///
/// Uses `std::sync::Mutex` for the cache since lookups and inserts are
/// non-blocking.
pub struct CachingAsyncReader<T, Ctx: Clone + Eq + Hash = ()> {
    inner: T,
    cache: Mutex<RangeCache<CacheKey<Ctx>>>,
}

impl<T, Ctx: Clone + Eq + Hash> CachingAsyncReader<T, Ctx> {
    pub fn new(inner: T) -> Self {
        Self {
            inner,
            cache: Mutex::new(RangeCache::default()),
        }
    }

    pub fn with_mem_limit(inner: T, mem_limit: usize) -> Self {
        Self {
            inner,
            cache: Mutex::new(RangeCache::new(mem_limit)),
        }
    }

    pub fn inner(&self) -> &T {
        &self.inner
    }
}

impl<T: Debug, Ctx: Clone + Eq + Hash> Debug for CachingAsyncReader<T, Ctx> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("CachingAsyncReader")
            .field("inner", &self.inner)
            .finish_non_exhaustive()
    }
}

impl<T: crate::AsyncReadAt<Ctx>, Ctx: Clone + Eq + Hash + Send + Sync> CachingAsyncReader<T, Ctx> {
    pub async fn read_at_with(&self, ctx: &Ctx, range: Range<u64>) -> std::io::Result<CachedData> {
        let key = (ctx.clone(), range.clone());

        {
            let cache = self.cache.lock().unwrap();
            if let Some(cached) = cache.get(&key) {
                return Ok(cached);
            }
        }

        let data = self.inner.read_at(ctx, range).await?;
        let vec = data.as_ref().to_vec();

        let mut cache = self.cache.lock().unwrap();
        Ok(cache.insert(key, vec))
    }
}

impl<T: crate::AsyncReadAt> CachingAsyncReader<T> {
    pub async fn read_at(&self, range: Range<u64>) -> std::io::Result<CachedData> {
        self.read_at_with(&(), range).await
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::atomic::AtomicUsize;
    use std::sync::atomic::Ordering;

    struct CountingSource {
        data: Vec<u8>,
        read_count: AtomicUsize,
    }

    impl CountingSource {
        fn new(data: Vec<u8>) -> Self {
            Self {
                data,
                read_count: AtomicUsize::new(0),
            }
        }

        fn read_count(&self) -> usize {
            self.read_count.load(Ordering::Relaxed)
        }
    }

    impl ReadAt for CountingSource {
        fn read_at(&self, _ctx: &(), range: Range<u64>) -> std::io::Result<impl AsRef<[u8]>> {
            self.read_count.fetch_add(1, Ordering::Relaxed);
            Ok(self.data[range.start as usize..range.end as usize].to_vec())
        }
    }

    #[test]
    fn cache_hit_avoids_read() {
        let source = CountingSource::new(vec![42u8; 1024]);
        let cached = CachingReader::new(source);

        let d1 = cached.read_at(0..100).unwrap();
        assert_eq!(d1.as_ref().len(), 100);
        assert_eq!(cached.inner().read_count(), 1);

        let d2 = cached.read_at(0..100).unwrap();
        assert_eq!(d2.as_ref().len(), 100);
        assert_eq!(cached.inner().read_count(), 1);

        let d3 = cached.read_at(100..200).unwrap();
        assert_eq!(d3.as_ref().len(), 100);
        assert_eq!(cached.inner().read_count(), 2);
    }

    #[test]
    fn eviction_under_pressure() {
        let source = CountingSource::new(vec![0u8; 4096]);
        let cached = CachingReader::with_mem_limit(source, 256);

        let _ = cached.read_at(0..100).unwrap();
        assert_eq!(cached.inner().read_count(), 1);

        let _ = cached.read_at(0..100).unwrap();
        assert_eq!(cached.inner().read_count(), 1);

        let _ = cached.read_at(100..300).unwrap();
        assert_eq!(cached.inner().read_count(), 2);

        let _ = cached.read_at(0..100).unwrap();
        assert_eq!(cached.inner().read_count(), 3);
    }

    #[test]
    fn context_aware_caching() {
        use std::collections::HashMap;

        struct MultiSource {
            volumes: HashMap<String, Vec<u8>>,
            read_count: AtomicUsize,
        }

        impl ReadAt<String> for MultiSource {
            fn read_at(
                &self,
                volume: &String,
                range: Range<u64>,
            ) -> std::io::Result<impl AsRef<[u8]>> {
                self.read_count.fetch_add(1, Ordering::Relaxed);
                let data = self.volumes.get(volume).unwrap();
                Ok(data[range.start as usize..range.end as usize].to_vec())
            }
        }

        let mut volumes = HashMap::new();
        volumes.insert("a.pkg".to_string(), vec![1u8; 1024]);
        volumes.insert("b.pkg".to_string(), vec![2u8; 1024]);
        let source = MultiSource {
            volumes,
            read_count: AtomicUsize::new(0),
        };

        let cached: CachingReader<_, String> = CachingReader::new(source);

        let d1 = cached.read_at_with(&"a.pkg".to_string(), 0..100).unwrap();
        assert_eq!(d1.as_ref()[0], 1);
        assert_eq!(cached.inner().read_count.load(Ordering::Relaxed), 1);

        let d2 = cached.read_at_with(&"b.pkg".to_string(), 0..100).unwrap();
        assert_eq!(d2.as_ref()[0], 2);
        assert_eq!(cached.inner().read_count.load(Ordering::Relaxed), 2);

        let d3 = cached.read_at_with(&"a.pkg".to_string(), 0..100).unwrap();
        assert_eq!(d3.as_ref()[0], 1);
        assert_eq!(cached.inner().read_count.load(Ordering::Relaxed), 2);
    }
}