#[cfg(test)]
use mock_instant::thread_local::Instant;
#[cfg(not(test))]
use std::time::Instant;
use itertools::Itertools;
use linked_hash_map::LinkedHashMap;
use std::time::Duration;
pub trait ByteSize {
fn allocated_size(&self) -> usize;
}
impl ByteSize for String {
fn allocated_size(&self) -> usize {
self.capacity()
}
}
pub struct LruCache<V: ByteSize> {
num_entries: usize,
capacity: usize,
allocated_memory: usize,
max_memory: usize,
soft_ttl: Duration,
hard_ttl: Duration,
refresh_interval: Duration,
reads: usize,
hits: usize,
writes: usize,
map: LinkedHashMap<String, Entry<V>>,
}
struct Entry<V: ByteSize> {
mem_size: usize,
soft_ttl: Instant,
hard_ttl: Instant,
next_refresh_request: Instant,
value: V,
secondary_keys: Option<Vec<String>>,
}
impl<V: ByteSize> LruCache<V> {
pub fn new(
capacity: usize,
max_memory: usize,
soft_ttl: Duration,
hard_ttl: Duration,
refresh_interval: Duration,
) -> Self {
LruCache {
num_entries: 0,
capacity,
allocated_memory: 0,
max_memory,
soft_ttl,
hard_ttl,
refresh_interval,
reads: 0,
hits: 0,
writes: 0,
map: LinkedHashMap::with_capacity(capacity),
}
}
pub fn put(&mut self, key: String, value: V) -> anyhow::Result<()> {
self.put_with_secondaries(key, value, None)
}
pub fn put_with_secondaries(
&mut self,
key: String,
value: V,
secondary_keys: Option<Vec<String>>,
) -> anyhow::Result<()> {
let mut mem_size = key.len() + value.allocated_size();
if let Some(secondary_keys) = &secondary_keys {
mem_size += secondary_keys.iter().map(String::len).sum::<usize>();
}
let entry = Entry {
mem_size,
soft_ttl: Instant::now() + self.soft_ttl,
hard_ttl: Instant::now() + self.hard_ttl,
next_refresh_request: Instant::now(),
value,
secondary_keys,
};
if entry.mem_size > self.max_memory {
return Err(anyhow::anyhow!(
"The entry to be cached is larger than the whole cache size!"
));
}
let mut delta_mem: isize = entry.mem_size as isize;
let mut delta_count = 1;
if let Some(stale_entry) = self.map.insert(key, entry) {
delta_mem -= stale_entry.mem_size as isize;
delta_count = 0
}
self.writes += 1;
self.num_entries += delta_count;
self.allocated_memory = (self.allocated_memory as isize + delta_mem) as usize;
self.enforce_constraints();
Ok(())
}
fn enforce_constraints(&mut self) {
while self.num_entries > self.capacity || self.allocated_memory > self.max_memory {
match self.map.pop_front() {
Some(lru_entry) => {
self.num_entries -= 1;
self.allocated_memory -= lru_entry.1.mem_size;
}
None => unreachable!("Failed to enforce constraints of a LRU cache!"),
}
}
}
pub fn get(&mut self, key: &str) -> Option<&V> {
self.reads += 1;
let now = Instant::now();
match self.map.get_refresh(key) {
Some(entry) if entry.soft_ttl > now => {
self.hits += 1;
Some(&entry.value)
}
_ => None,
}
}
pub fn extended_get(&mut self, key: &str) -> Option<(bool, bool, &V)> {
self.reads += 1;
let now = Instant::now();
match self.map.get_refresh(key) {
Some(entry) if entry.hard_ttl > now => {
self.hits += 1;
let mut alive = entry.soft_ttl > now;
let mut refresh = false;
if !alive {
if entry.next_refresh_request <= now {
entry.next_refresh_request = now + self.refresh_interval;
refresh = true;
} else {
alive = true;
}
}
Some((alive, refresh, &entry.value))
}
_ => None,
}
}
pub fn remove(&mut self, key: &str) {
self.writes += 1;
if let Some(entry) = self.map.remove(key) {
self.num_entries -= 1;
self.allocated_memory -= entry.mem_size;
}
}
pub fn remove_by_secondary(&mut self, secondary_key: &str) {
let mut entries_to_remove = Vec::default();
for (key, entry) in self.map.iter() {
if let Some(secondary_keys) = &entry.secondary_keys {
if secondary_keys
.iter()
.map(String::as_str)
.contains(&secondary_key)
{
entries_to_remove.push(key.clone());
}
}
}
for key in entries_to_remove {
self.remove(&key);
}
}
pub fn keys(&self) -> impl Iterator<Item = &String> + '_ {
self.map.keys()
}
pub fn flush(&mut self) {
self.map.clear();
self.allocated_memory = 0;
self.num_entries = 0;
self.reads = 0;
self.writes = 0;
self.hits = 0;
}
pub fn len(&self) -> usize {
self.num_entries
}
pub fn is_empty(&self) -> bool {
self.num_entries == 0
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn set_capacity(&mut self, capacity: usize) {
let previous_capacity = self.capacity;
self.capacity = capacity;
if previous_capacity > self.capacity {
self.enforce_constraints();
}
}
pub fn max_memory(&self) -> usize {
self.max_memory
}
pub fn set_max_memory(&mut self, max_memory: usize) {
let previous_max_memory = self.max_memory;
self.max_memory = max_memory;
if previous_max_memory > self.max_memory {
self.enforce_constraints();
}
}
pub fn soft_ttl(&self) -> Duration {
self.soft_ttl
}
pub fn set_soft_ttl(&mut self, soft_ttl: Duration) {
self.soft_ttl = soft_ttl;
}
pub fn hard_ttl(&self) -> Duration {
self.hard_ttl
}
pub fn set_hard_ttl(&mut self, hard_ttl: Duration) {
self.hard_ttl = hard_ttl;
}
pub fn refresh_interval(&self) -> Duration {
self.refresh_interval
}
pub fn set_refresh_interval(&mut self, refresh_interval: Duration) {
self.refresh_interval = refresh_interval;
}
pub fn allocated_memory(&self) -> usize {
self.allocated_memory
}
pub fn total_allocated_memory(&self) -> usize {
self.allocated_memory
+ self.map.capacity()
* (std::mem::size_of::<String>() + std::mem::size_of::<Entry<V>>())
}
pub fn utilization(&self) -> f32 {
self.num_entries as f32 / self.capacity as f32 * 100.
}
pub fn memory_utilization(&self) -> f32 {
self.allocated_memory as f32 / self.max_memory as f32 * 100.
}
pub fn hit_rate(&self) -> f32 {
match self.reads {
0 => 0.,
n => self.hits as f32 / n as f32 * 100.,
}
}
pub fn write_read_ratio(&self) -> f32 {
match self.reads {
0 => 100.,
n => self.writes as f32 / (self.writes + n) as f32 * 100.,
}
}
pub fn reads(&self) -> usize {
self.reads
}
pub fn writes(&self) -> usize {
self.writes
}
}
#[cfg(test)]
mod tests {
use crate::lru::LruCache;
use mock_instant::thread_local::MockClock;
use tokio::time::Duration;
#[test]
fn empty_caches_consume_no_resources() {
let mut lru = LruCache::new(
2 ^ 64,
1024,
Duration::from_secs(60 * 60),
Duration::from_secs(60 * 60),
Duration::from_secs(60),
);
lru.put("Hello".to_owned(), "World".to_owned()).unwrap();
assert!(lru.total_allocated_memory() < 32_000);
}
#[test]
fn capacity_is_enforced() {
let mut lru = LruCache::new(
4,
8192,
Duration::from_secs(60 * 60),
Duration::from_secs(60 * 60),
Duration::from_secs(60),
);
lru.put("Hello".to_owned(), "World".to_owned()).unwrap();
lru.put("Hello1".to_owned(), "World1".to_owned()).unwrap();
lru.put("Hello2".to_owned(), "World2".to_owned()).unwrap();
lru.put("Hello3".to_owned(), "World3".to_owned()).unwrap();
assert_eq!(lru.len(), 4);
assert_eq!(lru.get("Hello").unwrap(), &"World".to_owned());
assert_eq!(lru.get("Hello1").unwrap(), &"World1".to_owned());
assert_eq!(lru.get("Hello2").unwrap(), &"World2".to_owned());
assert_eq!(lru.get("Hello3").unwrap(), &"World3".to_owned());
lru.put("Hello4".to_owned(), "World4".to_owned()).unwrap();
assert_eq!(lru.get("Hello"), None);
assert_eq!(lru.get("Hello1").unwrap(), &"World1".to_owned());
assert_eq!(lru.get("Hello2").unwrap(), &"World2".to_owned());
assert_eq!(lru.get("Hello3").unwrap(), &"World3".to_owned());
assert_eq!(lru.get("Hello4").unwrap(), &"World4".to_owned());
let _ = lru.get("Hello1");
lru.put("Hello5".to_owned(), "World5".to_owned()).unwrap();
assert_eq!(lru.get("Hello1").unwrap(), &"World1".to_owned());
assert_eq!(lru.get("Hello2"), None);
assert_eq!(lru.get("Hello3").unwrap(), &"World3".to_owned());
assert_eq!(lru.get("Hello4").unwrap(), &"World4".to_owned());
assert_eq!(lru.get("Hello5").unwrap(), &"World5".to_owned());
assert_eq!(lru.len(), 4);
lru.remove("Hello5");
assert_eq!(lru.len(), 3);
lru.put("Hello6".to_owned(), "World6".to_owned()).unwrap();
assert_eq!(lru.get("Hello1").unwrap(), &"World1".to_owned());
assert_eq!(lru.get("Hello3").unwrap(), &"World3".to_owned());
assert_eq!(lru.get("Hello4").unwrap(), &"World4".to_owned());
assert_eq!(lru.get("Hello6").unwrap(), &"World6".to_owned());
assert_eq!(lru.len(), 4);
}
#[test]
fn max_memory_is_enforced() {
let mut lru = LruCache::new(
128,
12 * 4,
Duration::from_secs(60 * 60),
Duration::from_secs(60 * 60),
Duration::from_secs(60),
);
lru.put("Hello0".to_owned(), "World0".to_owned()).unwrap();
lru.put("Hello1".to_owned(), "World1".to_owned()).unwrap();
lru.put("Hello2".to_owned(), "World2".to_owned()).unwrap();
lru.put("Hello3".to_owned(), "World3".to_owned()).unwrap();
assert_eq!(lru.len(), 4);
assert_eq!(lru.allocated_memory(), 12 * 4);
assert_eq!(lru.get("Hello0").unwrap(), &"World0".to_owned());
assert_eq!(lru.get("Hello1").unwrap(), &"World1".to_owned());
assert_eq!(lru.get("Hello2").unwrap(), &"World2".to_owned());
assert_eq!(lru.get("Hello3").unwrap(), &"World3".to_owned());
lru.remove("Hello0");
assert_eq!(lru.len(), 3);
assert_eq!(lru.allocated_memory(), 12 * 3);
lru.put("Hello1".to_owned(), "".to_owned()).unwrap();
assert_eq!(lru.allocated_memory(), 12 * 3 - 6);
lru.put("Hello1".to_owned(), "World1".to_owned()).unwrap();
assert_eq!(lru.allocated_memory(), 12 * 3);
lru.put("Hello0".to_owned(), "World01".to_owned()).unwrap();
assert_eq!(lru.allocated_memory(), 12 * 3 + 1);
assert_eq!(lru.len(), 3);
assert_eq!(lru.get("Hello2"), None);
}
#[test]
fn ttls_are_properly_enforced() {
let mut lru = LruCache::new(
1024,
1024,
Duration::from_secs(15 * 60),
Duration::from_secs(30 * 60),
Duration::from_secs(60),
);
lru.put("Foo".to_owned(), "Bar".to_owned()).unwrap();
assert_eq!(lru.get("Foo").unwrap(), "Bar");
MockClock::advance(Duration::from_secs(16 * 60));
assert_eq!(lru.get("Foo"), None);
assert_eq!(
lru.extended_get("Foo").unwrap(),
(false, true, &"Bar".to_owned())
);
assert_eq!(
lru.extended_get("Foo").unwrap(),
(true, false, &"Bar".to_owned())
);
MockClock::advance(Duration::from_secs(2 * 60));
assert_eq!(
lru.extended_get("Foo").unwrap(),
(false, true, &"Bar".to_owned())
);
MockClock::advance(Duration::from_secs(16 * 60));
assert_eq!(lru.extended_get("Foo"), None);
}
#[test]
fn ttls_are_discarded_on_put() {
let mut lru = LruCache::new(
1024,
1024,
Duration::from_secs(15 * 60),
Duration::from_secs(30 * 60),
Duration::from_secs(60),
);
lru.put("Foo".to_owned(), "Bar".to_owned()).unwrap();
assert_eq!(lru.get("Foo").unwrap(), "Bar");
MockClock::advance(Duration::from_secs(16 * 60));
assert_eq!(lru.get("Foo"), None);
assert_eq!(
lru.extended_get("Foo").unwrap(),
(false, true, &"Bar".to_owned())
);
lru.put("Foo".to_owned(), "Bar1".to_owned()).unwrap();
assert_eq!(lru.get("Foo").unwrap(), "Bar1");
assert_eq!(
lru.extended_get("Foo").unwrap(),
(true, false, &"Bar1".to_owned())
);
}
#[test]
fn metrics_are_computed_correctly() {
let mut lru = LruCache::new(
4,
10,
Duration::from_secs(15 * 60),
Duration::from_secs(30 * 60),
Duration::from_secs(60),
);
lru.put("A".to_owned(), "A".to_owned()).unwrap();
lru.put("B".to_owned(), "B".to_owned()).unwrap();
lru.put_with_secondaries("C".to_owned(), "C".to_owned(), Some(vec!["D".to_owned()]))
.unwrap();
assert!(lru.get("A").is_some());
assert!(lru.get("B").is_some());
assert!(lru.get("C").is_some());
assert!(lru.get("D").is_none());
assert_eq!(lru.writes(), 3);
assert_eq!(lru.reads(), 4);
assert_eq!(lru.hit_rate().round() as i32, 75);
assert!(lru.get("A").is_some());
assert!(lru.get("B").is_some());
assert!(lru.get("C").is_some());
assert_eq!(lru.reads(), 7);
assert_eq!(lru.write_read_ratio().round() as i32, 30);
assert_eq!(lru.allocated_memory(), 7);
assert!(lru.total_allocated_memory() > lru.allocated_memory());
assert_eq!(lru.utilization().round() as i32, 75);
assert_eq!(lru.memory_utilization().round() as i32, 70);
}
}