use crate::hashing::HashMap;
use std::hash::Hash;
#[derive(Copy, Clone)]
struct LruCacheNode {
next: u32,
previous: u32,
}
pub struct LruCache<K, V> {
lru_list_head: u32,
lru_list_tail: u32,
lru_list: Vec<LruCacheNode>,
lru_list_pairs: Vec<Option<(K, V)>>,
lookup: HashMap<K, u32>,
}
impl<K: Clone + PartialEq + Eq + Hash, V> LruCache<K, V> {
pub fn new(size: u32) -> LruCache<K, V> {
assert!(size > 2);
let mut lru_list = vec![
LruCacheNode {
next: 0,
previous: 0
};
size as usize
];
lru_list[0].previous = u32::MAX;
lru_list[0].next = 1;
for i in 1..(size - 1) {
lru_list[i as usize].previous = i - 1;
lru_list[i as usize].next = i + 1;
}
lru_list[size as usize - 1].previous = size - 2;
lru_list[size as usize - 1].next = u32::MAX;
let mut lru_list_pairs = Vec::with_capacity(size as usize);
for _ in 0..size {
lru_list_pairs.push(None);
}
let lookup = HashMap::default();
LruCache {
lru_list_head: 0,
lru_list_tail: size - 1,
lru_list,
lru_list_pairs,
lookup,
}
}
pub fn pairs(&self) -> &Vec<Option<(K, V)>> {
&self.lru_list_pairs
}
pub fn pairs_mut(&mut self) -> &mut Vec<Option<(K, V)>> {
&mut self.lru_list_pairs
}
fn move_to_front(
&mut self,
node_index: u32,
) {
let node = self.lru_list[node_index as usize];
if node_index == self.lru_list_head {
assert_eq!(node.previous, u32::MAX);
assert_ne!(node.next, u32::MAX);
return;
}
if node_index == self.lru_list_tail {
assert_eq!(node.next, u32::MAX);
assert_ne!(node.previous, u32::MAX);
self.lru_list_tail = node.previous;
}
assert_ne!(node.previous, u32::MAX);
self.lru_list[node.previous as usize].next = node.next;
if node.next != u32::MAX {
self.lru_list[node.next as usize].previous = node.previous;
}
assert_eq!(
self.lru_list[self.lru_list_head as usize].previous,
u32::MAX
);
self.lru_list[self.lru_list_head as usize].previous = node_index;
self.lru_list[node_index as usize].previous = u32::MAX;
self.lru_list[node_index as usize].next = self.lru_list_head;
self.lru_list_head = node_index;
}
fn move_to_back(
&mut self,
node_index: u32,
) {
let node = self.lru_list[node_index as usize];
if node_index == self.lru_list_tail {
assert_eq!(node.next, u32::MAX);
assert_ne!(node.previous, u32::MAX);
return;
}
if node_index == self.lru_list_head {
assert_eq!(node.previous, u32::MAX);
assert_ne!(node.next, u32::MAX);
self.lru_list_head = node.next;
}
if node.previous != u32::MAX {
self.lru_list[node.previous as usize].next = node.next;
}
assert_ne!(node.next, u32::MAX);
self.lru_list[node.next as usize].previous = node.previous;
assert_eq!(self.lru_list[self.lru_list_tail as usize].next, u32::MAX);
self.lru_list[self.lru_list_tail as usize].next = node_index;
self.lru_list[node_index as usize].previous = self.lru_list_tail;
self.lru_list[node_index as usize].next = u32::MAX;
self.lru_list_tail = node_index;
}
pub fn get(
&mut self,
k: &K,
mark_as_recently_used: bool,
) -> Option<&V> {
if let Some(&node_index) = self.lookup.get(k) {
if mark_as_recently_used {
self.move_to_front(node_index);
}
self.lru_list_pairs[node_index as usize]
.as_ref()
.map(|(_, v)| v)
} else {
None
}
}
pub fn get_mut(
&mut self,
k: &K,
mark_as_recently_used: bool,
) -> Option<&mut V> {
if let Some(&node_index) = self.lookup.get(k) {
if mark_as_recently_used {
self.move_to_front(node_index);
}
self.lru_list_pairs[node_index as usize]
.as_mut()
.map(|(_, v)| v)
} else {
None
}
}
pub fn insert(
&mut self,
k: K,
v: V,
) {
if let Some(key_to_remove) = self.lru_list_pairs[self.lru_list_tail as usize]
.as_ref()
.map(|(k, _)| k)
.cloned()
{
self.remove(&key_to_remove);
}
let node_index = self.lru_list_tail;
if let Some((k, _)) = &self.lru_list_pairs[node_index as usize] {
self.lookup.remove(k);
}
self.move_to_front(self.lru_list_tail);
self.lookup.insert(k.clone(), node_index);
self.lru_list_pairs[node_index as usize] = Some((k, v));
}
pub fn remove(
&mut self,
k: &K,
) -> Option<V> {
if let Some(&node_index) = self.lookup.get(k) {
self.move_to_back(node_index);
let v = self.lru_list_pairs[node_index as usize]
.take()
.map(|(_, v)| v);
self.lookup.remove(k);
v
} else {
None
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn check_lru_gets_full() {
let mut lru_cache = LruCache::new(3);
lru_cache.insert(0, 0);
lru_cache.insert(1, 1);
lru_cache.insert(2, 2);
assert!(lru_cache.get(&0, false).is_some());
assert!(lru_cache.get(&1, false).is_some());
assert!(lru_cache.get(&2, false).is_some());
lru_cache.insert(3, 3);
assert!(lru_cache.get(&0, false).is_none());
assert!(lru_cache.get(&1, false).is_some());
assert!(lru_cache.get(&2, false).is_some());
assert!(lru_cache.get(&3, false).is_some());
}
#[test]
fn check_lru_deletes_least_recently_used() {
let mut lru_cache = LruCache::new(3);
lru_cache.insert(0, 0);
lru_cache.insert(1, 1);
lru_cache.insert(2, 2);
assert!(lru_cache.get(&0, false).is_some());
assert!(lru_cache.get(&1, false).is_some());
assert!(lru_cache.get(&2, false).is_some());
lru_cache.get(&0, true);
lru_cache.insert(3, 3);
assert!(lru_cache.get(&0, false).is_some());
assert!(lru_cache.get(&1, false).is_none());
assert!(lru_cache.get(&2, false).is_some());
assert!(lru_cache.get(&3, false).is_some());
}
#[test]
fn check_remove() {
let mut lru_cache = LruCache::new(3);
lru_cache.insert(0, 0);
lru_cache.insert(1, 1);
lru_cache.insert(2, 2);
assert!(lru_cache.get(&0, false).is_some());
assert!(lru_cache.get(&1, false).is_some());
assert!(lru_cache.get(&2, false).is_some());
lru_cache.remove(&0);
lru_cache.remove(&2);
lru_cache.remove(&1);
lru_cache.insert(3, 3);
assert!(lru_cache.get(&0, false).is_none());
assert!(lru_cache.get(&1, false).is_none());
assert!(lru_cache.get(&2, false).is_none());
assert!(lru_cache.get(&3, false).is_some());
}
}