use lru::LruCache;
use std::borrow::Borrow;
use std::hash::Hash;
use std::mem::size_of;
use std::num::NonZeroUsize;
pub const DEFAULT_MAX_ENTRIES: usize = 256;
pub const DEFAULT_MAX_BYTES: usize = 256 * 256;
pub struct BoundedByteCache<K, V>
where
K: Hash + Eq + AsRef<[u8]>,
{
inner: LruCache<K, V>,
current_bytes: usize,
max_bytes: usize,
}
impl<K, V> BoundedByteCache<K, V>
where
K: Hash + Eq + AsRef<[u8]>,
{
pub fn new(max_entries: usize, max_bytes: usize) -> Self {
let entry_cap = NonZeroUsize::new(max_entries).unwrap_or(NonZeroUsize::MIN);
Self {
inner: LruCache::new(entry_cap),
current_bytes: 0,
max_bytes,
}
}
#[inline]
pub fn get<Q>(&mut self, key: &Q) -> Option<&V>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.inner.get(key)
}
pub fn put(&mut self, key: K, value: V) {
let entry_bytes = Self::entry_size(&key);
if entry_bytes > self.max_bytes {
return;
}
if self.inner.pop(&key).is_some() {
self.current_bytes = self.current_bytes.saturating_sub(entry_bytes);
}
while self.current_bytes + entry_bytes > self.max_bytes {
match self.inner.pop_lru() {
Some((evicted_key, _)) => {
self.current_bytes = self
.current_bytes
.saturating_sub(Self::entry_size(&evicted_key));
}
None => break,
}
}
if let Some((replaced_key, _)) = self.inner.push(key, value) {
self.current_bytes = self
.current_bytes
.saturating_sub(Self::entry_size(&replaced_key));
}
self.current_bytes += entry_bytes;
}
#[cfg(test)]
#[inline]
pub fn current_bytes(&self) -> usize {
self.current_bytes
}
pub fn len(&self) -> usize {
self.inner.len()
}
fn entry_size(key: &K) -> usize {
Self::PER_ENTRY_OVERHEAD + key.as_ref().len() + size_of::<V>()
}
const PER_ENTRY_OVERHEAD: usize = 64;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_put_and_get() {
let mut cache: BoundedByteCache<Vec<u8>, bool> = BoundedByteCache::new(256, 1024);
cache.put(b"hello".to_vec(), true);
assert_eq!(cache.get(b"hello".as_ref()), Some(&true));
assert_eq!(cache.len(), 1);
}
#[test]
fn test_evicts_lru_when_over_byte_budget() {
let budget = 150;
let mut cache: BoundedByteCache<Vec<u8>, bool> = BoundedByteCache::new(256, budget);
cache.put(b"aaaa".to_vec(), true);
cache.put(b"bbbb".to_vec(), false);
assert_eq!(cache.len(), 2);
cache.put(b"cccc".to_vec(), true);
assert_eq!(cache.len(), 2);
assert_eq!(cache.get(b"aaaa".as_ref()), None);
assert_eq!(cache.get(b"bbbb".as_ref()), Some(&false));
assert_eq!(cache.get(b"cccc".as_ref()), Some(&true));
assert!(cache.current_bytes() <= budget);
}
#[test]
fn test_evicts_lru_when_over_entry_count() {
let mut cache: BoundedByteCache<Vec<u8>, bool> = BoundedByteCache::new(2, 1024);
cache.put(b"a".to_vec(), true);
cache.put(b"b".to_vec(), false);
cache.put(b"c".to_vec(), true);
assert_eq!(cache.len(), 2);
assert_eq!(cache.get(b"a".as_ref()), None);
assert_eq!(cache.get(b"b".as_ref()), Some(&false));
assert_eq!(cache.get(b"c".as_ref()), Some(&true));
}
#[test]
fn test_oversize_entry_is_rejected() {
let mut cache: BoundedByteCache<Vec<u8>, bool> = BoundedByteCache::new(256, 32);
cache.put(b"small".to_vec(), true);
assert_eq!(cache.len(), 0);
assert_eq!(cache.current_bytes(), 0);
}
#[test]
fn test_replacing_key_does_not_double_count() {
let mut cache: BoundedByteCache<Vec<u8>, bool> = BoundedByteCache::new(256, 1024);
cache.put(b"k".to_vec(), true);
let bytes_after_first = cache.current_bytes();
cache.put(b"k".to_vec(), false);
assert_eq!(cache.current_bytes(), bytes_after_first);
assert_eq!(cache.get(b"k".as_ref()), Some(&false));
assert_eq!(cache.len(), 1);
}
#[test]
fn test_get_bumps_recency() {
let mut cache: BoundedByteCache<Vec<u8>, bool> = BoundedByteCache::new(256, 150);
cache.put(b"aaaa".to_vec(), true);
cache.put(b"bbbb".to_vec(), true);
let _ = cache.get(b"aaaa".as_ref());
cache.put(b"cccc".to_vec(), true);
assert_eq!(cache.get(b"aaaa".as_ref()), Some(&true));
assert_eq!(cache.get(b"bbbb".as_ref()), None);
}
#[test]
fn test_many_inserts_stay_within_both_limits() {
let max_entries = 8;
let max_bytes = 600;
let mut cache: BoundedByteCache<Vec<u8>, bool> =
BoundedByteCache::new(max_entries, max_bytes);
for i in 0u16..1000 {
cache.put(format!("key-{:04}", i).into_bytes(), i % 2 == 0);
assert!(cache.current_bytes() <= max_bytes);
assert!(cache.len() <= max_entries);
}
}
#[test]
fn test_zero_entries_clamps_to_one() {
let mut cache: BoundedByteCache<Vec<u8>, bool> = BoundedByteCache::new(0, 1024);
cache.put(b"a".to_vec(), true);
cache.put(b"b".to_vec(), false);
assert_eq!(cache.len(), 1);
assert_eq!(cache.get(b"a".as_ref()), None);
assert_eq!(cache.get(b"b".as_ref()), Some(&false));
}
}