use std::collections::HashMap;
use std::hash::Hash;
use std::ptr::NonNull;
use std::marker::PhantomData;
use crate::{CachePolicy, PrefetchStrategy};
use crate::prefetch::{PrefetchType, NoPrefetch};
use super::{BenchmarkablePolicy, PolicyType};
pub struct ArcCache<K, V>
where
K: Hash + Eq + Clone,
V: Clone,
{
t1: HashMap<K, NonNull<Node<K, V>>>,
t2: HashMap<K, NonNull<Node<K, V>>>,
b1: HashMap<K, ()>,
b2: HashMap<K, ()>,
t1_head: Option<NonNull<Node<K, V>>>,
t1_tail: Option<NonNull<Node<K, V>>>,
t2_head: Option<NonNull<Node<K, V>>>,
t2_tail: Option<NonNull<Node<K, V>>>,
p: usize,
capacity: usize,
t1_size: usize,
t2_size: usize,
prefetch_strategy: Box<dyn PrefetchStrategy<K>>,
prefetch_buffer: HashMap<K, V>,
prefetch_buffer_size: usize,
prefetch_stats: super::lru::PrefetchStats,
_marker: PhantomData<Box<Node<K, V>>>,
}
struct Node<K, V> {
key: K,
value: V,
prev: Option<NonNull<Node<K, V>>>,
next: Option<NonNull<Node<K, V>>>,
list_type: ListType,
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum ListType {
T1,
T2,
}
impl<K, V> Node<K, V> {
fn new(key: K, value: V, list_type: ListType) -> Self {
Self {
key,
value,
prev: None,
next: None,
list_type,
}
}
}
impl<K, V> ArcCache<K, V>
where
K: Hash + Eq + Clone,
V: Clone,
{
pub fn new(capacity: usize) -> Self {
Self::with_custom_prefetch(capacity, Box::new(NoPrefetch))
}
pub fn with_custom_prefetch(
capacity: usize,
prefetch_strategy: Box<dyn PrefetchStrategy<K>>
) -> Self {
assert!(capacity > 0, "ARC cache capacity must be greater than 0");
Self {
t1: HashMap::new(),
t2: HashMap::new(),
b1: HashMap::new(),
b2: HashMap::new(),
t1_head: None,
t1_tail: None,
t2_head: None,
t2_tail: None,
p: 0,
capacity,
t1_size: 0,
t2_size: 0,
prefetch_strategy,
prefetch_buffer: HashMap::new(),
prefetch_buffer_size: (capacity / 4).max(1),
prefetch_stats: super::lru::PrefetchStats::default(),
_marker: PhantomData,
}
}
pub fn prefetch_stats(&self) -> &super::lru::PrefetchStats {
&self.prefetch_stats
}
pub fn reset_prefetch_stats(&mut self) {
self.prefetch_stats = super::lru::PrefetchStats::default();
self.prefetch_strategy.reset();
}
fn perform_prefetch(&mut self, accessed_key: &K) {
self.prefetch_strategy.update_access_pattern(accessed_key);
let predictions = self.prefetch_strategy.predict_next(accessed_key);
for predicted_key in predictions {
self.prefetch_stats.predictions_made += 1;
if !self.t1.contains_key(&predicted_key) &&
!self.t2.contains_key(&predicted_key) &&
!self.prefetch_buffer.contains_key(&predicted_key) {
}
}
self.trim_prefetch_buffer();
}
fn trim_prefetch_buffer(&mut self) {
while self.prefetch_buffer.len() > self.prefetch_buffer_size {
if let Some(key) = self.prefetch_buffer.keys().next().cloned() {
self.prefetch_buffer.remove(&key);
} else {
break;
}
}
}
unsafe fn add_to_front(&mut self, mut node_ptr: NonNull<Node<K, V>>, list_type: ListType) {
let node = unsafe { node_ptr.as_mut() };
node.list_type = list_type;
node.prev = None;
match list_type {
ListType::T1 => {
node.next = self.t1_head;
if let Some(mut old_head) = self.t1_head {
unsafe { old_head.as_mut() }.prev = Some(node_ptr);
} else {
self.t1_tail = Some(node_ptr);
}
self.t1_head = Some(node_ptr);
}
ListType::T2 => {
node.next = self.t2_head;
if let Some(mut old_head) = self.t2_head {
unsafe { old_head.as_mut() }.prev = Some(node_ptr);
} else {
self.t2_tail = Some(node_ptr);
}
self.t2_head = Some(node_ptr);
}
}
}
unsafe fn remove_from_list(&mut self, node_ptr: NonNull<Node<K, V>>) {
let node = unsafe { node_ptr.as_ref() };
if let Some(mut prev) = node.prev {
unsafe { prev.as_mut() }.next = node.next;
} else {
match node.list_type {
ListType::T1 => self.t1_head = node.next,
ListType::T2 => self.t2_head = node.next,
}
}
if let Some(mut next) = node.next {
unsafe { next.as_mut() }.prev = node.prev;
} else {
match node.list_type {
ListType::T1 => self.t1_tail = node.prev,
ListType::T2 => self.t2_tail = node.prev,
}
}
}
fn replace(&mut self, in_b2: bool) {
if self.t1_size >= 1 &&
((in_b2 && self.t1_size == self.p) || self.t1_size > self.p) {
if let Some(lru_ptr) = self.t1_tail {
unsafe {
let lru_node = Box::from_raw(lru_ptr.as_ptr());
let key = lru_node.key.clone();
self.t1.remove(&key);
self.b1.insert(key, ());
self.t1_tail = lru_node.prev;
if let Some(mut new_tail) = self.t1_tail {
new_tail.as_mut().next = None;
} else {
self.t1_head = None;
}
self.t1_size -= 1;
}
}
} else {
if let Some(lru_ptr) = self.t2_tail {
unsafe {
let lru_node = Box::from_raw(lru_ptr.as_ptr());
let key = lru_node.key.clone();
self.t2.remove(&key);
self.b2.insert(key, ());
self.t2_tail = lru_node.prev;
if let Some(mut new_tail) = self.t2_tail {
new_tail.as_mut().next = None;
} else {
self.t2_head = None;
}
self.t2_size -= 1;
}
}
}
}
fn update_p(&mut self, delta: i32) {
if delta > 0 {
self.p = (self.p + delta as usize).min(self.capacity);
} else {
self.p = self.p.saturating_sub((-delta) as usize);
}
}
}
impl<K, V> CachePolicy<K, V> for ArcCache<K, V>
where
K: Hash + Eq + Clone,
V: Clone,
{
fn get(&mut self, key: &K) -> Option<&V> {
if let Some(_) = self.prefetch_buffer.get(key) {
if let Some(value) = self.prefetch_buffer.remove(key) {
self.prefetch_stats.cache_hits_from_prefetch += 1;
self.insert(key.clone(), value);
return self.get(key);
}
}
if let Some(&node_ptr) = self.t1.get(key) {
unsafe {
self.remove_from_list(node_ptr);
self.t1.remove(key);
self.t2.insert(key.clone(), node_ptr);
self.add_to_front(node_ptr, ListType::T2);
self.t1_size -= 1;
self.t2_size += 1;
self.perform_prefetch(key);
return Some(&node_ptr.as_ref().value);
}
}
if let Some(&node_ptr) = self.t2.get(key) {
unsafe {
self.remove_from_list(node_ptr);
self.add_to_front(node_ptr, ListType::T2);
self.perform_prefetch(key);
return Some(&node_ptr.as_ref().value);
}
}
None
}
fn insert(&mut self, key: K, value: V) {
self.prefetch_buffer.remove(&key);
if let Some(&node_ptr) = self.t1.get(&key).or(self.t2.get(&key)) {
unsafe {
(*node_ptr.as_ptr()).value = value;
}
return;
}
if self.b1.contains_key(&key) {
let delta = (self.b2.len() as f32 / self.b1.len() as f32).ceil() as i32;
self.update_p(delta);
self.replace(false);
self.b1.remove(&key);
let new_node = Box::new(Node::new(key.clone(), value, ListType::T2));
let node_ptr = unsafe { NonNull::new_unchecked(Box::into_raw(new_node)) };
self.t2.insert(key, node_ptr);
unsafe { self.add_to_front(node_ptr, ListType::T2); }
self.t2_size += 1;
return;
}
if self.b2.contains_key(&key) {
let delta = (self.b1.len() as f32 / self.b2.len() as f32).ceil() as i32;
self.update_p(-delta);
self.replace(true);
self.b2.remove(&key);
let new_node = Box::new(Node::new(key.clone(), value, ListType::T2));
let node_ptr = unsafe { NonNull::new_unchecked(Box::into_raw(new_node)) };
self.t2.insert(key, node_ptr);
unsafe { self.add_to_front(node_ptr, ListType::T2); }
self.t2_size += 1;
return;
}
let new_node = Box::new(Node::new(key.clone(), value, ListType::T1));
let node_ptr = unsafe { NonNull::new_unchecked(Box::into_raw(new_node)) };
let total_cache = self.t1_size + self.t2_size;
let total_history = self.b1.len() + self.b2.len();
if total_cache < self.capacity {
if total_cache + total_history >= self.capacity {
if total_history >= self.capacity {
if self.b2.len() > 0 {
if let Some(key_to_remove) = self.b2.keys().next().cloned() {
self.b2.remove(&key_to_remove);
}
} else if let Some(key_to_remove) = self.b1.keys().next().cloned() {
self.b1.remove(&key_to_remove);
}
}
}
} else {
self.replace(false);
}
self.t1.insert(key, node_ptr);
unsafe { self.add_to_front(node_ptr, ListType::T1); }
self.t1_size += 1;
}
fn remove(&mut self, key: &K) -> Option<V> {
if let Some(value) = self.prefetch_buffer.remove(key) {
return Some(value);
}
if let Some(node_ptr) = self.t1.remove(key) {
unsafe {
self.remove_from_list(node_ptr);
let node = Box::from_raw(node_ptr.as_ptr());
self.t1_size -= 1;
return Some(node.value);
}
}
if let Some(node_ptr) = self.t2.remove(key) {
unsafe {
self.remove_from_list(node_ptr);
let node = Box::from_raw(node_ptr.as_ptr());
self.t2_size -= 1;
return Some(node.value);
}
}
self.b1.remove(key);
self.b2.remove(key);
None
}
fn len(&self) -> usize {
self.t1_size + self.t2_size
}
fn capacity(&self) -> usize {
self.capacity
}
fn clear(&mut self) {
while let Some(node_ptr) = self.t1_head {
unsafe {
let node = Box::from_raw(node_ptr.as_ptr());
self.t1_head = node.next;
}
}
while let Some(node_ptr) = self.t2_head {
unsafe {
let node = Box::from_raw(node_ptr.as_ptr());
self.t2_head = node.next;
}
}
self.t1.clear();
self.t2.clear();
self.b1.clear();
self.b2.clear();
self.t1_head = None;
self.t1_tail = None;
self.t2_head = None;
self.t2_tail = None;
self.t1_size = 0;
self.t2_size = 0;
self.p = 0;
self.prefetch_buffer.clear();
}
}
impl<K, V> BenchmarkablePolicy<K, V> for ArcCache<K, V>
where
K: Hash + Eq + Clone,
V: Clone,
{
fn policy_type(&self) -> PolicyType {
PolicyType::Arc
}
fn benchmark_name(&self) -> String {
format!("{}_cap_{}_prefetch", self.policy_type().name(), self.capacity())
}
fn reset_for_benchmark(&mut self) {
self.clear();
self.reset_prefetch_stats();
}
}
impl<K, V> Drop for ArcCache<K, V>
where
K: Hash + Eq + Clone,
V: Clone,
{
fn drop(&mut self) {
self.clear();
}
}
impl ArcCache<i32, String> {
pub fn with_prefetch_i32(capacity: usize, prefetch_type: PrefetchType) -> Self {
assert!(capacity > 0, "ARC cache capacity must be greater than 0");
let prefetch_strategy = crate::prefetch::create_prefetch_strategy_i32(prefetch_type);
Self::with_custom_prefetch(capacity, prefetch_strategy)
}
}
impl ArcCache<i64, String> {
pub fn with_prefetch_i64(capacity: usize, prefetch_type: PrefetchType) -> Self {
assert!(capacity > 0, "ARC cache capacity must be greater than 0");
let prefetch_strategy = crate::prefetch::create_prefetch_strategy_i64(prefetch_type);
Self::with_custom_prefetch(capacity, prefetch_strategy)
}
}
impl ArcCache<usize, String> {
pub fn with_prefetch_usize(capacity: usize, prefetch_type: PrefetchType) -> Self {
assert!(capacity > 0, "ARC cache capacity must be greater than 0");
let prefetch_strategy = crate::prefetch::create_prefetch_strategy_usize(prefetch_type);
Self::with_custom_prefetch(capacity, prefetch_strategy)
}
}
unsafe impl<K, V> Send for ArcCache<K, V>
where
K: Hash + Eq + Clone + Send,
V: Clone + Send,
{
}
unsafe impl<K, V> Sync for ArcCache<K, V>
where
K: Hash + Eq + Clone + Sync,
V: Clone + Sync,
{
}