extern crate alloc;
use crate::config::LruCacheConfig;
use crate::entry::CacheEntry;
use crate::list::{List, ListEntry};
use crate::metrics::{CacheMetrics, LruCacheMetrics};
use alloc::boxed::Box;
use alloc::collections::BTreeMap;
use alloc::string::String;
use alloc::vec::Vec;
use core::borrow::Borrow;
use core::hash::{BuildHasher, Hash};
use core::num::NonZeroUsize;
#[cfg(feature = "hashbrown")]
use hashbrown::DefaultHashBuilder;
#[cfg(feature = "hashbrown")]
use hashbrown::HashMap;
#[cfg(not(feature = "hashbrown"))]
use std::collections::hash_map::RandomState as DefaultHashBuilder;
#[cfg(not(feature = "hashbrown"))]
use std::collections::HashMap;
pub(crate) struct LruSegment<K, V, S = DefaultHashBuilder> {
config: LruCacheConfig,
list: List<CacheEntry<K, V>>,
map: HashMap<K, *mut ListEntry<CacheEntry<K, V>>, S>,
metrics: LruCacheMetrics,
current_size: u64,
}
unsafe impl<K: Send, V: Send, S: Send> Send for LruSegment<K, V, S> {}
unsafe impl<K: Send, V: Send, S: Sync> Sync for LruSegment<K, V, S> {}
impl<K: Hash + Eq, V: Clone, S: BuildHasher> LruSegment<K, V, S> {
#[allow(dead_code)] pub(crate) fn init(config: LruCacheConfig, hasher: S) -> Self {
let map_capacity = config.capacity.get().next_power_of_two();
LruSegment {
config,
list: List::new(config.capacity),
map: HashMap::with_capacity_and_hasher(map_capacity, hasher),
metrics: LruCacheMetrics::new(config.max_size),
current_size: 0,
}
}
#[inline]
pub(crate) fn cap(&self) -> NonZeroUsize {
self.config.capacity
}
#[inline]
pub(crate) fn len(&self) -> usize {
self.map.len()
}
#[inline]
pub(crate) fn is_empty(&self) -> bool {
self.map.is_empty()
}
#[inline]
pub(crate) fn current_size(&self) -> u64 {
self.current_size
}
#[inline]
pub(crate) fn max_size(&self) -> u64 {
self.config.max_size
}
#[inline]
pub(crate) fn metrics(&self) -> &LruCacheMetrics {
&self.metrics
}
pub(crate) fn get<Q>(&mut self, key: &Q) -> Option<&V>
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
if let Some(node) = self.map.get(key).copied() {
unsafe {
self.list.move_to_front(node);
let entry = (*node).get_value_mut();
entry.touch(); self.metrics.core.record_hit(entry.metadata.size);
Some(&entry.value)
}
} else {
None
}
}
#[inline]
pub(crate) fn record_miss(&mut self, object_size: u64) {
self.metrics.core.record_miss(object_size);
}
pub(crate) fn get_mut<Q>(&mut self, key: &Q) -> Option<&mut V>
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
let node = self.map.get(key).copied()?;
unsafe {
self.list.move_to_front(node);
let entry = (*node).get_value_mut();
entry.touch(); self.metrics.core.record_hit(entry.metadata.size);
Some(&mut entry.value)
}
}
pub(crate) fn put(&mut self, key: K, value: V, size: u64) -> Option<Vec<(K, V)>>
where
K: Clone + Hash + Eq,
{
if let Some(&node) = self.map.get(&key) {
unsafe {
self.list.move_to_front(node);
let entry = (*node).get_value_mut();
let old_size = entry.metadata.size;
self.current_size = self.current_size.saturating_sub(old_size);
self.metrics.core.cache_size_bytes =
self.metrics.core.cache_size_bytes.saturating_sub(old_size);
let _old_key = core::mem::replace(&mut entry.key, key);
let _old_value = core::mem::replace(&mut entry.value, value);
entry.metadata.size = size;
entry.touch();
self.current_size += size;
self.metrics.core.cache_size_bytes += size;
self.metrics.core.bytes_written_to_cache += size;
return None;
}
}
let mut evicted = Vec::new();
while self.map.len() >= self.cap().get()
|| (self.current_size + size > self.config.max_size && !self.map.is_empty())
{
if let Some(entry) = self.evict() {
self.metrics.core.evictions += 1;
evicted.push(entry);
} else {
break;
}
}
let cache_entry = CacheEntry::new(key.clone(), value, size);
if let Some(node) = self.list.add(cache_entry) {
self.map.insert(key, node);
self.current_size += size;
self.metrics.core.record_insertion(size);
}
if evicted.is_empty() {
None
} else {
Some(evicted)
}
}
pub(crate) fn remove<Q>(&mut self, key: &Q) -> Option<V>
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
let node = self.map.remove(key)?;
unsafe {
if let Some(boxed) = self.list.remove(node) {
let entry_ptr = Box::into_raw(boxed);
let cache_entry = (*entry_ptr).take_value();
let removed_size = cache_entry.metadata.size;
let _ = Box::from_raw(entry_ptr);
self.current_size = self.current_size.saturating_sub(removed_size);
self.metrics.core.record_removal(removed_size);
Some(cache_entry.value)
} else {
None
}
}
}
pub(crate) fn clear(&mut self) {
self.current_size = 0;
self.metrics.core.cache_size_bytes = 0;
self.map.clear();
self.list.clear();
}
#[inline]
pub(crate) fn contains<Q>(&self, key: &Q) -> bool
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
self.map.contains_key(key)
}
pub(crate) fn peek<Q>(&self, key: &Q) -> Option<&V>
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
let node = self.map.get(key).copied()?;
unsafe {
let entry = (*node).get_value();
Some(&entry.value)
}
}
fn evict(&mut self) -> Option<(K, V)> {
let old_entry = self.list.remove_last()?;
unsafe {
let entry_ptr = Box::into_raw(old_entry);
let cache_entry = (*entry_ptr).take_value();
let evicted_size = cache_entry.metadata.size;
self.map.remove(&cache_entry.key);
self.current_size = self.current_size.saturating_sub(evicted_size);
self.metrics.core.record_removal(evicted_size);
let _ = Box::from_raw(entry_ptr);
Some((cache_entry.key, cache_entry.value))
}
}
}
impl<K, V, S> core::fmt::Debug for LruSegment<K, V, S> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("LruSegment")
.field("capacity", &self.config.capacity)
.field("len", &self.map.len())
.finish()
}
}
#[derive(Debug)]
pub struct LruCache<K, V, S = DefaultHashBuilder> {
segment: LruSegment<K, V, S>,
}
impl<K: Hash + Eq, V: Clone, S: BuildHasher> LruCache<K, V, S> {
#[inline]
pub fn cap(&self) -> NonZeroUsize {
self.segment.cap()
}
#[inline]
pub fn len(&self) -> usize {
self.segment.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.segment.is_empty()
}
#[inline]
pub fn current_size(&self) -> u64 {
self.segment.current_size()
}
#[inline]
pub fn max_size(&self) -> u64 {
self.segment.max_size()
}
#[inline]
pub fn get<Q>(&mut self, key: &Q) -> Option<&V>
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
self.segment.get(key)
}
#[inline]
pub fn record_miss(&mut self, object_size: u64) {
self.segment.record_miss(object_size);
}
#[inline]
pub fn get_mut<Q>(&mut self, key: &Q) -> Option<&mut V>
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
self.segment.get_mut(key)
}
}
impl<K: Hash + Eq + Clone, V: Clone, S: BuildHasher> LruCache<K, V, S> {
#[inline]
pub fn put(&mut self, key: K, value: V, size: u64) -> Option<Vec<(K, V)>> {
self.segment.put(key, value, size)
}
#[inline]
pub fn remove<Q>(&mut self, key: &Q) -> Option<V>
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
self.segment.remove(key)
}
#[inline]
pub fn clear(&mut self) {
self.segment.clear()
}
#[inline]
pub fn contains<Q>(&self, key: &Q) -> bool
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
self.segment.contains(key)
}
#[inline]
pub fn peek<Q>(&self, key: &Q) -> Option<&V>
where
K: Borrow<Q>,
Q: ?Sized + Hash + Eq,
{
self.segment.peek(key)
}
pub fn iter(&self) -> Iter<'_, K, V> {
unimplemented!("Iteration not yet implemented")
}
pub fn iter_mut(&mut self) -> IterMut<'_, K, V> {
unimplemented!("Mutable iteration not yet implemented")
}
}
impl<K: Hash + Eq, V> LruCache<K, V>
where
V: Clone,
{
pub fn init(
config: LruCacheConfig,
hasher: Option<DefaultHashBuilder>,
) -> LruCache<K, V, DefaultHashBuilder> {
LruCache {
segment: LruSegment::init(config, hasher.unwrap_or_default()),
}
}
}
impl<K: Hash + Eq, V: Clone, S: BuildHasher> CacheMetrics for LruCache<K, V, S> {
fn metrics(&self) -> BTreeMap<String, f64> {
self.segment.metrics().metrics()
}
fn algorithm_name(&self) -> &'static str {
self.segment.metrics().algorithm_name()
}
}
pub struct Iter<'a, K, V> {
_marker: core::marker::PhantomData<&'a (K, V)>,
}
pub struct IterMut<'a, K, V> {
_marker: core::marker::PhantomData<&'a mut (K, V)>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::LruCacheConfig;
use alloc::string::String;
use alloc::vec;
fn make_cache<K: Hash + Eq + Clone, V: Clone>(cap: usize) -> LruCache<K, V> {
let config = LruCacheConfig {
capacity: NonZeroUsize::new(cap).unwrap(),
max_size: u64::MAX,
};
LruCache::init(config, None)
}
#[test]
fn test_lru_get_put() {
let mut cache = make_cache(2);
assert_eq!(cache.put("apple", 1, 1), None);
assert_eq!(cache.put("banana", 2, 1), None);
assert_eq!(cache.get(&"apple"), Some(&1));
assert_eq!(cache.get(&"banana"), Some(&2));
assert_eq!(cache.get(&"cherry"), None);
assert!(cache.put("apple", 3, 1).is_none());
assert_eq!(cache.get(&"apple"), Some(&3));
assert_eq!(cache.put("cherry", 4, 1).unwrap()[0].1, 2);
assert_eq!(cache.get(&"banana"), None);
assert_eq!(cache.get(&"apple"), Some(&3));
assert_eq!(cache.get(&"cherry"), Some(&4));
}
#[test]
fn test_lru_get_mut() {
let mut cache = make_cache(2);
cache.put("apple", 1, 1);
cache.put("banana", 2, 1);
if let Some(v) = cache.get_mut(&"apple") {
*v = 3;
}
assert_eq!(cache.get(&"apple"), Some(&3));
cache.put("cherry", 4, 1);
assert_eq!(cache.get(&"banana"), None);
assert_eq!(cache.get(&"apple"), Some(&3));
assert_eq!(cache.get(&"cherry"), Some(&4));
}
#[test]
fn test_lru_remove() {
let mut cache = make_cache(2);
cache.put("apple", 1, 1);
cache.put("banana", 2, 1);
assert_eq!(cache.get(&"apple"), Some(&1));
assert_eq!(cache.get(&"banana"), Some(&2));
assert_eq!(cache.get(&"cherry"), None);
assert_eq!(cache.remove(&"apple"), Some(1));
assert_eq!(cache.get(&"apple"), None);
assert_eq!(cache.len(), 1);
assert_eq!(cache.remove(&"cherry"), None);
let evicted = cache.put("cherry", 3, 1);
assert_eq!(evicted, None);
assert_eq!(cache.get(&"banana"), Some(&2));
assert_eq!(cache.get(&"cherry"), Some(&3));
}
#[test]
fn test_lru_clear() {
let mut cache = make_cache(2);
cache.put("apple", 1, 1);
cache.put("banana", 2, 1);
assert_eq!(cache.len(), 2);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
cache.put("cherry", 3, 1);
assert_eq!(cache.get(&"cherry"), Some(&3));
}
#[test]
fn test_lru_capacity_limits() {
let mut cache = make_cache(2);
cache.put("apple", 1, 1);
cache.put("banana", 2, 1);
cache.put("cherry", 3, 1);
assert_eq!(cache.len(), 2);
assert_eq!(cache.get(&"apple"), None);
assert_eq!(cache.get(&"banana"), Some(&2));
assert_eq!(cache.get(&"cherry"), Some(&3));
}
#[test]
fn test_lru_string_keys() {
let mut cache = make_cache(2);
let key1 = String::from("apple");
let key2 = String::from("banana");
cache.put(key1.clone(), 1, 1);
cache.put(key2.clone(), 2, 1);
assert_eq!(cache.get(&key1), Some(&1));
assert_eq!(cache.get(&key2), Some(&2));
assert_eq!(cache.get("apple"), Some(&1));
assert_eq!(cache.get("banana"), Some(&2));
drop(cache);
}
#[derive(Debug, Clone, Eq, PartialEq)]
struct ComplexValue {
val: i32,
description: String,
}
#[test]
fn test_lru_complex_values() {
let mut cache = make_cache(2);
let key1 = String::from("apple");
let key2 = String::from("banana");
let fruit1 = ComplexValue {
val: 1,
description: String::from("First fruit"),
};
let fruit2 = ComplexValue {
val: 2,
description: String::from("Second fruit"),
};
let fruit3 = ComplexValue {
val: 3,
description: String::from("Third fruit"),
};
cache.put(key1.clone(), fruit1.clone(), 1);
cache.put(key2.clone(), fruit2.clone(), 1);
assert_eq!(cache.get(&key1).unwrap().val, fruit1.val);
assert_eq!(cache.get(&key2).unwrap().val, fruit2.val);
let evicted = cache.put(String::from("cherry"), fruit3.clone(), 1);
let evicted_fruit = evicted.unwrap();
assert_eq!(evicted_fruit[0].1, fruit1);
let removed = cache.remove(&key1);
assert_eq!(removed, None);
}
#[test]
fn test_lru_metrics() {
use crate::metrics::CacheMetrics;
let mut cache = make_cache(2);
let metrics = cache.metrics();
assert_eq!(metrics.get("requests").unwrap(), &0.0);
assert_eq!(metrics.get("cache_hits").unwrap(), &0.0);
assert_eq!(metrics.get("cache_misses").unwrap(), &0.0);
cache.put("apple", 1, 1);
cache.put("banana", 2, 1);
cache.get(&"apple");
cache.get(&"banana");
let metrics = cache.metrics();
assert_eq!(metrics.get("cache_hits").unwrap(), &2.0);
cache.record_miss(64);
let metrics = cache.metrics();
assert_eq!(metrics.get("cache_misses").unwrap(), &1.0);
assert_eq!(metrics.get("requests").unwrap(), &3.0);
cache.put("cherry", 3, 1);
let metrics = cache.metrics();
assert_eq!(metrics.get("evictions").unwrap(), &1.0);
assert!(metrics.get("bytes_written_to_cache").unwrap() > &0.0);
assert_eq!(cache.algorithm_name(), "LRU");
}
#[test]
fn test_lru_segment_directly() {
let config = LruCacheConfig {
capacity: NonZeroUsize::new(2).unwrap(),
max_size: u64::MAX,
};
let mut segment: LruSegment<&str, i32, DefaultHashBuilder> =
LruSegment::init(config, DefaultHashBuilder::default());
assert_eq!(segment.len(), 0);
assert!(segment.is_empty());
assert_eq!(segment.cap().get(), 2);
segment.put("a", 1, 1);
segment.put("b", 2, 1);
assert_eq!(segment.len(), 2);
assert_eq!(segment.get(&"a"), Some(&1));
assert_eq!(segment.get(&"b"), Some(&2));
}
#[test]
fn test_lru_concurrent_access() {
extern crate std;
use std::sync::{Arc, Mutex};
use std::thread;
use std::vec::Vec;
let cache = Arc::new(Mutex::new(make_cache::<String, i32>(100)));
let num_threads = 4;
let ops_per_thread = 100;
let mut handles: Vec<std::thread::JoinHandle<()>> = Vec::new();
for t in 0..num_threads {
let cache = Arc::clone(&cache);
handles.push(thread::spawn(move || {
for i in 0..ops_per_thread {
let key = std::format!("thread_{}_key_{}", t, i);
let mut guard = cache.lock().unwrap();
guard.put(key, t * 1000 + i, 1);
}
}));
}
for t in 0..num_threads {
let cache = Arc::clone(&cache);
handles.push(thread::spawn(move || {
for i in 0..ops_per_thread {
let key = std::format!("thread_{}_key_{}", t, i);
let mut guard = cache.lock().unwrap();
let _ = guard.get(&key);
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let mut guard = cache.lock().unwrap();
assert!(guard.len() <= 100);
assert!(!guard.is_empty());
guard.clear(); }
#[test]
fn test_lru_high_contention() {
extern crate std;
use std::sync::{Arc, Mutex};
use std::thread;
use std::vec::Vec;
let cache = Arc::new(Mutex::new(make_cache::<String, i32>(50)));
let num_threads = 8;
let ops_per_thread = 500;
let mut handles: Vec<std::thread::JoinHandle<()>> = Vec::new();
for t in 0..num_threads {
let cache = Arc::clone(&cache);
handles.push(thread::spawn(move || {
for i in 0..ops_per_thread {
let key = std::format!("key_{}", i % 100); let mut guard = cache.lock().unwrap();
if i % 2 == 0 {
guard.put(key, t * 1000 + i, 1);
} else {
let _ = guard.get(&key);
}
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let mut guard = cache.lock().unwrap();
assert!(guard.len() <= 50);
guard.clear(); }
#[test]
fn test_lru_concurrent_mixed_operations() {
extern crate std;
use std::sync::{Arc, Mutex};
use std::thread;
use std::vec::Vec;
let cache = Arc::new(Mutex::new(make_cache::<String, i32>(100)));
let num_threads = 8;
let ops_per_thread = 1000;
let mut handles: Vec<std::thread::JoinHandle<()>> = Vec::new();
for t in 0..num_threads {
let cache = Arc::clone(&cache);
handles.push(thread::spawn(move || {
for i in 0..ops_per_thread {
let key = std::format!("key_{}", i % 200);
let mut guard = cache.lock().unwrap();
match i % 4 {
0 => {
guard.put(key, i, 1);
}
1 => {
let _ = guard.get(&key);
}
2 => {
let _ = guard.get_mut(&key);
}
3 => {
let _ = guard.remove(&key);
}
_ => unreachable!(),
}
if i == 500 && t == 0 {
guard.clear();
}
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let mut guard = cache.lock().unwrap();
assert!(guard.len() <= 100);
guard.clear(); }
#[test]
fn test_lru_size_aware_tracking() {
let mut cache = make_cache(10);
assert_eq!(cache.current_size(), 0);
assert_eq!(cache.max_size(), u64::MAX);
cache.put("a", 1, 100);
cache.put("b", 2, 200);
cache.put("c", 3, 150);
assert_eq!(cache.current_size(), 450);
assert_eq!(cache.len(), 3);
cache.clear();
assert_eq!(cache.current_size(), 0);
}
#[test]
fn test_lru_init_constructor() {
let config = LruCacheConfig {
capacity: NonZeroUsize::new(1000).unwrap(),
max_size: 1024 * 1024,
};
let cache: LruCache<String, i32> = LruCache::init(config, None);
assert_eq!(cache.current_size(), 0);
assert_eq!(cache.max_size(), 1024 * 1024);
assert_eq!(cache.len(), 0);
}
#[test]
fn test_lru_with_limits_constructor() {
let config = LruCacheConfig {
capacity: NonZeroUsize::new(100).unwrap(),
max_size: 1024 * 1024,
};
let cache: LruCache<String, String> = LruCache::init(config, None);
assert_eq!(cache.current_size(), 0);
assert_eq!(cache.max_size(), 1024 * 1024);
assert_eq!(cache.cap().get(), 100);
}
#[test]
fn test_lru_contains_non_promoting() {
let mut cache = make_cache(2);
cache.put("a", 1, 1);
cache.put("b", 2, 1);
assert!(cache.contains(&"a"));
assert!(cache.contains(&"b"));
assert!(!cache.contains(&"c"));
cache.put("c", 3, 1);
assert!(!cache.contains(&"a")); assert!(cache.contains(&"b")); assert!(cache.contains(&"c")); }
#[test]
fn test_put_eviction_increments_eviction_count() {
let mut cache = make_cache(2);
cache.put("a", 1, 1);
cache.put("b", 2, 1);
assert_eq!(cache.segment.metrics().core.evictions, 0);
cache.put("c", 3, 1);
assert_eq!(cache.segment.metrics().core.evictions, 1);
cache.put("d", 4, 1);
assert_eq!(cache.segment.metrics().core.evictions, 2);
}
#[test]
fn test_put_returns_none_when_no_eviction() {
let mut cache = make_cache(10);
assert!(cache.put("a", 1, 1).is_none());
assert!(cache.put("b", 2, 1).is_none());
}
#[test]
fn test_put_returns_single_eviction() {
let mut cache = make_cache(2);
cache.put("a", 1, 1);
cache.put("b", 2, 1);
let result = cache.put("c", 3, 1);
assert_eq!(result, Some(vec![("a", 1)]));
}
#[test]
fn test_put_replacement_returns_none() {
let mut cache = make_cache(10);
cache.put("a", 1, 1);
let result = cache.put("a", 2, 1);
assert!(result.is_none());
assert_eq!(cache.get(&"a"), Some(&2));
}
#[test]
fn test_put_returns_multiple_evictions_size_based() {
let config = LruCacheConfig {
capacity: NonZeroUsize::new(10).unwrap(),
max_size: 100,
};
let mut cache = LruCache::init(config, None);
for i in 0..10 {
cache.put(i, i, 10);
}
let result = cache.put(99, 99, 50);
let evicted = result.unwrap();
assert_eq!(evicted.len(), 5);
}
}