1use std::hash::Hash;
2use std::marker::PhantomData;
3use std::sync::atomic::AtomicBool;
4use std::sync::Arc;
5use std::{
6 borrow::Borrow,
7 collections::{HashMap, VecDeque},
8 hash::BuildHasher,
9};
10
11pub trait Weigher<K, V> {
12 fn weigh(_k: &K, _v: &V) -> usize {
13 1
14 }
15}
16
17#[derive(Debug, Clone)]
18pub struct One;
19impl<K, V> Weigher<K, V> for One {}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum EntryStatus {
24 Retain,
26 Evict,
28}
29
30pub trait Lifecycle<K, V> {
31 fn on_eviction(&mut self, _key: K, _value: V) {}
34
35 fn evaluate(&self, _key: &K, _value: &V) -> EntryStatus {
38 EntryStatus::Retain
39 }
40}
41
42#[derive(Debug, Clone, Copy, Default)]
43pub struct DefaultLifecycle;
44impl<K, V> Lifecycle<K, V> for DefaultLifecycle {}
45
46#[derive(Debug)]
47struct SieveEntry<D> {
48 data: D,
49 visited: Arc<AtomicBool>,
50}
51
52#[derive(Debug)]
53pub struct Cache<K, V, S, W: Weigher<K, V> = One, L: Lifecycle<K, V> = DefaultLifecycle> {
54 map: HashMap<K, SieveEntry<V>, S>,
55 sieve_pool: VecDeque<SieveEntry<K>>,
56 sieve_hand: usize,
57 max_weight: usize,
58 weight: usize,
59 lifecycle: L,
60 _phantom: PhantomData<W>,
61}
62
63impl<K, V, S, W, L> Cache<K, V, S, W, L>
64where
65 K: Eq + Hash + Clone,
66 S: BuildHasher,
67 W: Weigher<K, V>,
68 L: Lifecycle<K, V> + Default,
69{
70 pub fn new(hasher: S, max_weight: usize) -> Self {
71 Self {
72 map: HashMap::with_hasher(hasher),
73 sieve_pool: VecDeque::new(),
74 sieve_hand: 0,
75 max_weight,
76 weight: 0,
77 lifecycle: Default::default(),
78 _phantom: PhantomData,
79 }
80 }
81}
82
83impl<K, V, S, W, L> Cache<K, V, S, W, L>
84where
85 K: Eq + Hash + Clone,
86 S: BuildHasher,
87 W: Weigher<K, V>,
88 L: Lifecycle<K, V>,
89{
90 pub fn new_with_lifecycle(hasher: S, max_weight: usize, lifecycle: L) -> Self {
91 Self {
92 map: HashMap::with_hasher(hasher),
93 sieve_pool: VecDeque::new(),
94 sieve_hand: 0,
95 max_weight,
96 weight: 0,
97 lifecycle,
98 _phantom: PhantomData,
99 }
100 }
101
102 pub fn put(&mut self, key: K, value: V) {
103 self.walk_hand();
105
106 let new_entry_weight = self.make_room_for(&key, &value);
107 self.weight += new_entry_weight;
108
109 match self.map.entry(key.clone()) {
110 std::collections::hash_map::Entry::Occupied(mut occupied_entry) => {
111 let replaced_weight = W::weigh(&key, &occupied_entry.get().data);
112 self.weight -= replaced_weight; occupied_entry.get_mut().data = value;
115 occupied_entry
116 .get_mut()
117 .visited
118 .store(true, std::sync::atomic::Ordering::Relaxed);
119 }
120 std::collections::hash_map::Entry::Vacant(vacant_entry) => {
121 let visited = Arc::new(AtomicBool::new(true));
126 vacant_entry.insert(SieveEntry {
127 data: value,
128 visited: visited.clone(),
129 });
130 self.sieve_pool.push_back(SieveEntry { data: key, visited });
131 }
132 }
133 }
134
135 pub fn get<Q>(&self, key: &Q) -> Option<&V>
136 where
137 K: Borrow<Q>,
138 Q: Hash + Eq + ?Sized,
139 {
140 let (full_key, entry) = self.map.get_key_value(key)?;
141 if self.lifecycle.evaluate(full_key, &entry.data) == EntryStatus::Evict {
144 return None;
145 }
146 entry
147 .visited
148 .store(true, std::sync::atomic::Ordering::Relaxed);
149 Some(&entry.data)
150 }
151
152 pub fn remove(&mut self, key: &K) -> Option<V> {
153 match self.map.remove(key) {
154 Some(removed) => {
155 let removed_weight = W::weigh(key, &removed.data);
160 match self.weight.checked_sub(removed_weight) {
161 Some(new_weight) => self.weight = new_weight,
162 None => {
163 log::error!("weight underflow");
164 self.weight = 0;
165 }
166 };
167 Some(removed.data)
168 }
169 None => {
170 log::debug!("garbage collecting sieve entry as {}", self.sieve_hand);
171 None
172 }
173 }
174 }
175
176 pub fn evict_all(&mut self) {
181 for (k, v) in self.map.drain() {
182 self.lifecycle.on_eviction(k, v.data);
183 }
184 self.sieve_pool.clear();
185 self.sieve_hand = 0;
186 self.weight = 0;
187 }
188
189 fn walk_hand(&mut self) {
193 for _ in 0..3 {
194 if self.sieve_pool.is_empty() {
195 return;
196 }
197 let status = {
198 let key = &self.sieve_pool[self.sieve_hand].data;
199 match self.map.get(key) {
200 Some(entry) => self.lifecycle.evaluate(key, &entry.data),
201 None => EntryStatus::Evict,
202 }
203 };
204 match status {
205 EntryStatus::Evict => {
206 let sieve_key_entry = self
207 .sieve_pool
208 .swap_remove_back(self.sieve_hand)
209 .expect("the index must be present");
210 if let Some(removed_value) = self.remove(&sieve_key_entry.data) {
211 self.lifecycle
212 .on_eviction(sieve_key_entry.data, removed_value);
213 } else {
214 log::debug!("garbage collecting sieve entry at {}", self.sieve_hand);
215 }
216 if self.sieve_pool.is_empty() {
217 self.sieve_hand = 0;
218 return;
219 }
220 if self.sieve_hand >= self.sieve_pool.len() {
221 self.sieve_hand = 0;
222 }
223 }
224 EntryStatus::Retain => {
225 self.sieve_hand = (self.sieve_hand + 1) % self.sieve_pool.len();
226 }
227 }
228 }
229 }
230
231 fn make_room_for(&mut self, key: &K, value: &V) -> usize {
232 let entry_weight = W::weigh(key, value);
233 while self.max_weight < self.weight + entry_weight {
234 let sieve_entry = &mut self.sieve_pool[self.sieve_hand];
235 let visited = sieve_entry
236 .visited
237 .swap(false, std::sync::atomic::Ordering::Relaxed);
238 if visited {
239 self.sieve_hand = (self.sieve_hand + 1) % self.sieve_pool.len();
240 } else {
241 let sieve_key_entry = self
244 .sieve_pool
245 .swap_remove_back(self.sieve_hand)
246 .expect("the index must be present");
247 let removed = self.remove(&sieve_key_entry.data);
248 if let Some(removed_value) = removed {
249 self.lifecycle
250 .on_eviction(sieve_key_entry.data, removed_value);
251 } else {
252 log::debug!("garbage collecting sieve entry at {}", self.sieve_hand);
256 }
257
258 if self.sieve_hand == self.sieve_pool.len() {
259 self.sieve_hand = 0;
260 }
261 }
262 }
263 entry_weight
264 }
265}
266
267#[cfg(test)]
268mod test {
269 use std::hash::RandomState;
270
271 use super::*;
272
273 #[test]
274 fn test_put() {
275 let mut cache: Cache<String, String, RandomState> = Cache::new(RandomState::new(), 100);
276 cache.put("key1".to_string(), "value1".to_string());
277 assert_eq!(cache.get("key1"), Some(&"value1".to_string()));
278 cache.put("key1".to_string(), "value2".to_string());
279 assert_eq!(cache.get("key1"), Some(&"value2".to_string()));
280 assert_eq!(cache.weight, 1);
281 assert_eq!(cache.map.len(), 1);
282 assert_eq!(cache.sieve_pool.len(), 1);
283 }
284}