use std::collections::HashMap;
use std::hash::Hash;
const NIL: usize = usize::MAX;
struct Slot<K, V> {
key: K,
val: V,
prev: usize,
next: usize,
}
pub(crate) struct LruCache<K, V> {
map: HashMap<K, usize>,
slots: Vec<Slot<K, V>>,
head: usize,
tail: usize,
capacity: usize,
}
impl<K: Hash + Eq + Clone, V> LruCache<K, V> {
pub(crate) fn with_capacity(capacity: usize) -> Self {
Self {
map: HashMap::with_capacity(capacity),
slots: Vec::with_capacity(capacity),
head: NIL,
tail: NIL,
capacity,
}
}
pub(crate) fn len(&self) -> usize {
self.map.len()
}
fn detach(&mut self, idx: usize) {
let (prev, next) = {
let slot = &self.slots[idx];
(slot.prev, slot.next)
};
if prev != NIL {
self.slots[prev].next = next;
} else {
self.head = next;
}
if next != NIL {
self.slots[next].prev = prev;
} else {
self.tail = prev;
}
}
fn push_front(&mut self, idx: usize) {
let old_head = self.head;
{
let slot = &mut self.slots[idx];
slot.prev = NIL;
slot.next = old_head;
}
if old_head != NIL {
self.slots[old_head].prev = idx;
} else {
self.tail = idx;
}
self.head = idx;
}
pub(crate) fn get(&mut self, key: &K) -> Option<&V> {
let idx = *self.map.get(key)?;
self.detach(idx);
self.push_front(idx);
Some(&self.slots[idx].val)
}
pub(crate) fn put(&mut self, key: K, val: V) -> Option<(K, V)> {
if self.capacity == 0 {
return Some((key, val));
}
if let Some(&idx) = self.map.get(&key) {
self.slots[idx].val = val;
self.detach(idx);
self.push_front(idx);
return None;
}
if self.map.len() == self.capacity {
let idx = self.tail;
self.detach(idx);
let evicted_key = std::mem::replace(&mut self.slots[idx].key, key.clone());
let evicted_val = std::mem::replace(&mut self.slots[idx].val, val);
let _removed = self.map.remove(&evicted_key);
let _prev = self.map.insert(key, idx);
self.push_front(idx);
return Some((evicted_key, evicted_val));
}
let idx = self.slots.len();
self.slots.push(Slot {
key: key.clone(),
val,
prev: NIL,
next: NIL,
});
let _prev = self.map.insert(key, idx);
self.push_front(idx);
None
}
pub(crate) fn clear(&mut self) {
self.map.clear();
self.slots.clear();
self.head = NIL;
self.tail = NIL;
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn get_miss_on_empty() {
let mut cache: LruCache<u32, u32> = LruCache::with_capacity(4);
assert_eq!(cache.get(&1), None);
assert_eq!(cache.len(), 0);
}
#[test]
fn put_then_get() {
let mut cache = LruCache::with_capacity(2);
assert_eq!(cache.put(1u32, 10u32), None);
assert_eq!(cache.put(2, 20), None);
assert_eq!(cache.get(&1), Some(&10));
assert_eq!(cache.get(&2), Some(&20));
assert_eq!(cache.len(), 2);
}
#[test]
fn put_updates_existing_without_growth() {
let mut cache = LruCache::with_capacity(2);
assert_eq!(cache.put(1u32, 10u32), None);
assert_eq!(cache.put(1, 11), None);
assert_eq!(cache.len(), 1);
assert_eq!(cache.get(&1), Some(&11));
}
#[test]
fn evicts_least_recently_used() {
let mut cache = LruCache::with_capacity(2);
let _ = cache.put(1u32, 10u32);
let _ = cache.put(2, 20);
assert_eq!(cache.get(&1), Some(&10));
assert_eq!(cache.put(3, 30), Some((2, 20)));
assert_eq!(cache.get(&2), None);
assert_eq!(cache.get(&1), Some(&10));
assert_eq!(cache.get(&3), Some(&30));
assert_eq!(cache.len(), 2);
}
#[test]
fn capacity_one_keeps_only_newest() {
let mut cache = LruCache::with_capacity(1);
let _ = cache.put(1u32, 10u32);
assert_eq!(cache.put(2, 20), Some((1, 10)));
assert_eq!(cache.get(&1), None);
assert_eq!(cache.get(&2), Some(&20));
}
#[test]
fn capacity_zero_rejects_everything() {
let mut cache = LruCache::with_capacity(0);
assert_eq!(cache.put(1u32, 10u32), Some((1, 10)));
assert_eq!(cache.get(&1), None);
assert_eq!(cache.len(), 0);
}
#[test]
fn clear_empties_and_keeps_working() {
let mut cache = LruCache::with_capacity(2);
let _ = cache.put(1u32, 10u32);
let _ = cache.put(2, 20);
cache.clear();
assert_eq!(cache.len(), 0);
assert_eq!(cache.get(&1), None);
let _ = cache.put(3, 30);
assert_eq!(cache.get(&3), Some(&30));
}
#[test]
fn never_exceeds_capacity_under_churn() {
let mut cache = LruCache::with_capacity(8);
for i in 0..1000u32 {
let _ = cache.put(i, i);
assert!(cache.len() <= 8);
}
for i in 992..1000u32 {
assert_eq!(cache.get(&i), Some(&i));
}
for i in 0..992u32 {
assert_eq!(cache.get(&i), None);
}
}
}