use std::sync::{
Mutex,
atomic::{AtomicU64, Ordering},
};
use bytes::Bytes;
use lru::LruCache;
use objects::{object::ContentHash, sync::LockExt};
pub const DEFAULT_BLOB_CACHE_BYTES: usize = 256 * 1024 * 1024;
pub struct BlobCachePool {
inner: Mutex<BlobCacheInner>,
cap_bytes: usize,
hits: AtomicU64,
misses: AtomicU64,
inserts: AtomicU64,
}
struct BlobCacheInner {
lru: LruCache<ContentHash, Bytes>,
bytes: usize,
}
impl BlobCachePool {
pub fn with_default_capacity() -> Self {
Self::with_capacity(DEFAULT_BLOB_CACHE_BYTES)
}
pub fn with_capacity(cap_bytes: usize) -> Self {
Self {
inner: Mutex::new(BlobCacheInner {
lru: LruCache::unbounded(),
bytes: 0,
}),
cap_bytes,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
inserts: AtomicU64::new(0),
}
}
pub(crate) fn get(&self, hash: &ContentHash) -> Option<Bytes> {
let mut guard = self.inner.lock_or_poisoned();
match guard.lru.get(hash).cloned() {
Some(bytes) => {
self.hits.fetch_add(1, Ordering::Relaxed);
Some(bytes)
}
None => {
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
}
}
pub(crate) fn insert(&self, hash: ContentHash, bytes: Bytes) {
let size = bytes.len();
if size > self.cap_bytes {
return;
}
let mut guard = self.inner.lock_or_poisoned();
if let Some(prev) = guard.lru.put(hash, bytes) {
guard.bytes = guard.bytes.saturating_sub(prev.len());
}
guard.bytes += size;
while guard.bytes > self.cap_bytes {
let Some((_evicted_hash, evicted)) = guard.lru.pop_lru() else {
break;
};
guard.bytes = guard.bytes.saturating_sub(evicted.len());
}
self.inserts.fetch_add(1, Ordering::Relaxed);
}
pub fn clear(&self) {
let mut guard = self.inner.lock_or_poisoned();
guard.lru.clear();
guard.bytes = 0;
}
pub fn cap_bytes(&self) -> usize {
self.cap_bytes
}
pub fn resident_bytes(&self) -> usize {
self.inner.lock_or_poisoned().bytes
}
pub fn entry_count(&self) -> usize {
self.inner.lock_or_poisoned().lru.len()
}
pub fn stats(&self) -> BlobCacheStats {
BlobCacheStats {
hits: self.hits.load(Ordering::Relaxed),
misses: self.misses.load(Ordering::Relaxed),
inserts: self.inserts.load(Ordering::Relaxed),
}
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct BlobCacheStats {
pub hits: u64,
pub misses: u64,
pub inserts: u64,
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
fn h(byte: u8) -> ContentHash {
let mut buf = [0u8; 32];
buf.iter_mut().for_each(|b| *b = byte);
ContentHash::from_bytes(buf)
}
fn payload(byte: u8, len: usize) -> Bytes {
Bytes::from(vec![byte; len])
}
#[test]
fn round_trip_get_hit_miss() {
let pool = BlobCachePool::with_capacity(1024);
assert!(pool.get(&h(1)).is_none());
pool.insert(h(1), payload(0xAA, 64));
let hit = pool.get(&h(1)).expect("should hit");
assert_eq!(&hit[..], &vec![0xAA; 64][..]);
let stats = pool.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert_eq!(stats.inserts, 1);
}
#[test]
fn byte_bound_evicts_lru_until_under_cap() {
let pool = BlobCachePool::with_capacity(300);
pool.insert(h(1), payload(0x01, 100));
pool.insert(h(2), payload(0x02, 100));
pool.insert(h(3), payload(0x03, 100));
assert_eq!(pool.resident_bytes(), 300);
pool.insert(h(4), payload(0x04, 100));
assert!(pool.get(&h(1)).is_none());
assert!(pool.get(&h(4)).is_some());
assert_eq!(pool.resident_bytes(), 300);
}
#[test]
fn oversized_blob_bypasses_cache() {
let pool = BlobCachePool::with_capacity(256);
pool.insert(h(1), payload(0x01, 200));
pool.insert(h(2), payload(0x02, 500));
assert!(pool.get(&h(1)).is_some());
assert!(pool.get(&h(2)).is_none());
assert_eq!(pool.resident_bytes(), 200);
}
#[test]
fn shared_pool_visible_across_clones() {
let pool = Arc::new(BlobCachePool::with_capacity(1024));
let a = Arc::clone(&pool);
let b = Arc::clone(&pool);
a.insert(h(1), payload(0xAA, 64));
assert!(b.get(&h(1)).is_some());
}
}