use crate::freelist::{FreeList, FreeListHandle, WeakFreeListHandle};
use std::{mem, num};
#[cfg_attr(feature = "capture", derive(Serialize))]
#[cfg_attr(feature = "replay", derive(Deserialize))]
#[derive(MallocSizeOf)]
struct LRUCacheEntry<T> {
partition_index: u8,
lru_index: ItemIndex,
value: T,
}
#[cfg_attr(feature = "capture", derive(Serialize))]
#[cfg_attr(feature = "replay", derive(Deserialize))]
#[derive(MallocSizeOf)]
pub struct LRUCache<T, M> {
entries: FreeList<LRUCacheEntry<T>, M>,
lru: Vec<LRUTracker<FreeListHandle<M>>>,
}
impl<T, M> LRUCache<T, M> {
pub fn new(lru_partition_count: usize) -> Self {
assert!(lru_partition_count <= u8::MAX as usize + 1);
LRUCache {
entries: FreeList::new(),
lru: (0..lru_partition_count).map(|_| LRUTracker::new()).collect(),
}
}
pub fn push_new(
&mut self,
partition_index: u8,
value: T,
) -> WeakFreeListHandle<M> {
let handle = self.entries.insert(LRUCacheEntry {
partition_index: 0,
lru_index: ItemIndex(num::NonZeroU32::new(1).unwrap()),
value
});
let weak_handle = handle.weak();
let entry = self.entries.get_mut(&handle);
let lru_index = self.lru[partition_index as usize].push_new(handle);
entry.partition_index = partition_index;
entry.lru_index = lru_index;
weak_handle
}
pub fn get_opt(
&self,
handle: &WeakFreeListHandle<M>,
) -> Option<&T> {
self.entries
.get_opt(handle)
.map(|entry| {
&entry.value
})
}
pub fn get_opt_mut(
&mut self,
handle: &WeakFreeListHandle<M>,
) -> Option<&mut T> {
self.entries
.get_opt_mut(handle)
.map(|entry| {
&mut entry.value
})
}
pub fn peek_oldest(&self, partition_index: u8) -> Option<&T> {
self.lru[partition_index as usize]
.peek_front()
.map(|handle| {
let entry = self.entries.get(handle);
&entry.value
})
}
pub fn pop_oldest(
&mut self,
partition_index: u8,
) -> Option<T> {
self.lru[partition_index as usize]
.pop_front()
.map(|handle| {
let entry = self.entries.free(handle);
entry.value
})
}
#[must_use]
pub fn replace_or_insert(
&mut self,
handle: &mut WeakFreeListHandle<M>,
partition_index: u8,
data: T,
) -> Option<T> {
match self.entries.get_opt_mut(handle) {
Some(entry) => {
if entry.partition_index != partition_index {
let strong_handle = self.lru[entry.partition_index as usize].remove(entry.lru_index);
let lru_index = self.lru[partition_index as usize].push_new(strong_handle);
entry.partition_index = partition_index;
entry.lru_index = lru_index;
}
Some(mem::replace(&mut entry.value, data))
}
None => {
*handle = self.push_new(partition_index, data);
None
}
}
}
pub fn remove(&mut self, handle: &WeakFreeListHandle<M>) -> Option<T> {
if let Some(entry) = self.entries.get_opt_mut(handle) {
let strong_handle = self.lru[entry.partition_index as usize].remove(entry.lru_index);
return Some(self.entries.free(strong_handle).value);
}
None
}
pub fn touch(
&mut self,
handle: &WeakFreeListHandle<M>,
) -> Option<&mut T> {
let lru = &mut self.lru;
self.entries
.get_opt_mut(handle)
.map(|entry| {
lru[entry.partition_index as usize].mark_used(entry.lru_index);
&mut entry.value
})
}
#[cfg(test)]
fn validate(&self) {
for lru in &self.lru {
lru.validate();
}
}
}
#[cfg_attr(feature = "capture", derive(Serialize))]
#[cfg_attr(feature = "replay", derive(Deserialize))]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, MallocSizeOf)]
struct ItemIndex(num::NonZeroU32);
impl ItemIndex {
fn as_usize(&self) -> usize {
self.0.get() as usize
}
}
#[cfg_attr(feature = "capture", derive(Serialize))]
#[cfg_attr(feature = "replay", derive(Deserialize))]
#[derive(Debug, MallocSizeOf)]
struct Item<H> {
prev: Option<ItemIndex>,
next: Option<ItemIndex>,
handle: Option<H>,
}
#[cfg_attr(feature = "capture", derive(Serialize))]
#[cfg_attr(feature = "replay", derive(Deserialize))]
#[derive(MallocSizeOf)]
struct LRUTracker<H> {
head: Option<ItemIndex>,
tail: Option<ItemIndex>,
free_list_head: Option<ItemIndex>,
items: Vec<Item<H>>,
}
impl<H> LRUTracker<H> where H: std::fmt::Debug {
fn new() -> Self {
let items = vec![
Item {
prev: None,
next: None,
handle: None,
},
];
LRUTracker {
head: None,
tail: None,
free_list_head: None,
items,
}
}
fn link_as_new_tail(
&mut self,
item_index: ItemIndex,
) {
match (self.head, self.tail) {
(Some(..), Some(tail)) => {
self.items[item_index.as_usize()].prev = Some(tail);
self.items[item_index.as_usize()].next = None;
self.items[tail.as_usize()].next = Some(item_index);
self.tail = Some(item_index);
}
(None, None) => {
self.items[item_index.as_usize()].prev = None;
self.items[item_index.as_usize()].next = None;
self.head = Some(item_index);
self.tail = Some(item_index);
}
(Some(..), None) | (None, Some(..)) => {
unreachable!();
}
}
}
fn unlink(
&mut self,
item_index: ItemIndex,
) {
let (next, prev) = {
let item = &self.items[item_index.as_usize()];
(item.next, item.prev)
};
match next {
Some(next) => {
self.items[next.as_usize()].prev = prev;
}
None => {
debug_assert_eq!(self.tail, Some(item_index));
self.tail = prev;
}
}
match prev {
Some(prev) => {
self.items[prev.as_usize()].next = next;
}
None => {
debug_assert_eq!(self.head, Some(item_index));
self.head = next;
}
}
}
fn push_new(
&mut self,
handle: H,
) -> ItemIndex {
let item_index = match self.free_list_head {
Some(index) => {
let item = &mut self.items[index.as_usize()];
assert!(item.handle.is_none());
item.handle = Some(handle);
self.free_list_head = item.next;
index
}
None => {
let index = ItemIndex(num::NonZeroU32::new(self.items.len() as u32).unwrap());
self.items.push(Item {
prev: None,
next: None,
handle: Some(handle),
});
index
}
};
self.link_as_new_tail(item_index);
item_index
}
fn peek_front(&self) -> Option<&H> {
self.head.map(|head| self.items[head.as_usize()].handle.as_ref().unwrap())
}
fn pop_front(
&mut self,
) -> Option<H> {
let handle = match (self.head, self.tail) {
(Some(head), Some(tail)) => {
let item_index = head;
if head == tail {
self.head = None;
self.tail = None;
} else {
let new_head = self.items[head.as_usize()].next.unwrap();
self.head = Some(new_head);
self.items[new_head.as_usize()].prev = None;
}
self.items[item_index.as_usize()].next = self.free_list_head;
self.free_list_head = Some(item_index);
Some(self.items[item_index.as_usize()].handle.take().unwrap())
}
(None, None) => {
None
}
(Some(..), None) | (None, Some(..)) => {
unreachable!();
}
};
handle
}
fn remove(
&mut self,
index: ItemIndex,
) -> H {
self.unlink(index);
let handle = self.items[index.as_usize()].handle.take().unwrap();
self.items[index.as_usize()].next = self.free_list_head;
self.free_list_head = Some(index);
handle
}
fn mark_used(
&mut self,
index: ItemIndex,
) {
self.unlink(index);
self.link_as_new_tail(index);
}
#[cfg(test)]
fn validate(&self) {
use std::collections::HashSet;
assert!((self.head.is_none() && self.tail.is_none()) || (self.head.is_some() && self.tail.is_some()));
if let Some(head) = self.head {
assert!(self.items[head.as_usize()].prev.is_none());
}
if let Some(tail) = self.tail {
assert!(self.items[tail.as_usize()].next.is_none());
}
let mut free_items = Vec::new();
let mut free_items_set = HashSet::new();
let mut valid_items_front = Vec::new();
let mut valid_items_front_set = HashSet::new();
let mut valid_items_reverse = Vec::new();
let mut valid_items_reverse_set = HashSet::new();
let mut current = self.free_list_head;
while let Some(index) = current {
let item = &self.items[index.as_usize()];
free_items.push(index);
assert!(free_items_set.insert(index));
current = item.next;
}
current = self.head;
while let Some(index) = current {
let item = &self.items[index.as_usize()];
valid_items_front.push(index);
assert!(valid_items_front_set.insert(index));
current = item.next;
}
current = self.tail;
while let Some(index) = current {
let item = &self.items[index.as_usize()];
valid_items_reverse.push(index);
assert!(!valid_items_reverse_set.contains(&index));
valid_items_reverse_set.insert(index);
current = item.prev;
}
assert_eq!(valid_items_front.len(), valid_items_front_set.len());
assert_eq!(valid_items_reverse.len(), valid_items_reverse_set.len());
assert_eq!(free_items.len() + valid_items_front.len() + 1, self.items.len());
assert_eq!(valid_items_front.len(), valid_items_reverse.len());
assert!(free_items_set.intersection(&valid_items_reverse_set).collect::<HashSet<_>>().is_empty());
assert!(free_items_set.intersection(&valid_items_front_set).collect::<HashSet<_>>().is_empty());
assert_eq!(valid_items_front_set.len(), valid_items_reverse_set.len());
for (i0, i1) in valid_items_front.iter().zip(valid_items_reverse.iter().rev()) {
assert_eq!(i0, i1);
}
}
}
#[test]
fn test_lru_tracker_push_peek() {
struct CacheMarker;
const NUM_ELEMENTS: usize = 50;
let mut cache: LRUCache<usize, CacheMarker> = LRUCache::new(1);
cache.validate();
assert_eq!(cache.peek_oldest(0), None);
for i in 0 .. NUM_ELEMENTS {
cache.push_new(0, i);
}
cache.validate();
assert_eq!(cache.peek_oldest(0), Some(&0));
assert_eq!(cache.peek_oldest(0), Some(&0));
cache.pop_oldest(0);
assert_eq!(cache.peek_oldest(0), Some(&1));
}
#[test]
fn test_lru_tracker_push_pop() {
struct CacheMarker;
const NUM_ELEMENTS: usize = 50;
let mut cache: LRUCache<usize, CacheMarker> = LRUCache::new(1);
cache.validate();
for i in 0 .. NUM_ELEMENTS {
cache.push_new(0, i);
}
cache.validate();
for i in 0 .. NUM_ELEMENTS {
assert_eq!(cache.pop_oldest(0), Some(i));
}
cache.validate();
assert_eq!(cache.pop_oldest(0), None);
}
#[test]
fn test_lru_tracker_push_touch_pop() {
struct CacheMarker;
const NUM_ELEMENTS: usize = 50;
let mut cache: LRUCache<usize, CacheMarker> = LRUCache::new(1);
let mut handles = Vec::new();
cache.validate();
for i in 0 .. NUM_ELEMENTS {
handles.push(cache.push_new(0, i));
}
cache.validate();
for i in 0 .. NUM_ELEMENTS/2 {
cache.touch(&handles[i*2]);
}
cache.validate();
for i in 0 .. NUM_ELEMENTS/2 {
assert_eq!(cache.pop_oldest(0), Some(i*2+1));
}
cache.validate();
for i in 0 .. NUM_ELEMENTS/2 {
assert_eq!(cache.pop_oldest(0), Some(i*2));
}
cache.validate();
assert_eq!(cache.pop_oldest(0), None);
}
#[test]
fn test_lru_tracker_push_get() {
struct CacheMarker;
const NUM_ELEMENTS: usize = 50;
let mut cache: LRUCache<usize, CacheMarker> = LRUCache::new(1);
let mut handles = Vec::new();
cache.validate();
for i in 0 .. NUM_ELEMENTS {
handles.push(cache.push_new(0, i));
}
cache.validate();
for i in 0 .. NUM_ELEMENTS/2 {
assert!(cache.get_opt(&handles[i]) == Some(&i));
}
cache.validate();
}
#[test]
fn test_lru_tracker_push_replace_get() {
struct CacheMarker;
const NUM_ELEMENTS: usize = 50;
let mut cache: LRUCache<usize, CacheMarker> = LRUCache::new(1);
let mut handles = Vec::new();
cache.validate();
for i in 0 .. NUM_ELEMENTS {
handles.push(cache.push_new(0, i));
}
cache.validate();
for i in 0 .. NUM_ELEMENTS {
assert_eq!(cache.replace_or_insert(&mut handles[i], 0, i * 2), Some(i));
}
cache.validate();
for i in 0 .. NUM_ELEMENTS/2 {
assert!(cache.get_opt(&handles[i]) == Some(&(i * 2)));
}
cache.validate();
let mut empty_handle = WeakFreeListHandle::invalid();
assert_eq!(cache.replace_or_insert(&mut empty_handle, 0, 100), None);
assert_eq!(cache.get_opt(&empty_handle), Some(&100));
}