k_cache/
cache.rs

1use std::hash::Hash;
2use std::marker::PhantomData;
3use std::sync::atomic::AtomicBool;
4use std::sync::Arc;
5use std::{
6    borrow::Borrow,
7    collections::{HashMap, VecDeque},
8    hash::BuildHasher,
9};
10
11pub trait Weigher<K, V> {
12    fn weigh(_k: &K, _v: &V) -> usize {
13        1
14    }
15}
16
17#[derive(Debug, Clone)]
18pub struct One;
19impl<K, V> Weigher<K, V> for One {}
20
21#[derive(Debug)]
22struct SieveEntry<D> {
23    data: D,
24    visited: Arc<AtomicBool>,
25}
26
27#[derive(Debug)]
28pub struct Cache<K, V, S, W: Weigher<K, V> = One> {
29    map: HashMap<K, SieveEntry<V>, S>,
30    sieve_pool: VecDeque<SieveEntry<K>>,
31    sieve_hand: usize,
32    max_weight: usize,
33    weight: usize,
34    _phantom: PhantomData<W>,
35}
36
37impl<K, V, S, W> Cache<K, V, S, W>
38where
39    K: Eq + Hash + Clone,
40    S: BuildHasher,
41    W: Weigher<K, V>,
42{
43    pub fn new(hasher: S, max_weight: usize) -> Self {
44        Self {
45            map: HashMap::with_hasher(hasher),
46            sieve_pool: VecDeque::new(),
47            sieve_hand: 0,
48            max_weight,
49            weight: 0,
50            _phantom: PhantomData,
51        }
52    }
53
54    pub fn put(&mut self, key: K, value: V) {
55        let new_weight = self.make_room_for(&key, &value);
56        self.weight += new_weight;
57        let visited = Arc::new(AtomicBool::new(true));
58        if let Some(replaced) = self.map.insert(
59            key.clone(),
60            SieveEntry {
61                data: value,
62                visited: visited.clone(),
63            },
64        ) {
65            let replaced_weight = W::weigh(&key, &replaced.data);
66            match self.weight.checked_sub(replaced_weight) {
67                Some(new_weight) => self.weight = new_weight,
68                None => {
69                    log::error!("weight underflow");
70                    self.weight = 0;
71                }
72            }
73        }
74        self.sieve_pool.push_back(SieveEntry {
75            data: key.clone(),
76            visited,
77        });
78    }
79
80    pub fn get<Q>(&self, key: &Q) -> Option<&V>
81    where
82        K: Borrow<Q>,
83        Q: Hash + Eq + ?Sized,
84    {
85        match self.map.get(key) {
86            Some(entry) => {
87                entry
88                    .visited
89                    .store(true, std::sync::atomic::Ordering::Relaxed);
90                Some(&entry.data)
91            }
92            None => None,
93        }
94    }
95
96    fn make_room_for(&mut self, key: &K, value: &V) -> usize {
97        let entry_weight = W::weigh(key, value);
98        while self.max_weight < self.weight + entry_weight {
99            let sieve_entry = &mut self.sieve_pool[self.sieve_hand];
100            let visited = sieve_entry
101                .visited
102                .swap(false, std::sync::atomic::Ordering::Relaxed);
103            if visited {
104                self.sieve_hand = (self.sieve_hand + 1) % self.sieve_pool.len();
105            } else {
106                let sieve_key_entry = self
107                    .sieve_pool
108                    .swap_remove_back(self.sieve_hand)
109                    .expect("the index must be present");
110                match self.map.remove(&sieve_key_entry.data) {
111                    Some(removed) => {
112                        let removed_weight = W::weigh(&sieve_key_entry.data, &removed.data);
113                        match self.weight.checked_sub(removed_weight) {
114                            Some(new_weight) => self.weight = new_weight,
115                            None => {
116                                log::error!("weight underflow");
117                                self.weight = 0;
118                            }
119                        }
120                    }
121                    None => {
122                        log::debug!("garbage collecting sieve entry as {}", self.sieve_hand);
123                    }
124                }
125
126                if self.sieve_hand == self.sieve_pool.len() {
127                    self.sieve_hand = 0;
128                }
129            }
130        }
131        entry_weight
132    }
133}