use std::{mem::offset_of, sync::Arc};
use foyer_common::{
code::{Key, Value},
strict_assert,
};
use intrusive_collections::{intrusive_adapter, LinkedList, LinkedListAtomicLink};
use serde::{Deserialize, Serialize};
use super::{Eviction, Op};
use crate::{
error::{Error, Result},
record::{CacheHint, Record},
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LruConfig {
pub high_priority_pool_ratio: f64,
}
impl Default for LruConfig {
fn default() -> Self {
Self {
high_priority_pool_ratio: 0.9,
}
}
}
#[derive(Debug, Clone)]
pub enum LruHint {
HighPriority,
LowPriority,
}
impl Default for LruHint {
fn default() -> Self {
Self::HighPriority
}
}
impl From<CacheHint> for LruHint {
fn from(hint: CacheHint) -> Self {
match hint {
CacheHint::Normal => LruHint::HighPriority,
CacheHint::Low => LruHint::LowPriority,
}
}
}
impl From<LruHint> for CacheHint {
fn from(hint: LruHint) -> Self {
match hint {
LruHint::HighPriority => CacheHint::Normal,
LruHint::LowPriority => CacheHint::Low,
}
}
}
#[derive(Debug, Default)]
pub struct LruState {
link: LinkedListAtomicLink,
in_high_priority_pool: bool,
is_pinned: bool,
}
intrusive_adapter! { Adapter<K, V> = Arc<Record<Lru<K, V>>>: Record<Lru<K, V>> { ?offset = Record::<Lru<K, V>>::STATE_OFFSET + offset_of!(LruState, link) => LinkedListAtomicLink } where K: Key, V: Value }
pub struct Lru<K, V>
where
K: Key,
V: Value,
{
high_priority_list: LinkedList<Adapter<K, V>>,
list: LinkedList<Adapter<K, V>>,
pin_list: LinkedList<Adapter<K, V>>,
high_priority_weight: usize,
high_priority_weight_capacity: usize,
config: LruConfig,
}
impl<K, V> Lru<K, V>
where
K: Key,
V: Value,
{
fn may_overflow_high_priority_pool(&mut self) {
while self.high_priority_weight > self.high_priority_weight_capacity {
strict_assert!(!self.high_priority_list.is_empty());
let record = self.high_priority_list.pop_front().unwrap();
let state = unsafe { &mut *record.state().get() };
strict_assert!(state.in_high_priority_pool);
state.in_high_priority_pool = false;
self.high_priority_weight -= record.weight();
self.list.push_back(record);
}
}
}
impl<K, V> Eviction for Lru<K, V>
where
K: Key,
V: Value,
{
type Config = LruConfig;
type Key = K;
type Value = V;
type Hint = LruHint;
type State = LruState;
fn new(capacity: usize, config: &Self::Config) -> Self
where
Self: Sized,
{
assert!(
(0.0..=1.0).contains(&config.high_priority_pool_ratio),
"high_priority_pool_ratio_percentage must be in 0.0..=1.0, given: {}",
config.high_priority_pool_ratio
);
let config = config.clone();
let high_priority_weight_capacity = (capacity as f64 * config.high_priority_pool_ratio) as usize;
Self {
high_priority_list: LinkedList::new(Adapter::new()),
list: LinkedList::new(Adapter::new()),
pin_list: LinkedList::new(Adapter::new()),
high_priority_weight: 0,
high_priority_weight_capacity,
config,
}
}
fn update(&mut self, capacity: usize, config: Option<&Self::Config>) -> Result<()> {
if let Some(config) = config {
if !(0.0..=1.0).contains(&config.high_priority_pool_ratio) {
return Err(Error::ConfigError(
format!(
"[lru]: high_priority_pool_ratio_percentage must be in 0.0..=1.0, given: {}, new configuration ignored",
config.high_priority_pool_ratio
)
));
}
self.config = config.clone();
}
let high_priority_weight_capacity = (capacity as f64 * self.config.high_priority_pool_ratio) as usize;
self.high_priority_weight_capacity = high_priority_weight_capacity;
self.may_overflow_high_priority_pool();
Ok(())
}
fn push(&mut self, record: Arc<Record<Self>>) {
let state = unsafe { &mut *record.state().get() };
strict_assert!(!state.link.is_linked());
record.set_in_eviction(true);
match record.hint() {
LruHint::HighPriority => {
state.in_high_priority_pool = true;
self.high_priority_weight += record.weight();
self.high_priority_list.push_back(record);
self.may_overflow_high_priority_pool();
}
LruHint::LowPriority => {
state.in_high_priority_pool = false;
self.list.push_back(record);
}
}
}
fn pop(&mut self) -> Option<Arc<Record<Self>>> {
let record = self.list.pop_front().or_else(|| self.high_priority_list.pop_front())?;
let state = unsafe { &mut *record.state().get() };
strict_assert!(!state.link.is_linked());
if state.in_high_priority_pool {
self.high_priority_weight -= record.weight();
state.in_high_priority_pool = false;
}
record.set_in_eviction(false);
Some(record)
}
fn remove(&mut self, record: &Arc<Record<Self>>) {
let state = unsafe { &mut *record.state().get() };
strict_assert!(state.link.is_linked());
match (state.is_pinned, state.in_high_priority_pool) {
(true, false) => unsafe { self.pin_list.remove_from_ptr(Arc::as_ptr(record)) },
(true, true) => unsafe {
self.high_priority_weight -= record.weight();
state.in_high_priority_pool = false;
self.pin_list.remove_from_ptr(Arc::as_ptr(record))
},
(false, true) => {
self.high_priority_weight -= record.weight();
state.in_high_priority_pool = false;
unsafe { self.high_priority_list.remove_from_ptr(Arc::as_ptr(record)) }
}
(false, false) => unsafe { self.list.remove_from_ptr(Arc::as_ptr(record)) },
};
strict_assert!(!state.link.is_linked());
record.set_in_eviction(false);
}
fn clear(&mut self) {
while self.pop().is_some() {}
while let Some(record) = self.pin_list.pop_front() {
let state = unsafe { &mut *record.state().get() };
strict_assert!(!state.link.is_linked());
if state.in_high_priority_pool {
self.high_priority_weight -= record.weight();
state.in_high_priority_pool = false;
}
record.set_in_eviction(false);
}
assert!(self.list.is_empty());
assert!(self.high_priority_list.is_empty());
assert!(self.pin_list.is_empty());
assert_eq!(self.high_priority_weight, 0);
}
fn acquire() -> Op<Self> {
Op::mutable(|this: &mut Self, record| {
if !record.is_in_eviction() {
return;
}
let state = unsafe { &mut *record.state().get() };
assert!(state.link.is_linked());
if state.is_pinned {
return;
}
let r = if state.in_high_priority_pool {
unsafe { this.high_priority_list.remove_from_ptr(Arc::as_ptr(record)) }
} else {
unsafe { this.list.remove_from_ptr(Arc::as_ptr(record)) }
};
this.pin_list.push_back(r);
state.is_pinned = true;
})
}
fn release() -> Op<Self> {
Op::mutable(|this: &mut Self, record| {
if !record.is_in_eviction() {
return;
}
let state = unsafe { &mut *record.state().get() };
assert!(state.link.is_linked());
if !state.is_pinned {
return;
}
unsafe { this.pin_list.remove_from_ptr(Arc::as_ptr(record)) };
if state.in_high_priority_pool {
this.high_priority_list.push_back(record.clone());
} else {
this.list.push_back(record.clone());
}
state.is_pinned = false;
})
}
}
#[cfg(test)]
pub mod tests {
use itertools::Itertools;
use super::*;
use crate::{
eviction::test_utils::{assert_ptr_eq, assert_ptr_vec_vec_eq, Dump, OpExt},
record::Data,
};
impl<K, V> Dump for Lru<K, V>
where
K: Key + Clone,
V: Value + Clone,
{
type Output = Vec<Vec<Arc<Record<Self>>>>;
fn dump(&self) -> Self::Output {
let mut low = vec![];
let mut high = vec![];
let mut pin = vec![];
let mut cursor = self.list.cursor();
loop {
cursor.move_next();
match cursor.clone_pointer() {
Some(record) => low.push(record),
None => break,
}
}
let mut cursor = self.high_priority_list.cursor();
loop {
cursor.move_next();
match cursor.clone_pointer() {
Some(record) => high.push(record),
None => break,
}
}
let mut cursor = self.pin_list.cursor();
loop {
cursor.move_next();
match cursor.clone_pointer() {
Some(record) => pin.push(record),
None => break,
}
}
vec![low, high, pin]
}
}
type TestLru = Lru<u64, u64>;
#[test]
fn test_lru() {
let rs = (0..20)
.map(|i| {
Arc::new(Record::new(Data {
key: i,
value: i,
hint: if i < 10 {
LruHint::HighPriority
} else {
LruHint::LowPriority
},
hash: i,
weight: 1,
}))
})
.collect_vec();
let r = |i: usize| rs[i].clone();
let config = LruConfig {
high_priority_pool_ratio: 0.5,
};
let mut lru = TestLru::new(8, &config);
assert_eq!(lru.high_priority_weight_capacity, 4);
lru.push(r(0));
lru.push(r(1));
lru.push(r(2));
lru.push(r(3));
assert_ptr_vec_vec_eq(lru.dump(), vec![vec![], vec![r(0), r(1), r(2), r(3)], vec![]]);
lru.push(r(4));
assert_ptr_vec_vec_eq(lru.dump(), vec![vec![r(0)], vec![r(1), r(2), r(3), r(4)], vec![]]);
lru.push(r(10));
assert_ptr_vec_vec_eq(
lru.dump(),
vec![vec![r(0), r(10)], vec![r(1), r(2), r(3), r(4)], vec![]],
);
let r0 = lru.pop().unwrap();
assert_ptr_eq(&r(0), &r0);
assert_ptr_vec_vec_eq(lru.dump(), vec![vec![r(10)], vec![r(1), r(2), r(3), r(4)], vec![]]);
lru.remove(&rs[2]);
assert_ptr_vec_vec_eq(lru.dump(), vec![vec![r(10)], vec![r(1), r(3), r(4)], vec![]]);
lru.push(r(11));
assert_ptr_vec_vec_eq(lru.dump(), vec![vec![r(10), r(11)], vec![r(1), r(3), r(4)], vec![]]);
lru.push(r(5));
lru.push(r(6));
assert_ptr_vec_vec_eq(
lru.dump(),
vec![vec![r(10), r(11), r(1)], vec![r(3), r(4), r(5), r(6)], vec![]],
);
lru.push(r(0));
assert_ptr_vec_vec_eq(
lru.dump(),
vec![vec![r(10), r(11), r(1), r(3)], vec![r(4), r(5), r(6), r(0)], vec![]],
);
lru.clear();
assert_ptr_vec_vec_eq(lru.dump(), vec![vec![], vec![], vec![]]);
}
#[test]
fn test_lru_pin() {
let rs = (0..20)
.map(|i| {
Arc::new(Record::new(Data {
key: i,
value: i,
hint: if i < 10 {
LruHint::HighPriority
} else {
LruHint::LowPriority
},
hash: i,
weight: 1,
}))
})
.collect_vec();
let r = |i: usize| rs[i].clone();
let config = LruConfig {
high_priority_pool_ratio: 0.5,
};
let mut lru = TestLru::new(8, &config);
assert_eq!(lru.high_priority_weight_capacity, 4);
lru.push(r(0));
lru.push(r(1));
lru.push(r(10));
lru.push(r(11));
assert_ptr_vec_vec_eq(lru.dump(), vec![vec![r(10), r(11)], vec![r(0), r(1)], vec![]]);
lru.acquire_mutable(&rs[0]);
lru.acquire_mutable(&rs[10]);
assert_ptr_vec_vec_eq(lru.dump(), vec![vec![r(11)], vec![r(1)], vec![r(0), r(10)]]);
lru.release_mutable(&rs[0]);
lru.release_mutable(&rs[10]);
assert_ptr_vec_vec_eq(lru.dump(), vec![vec![r(11), r(10)], vec![r(1), r(0)], vec![]]);
lru.acquire_mutable(&rs[0]);
lru.acquire_mutable(&rs[11]);
lru.acquire_mutable(&rs[0]);
lru.acquire_mutable(&rs[11]);
assert_ptr_vec_vec_eq(lru.dump(), vec![vec![r(10)], vec![r(1)], vec![r(0), r(11)]]);
lru.remove(&rs[11]);
assert_ptr_vec_vec_eq(lru.dump(), vec![vec![r(10)], vec![r(1)], vec![r(0)]]);
lru.push(r(2));
lru.acquire_mutable(&rs[2]);
assert_ptr_vec_vec_eq(lru.dump(), vec![vec![r(10)], vec![r(1)], vec![r(0), r(2)]]);
lru.remove(&rs[2]);
assert_ptr_vec_vec_eq(lru.dump(), vec![vec![r(10)], vec![r(1)], vec![r(0)]]);
lru.release_mutable(&rs[11]);
assert_ptr_vec_vec_eq(lru.dump(), vec![vec![r(10)], vec![r(1)], vec![r(0)]]);
lru.release_mutable(&rs[0]);
lru.release_mutable(&rs[0]);
assert_ptr_vec_vec_eq(lru.dump(), vec![vec![r(10)], vec![r(1), r(0)], vec![]]);
lru.acquire_mutable(&rs[1]);
assert_ptr_vec_vec_eq(lru.dump(), vec![vec![r(10)], vec![r(0)], vec![r(1)]]);
lru.clear();
assert_ptr_vec_vec_eq(lru.dump(), vec![vec![], vec![], vec![]]);
}
}