1use std::hash::Hash;
2use std::marker::PhantomData;
3use std::sync::atomic::AtomicBool;
4use std::sync::Arc;
5use std::{
6 borrow::Borrow,
7 collections::{HashMap, VecDeque},
8 hash::BuildHasher,
9};
10
11pub trait Weigher<K, V> {
12 fn weigh(_k: &K, _v: &V) -> usize {
13 1
14 }
15}
16
17#[derive(Debug, Clone)]
18pub struct One;
19impl<K, V> Weigher<K, V> for One {}
20
21pub trait Lifecycle<K, V> {
22 fn on_eviction(&self, _key: K, _value: V) {}
23}
24
25#[derive(Debug, Clone, Copy, Default)]
26pub struct DefaultLifecycle;
27impl<K, V> Lifecycle<K, V> for DefaultLifecycle {}
28
29#[derive(Debug)]
30struct SieveEntry<D> {
31 data: D,
32 visited: Arc<AtomicBool>,
33}
34
35#[derive(Debug)]
36pub struct Cache<K, V, S, W: Weigher<K, V> = One, L: Lifecycle<K, V> = DefaultLifecycle> {
37 map: HashMap<K, SieveEntry<V>, S>,
38 sieve_pool: VecDeque<SieveEntry<K>>,
39 sieve_hand: usize,
40 max_weight: usize,
41 weight: usize,
42 lifecycle: L,
43 _phantom: PhantomData<W>,
44}
45
46impl<K, V, S, W, L> Cache<K, V, S, W, L>
47where
48 K: Eq + Hash + Clone,
49 S: BuildHasher,
50 W: Weigher<K, V>,
51 L: Lifecycle<K, V> + Default,
52{
53 pub fn new(hasher: S, max_weight: usize) -> Self {
54 Self {
55 map: HashMap::with_hasher(hasher),
56 sieve_pool: VecDeque::new(),
57 sieve_hand: 0,
58 max_weight,
59 weight: 0,
60 lifecycle: Default::default(),
61 _phantom: PhantomData,
62 }
63 }
64}
65
66impl<K, V, S, W, L> Cache<K, V, S, W, L>
67where
68 K: Eq + Hash + Clone,
69 S: BuildHasher,
70 W: Weigher<K, V>,
71 L: Lifecycle<K, V>,
72{
73 pub fn new_with_lifecycle(hasher: S, max_weight: usize, lifecycle: L) -> Self {
74 Self {
75 map: HashMap::with_hasher(hasher),
76 sieve_pool: VecDeque::new(),
77 sieve_hand: 0,
78 max_weight,
79 weight: 0,
80 lifecycle,
81 _phantom: PhantomData,
82 }
83 }
84
85 pub fn put(&mut self, key: K, value: V) {
86 let new_entry_weight = self.make_room_for(&key, &value);
87 self.weight += new_entry_weight;
88
89 match self.map.entry(key.clone()) {
90 std::collections::hash_map::Entry::Occupied(mut occupied_entry) => {
91 let replaced_weight = W::weigh(&key, &occupied_entry.get().data);
92 self.weight -= replaced_weight; occupied_entry.get_mut().data = value;
95 occupied_entry
96 .get_mut()
97 .visited
98 .store(true, std::sync::atomic::Ordering::Relaxed);
99 }
100 std::collections::hash_map::Entry::Vacant(vacant_entry) => {
101 let visited = Arc::new(AtomicBool::new(true));
106 vacant_entry.insert(SieveEntry {
107 data: value,
108 visited: visited.clone(),
109 });
110 self.sieve_pool.push_back(SieveEntry { data: key, visited });
111 }
112 }
113 }
114
115 pub fn get<Q>(&self, key: &Q) -> Option<&V>
116 where
117 K: Borrow<Q>,
118 Q: Hash + Eq + ?Sized,
119 {
120 match self.map.get(key) {
121 Some(entry) => {
122 entry
123 .visited
124 .store(true, std::sync::atomic::Ordering::Relaxed);
125 Some(&entry.data)
126 }
127 None => None,
128 }
129 }
130
131 pub fn remove(&mut self, key: &K) -> Option<V> {
132 match self.map.remove(key) {
133 Some(removed) => {
134 let removed_weight = W::weigh(key, &removed.data);
139 match self.weight.checked_sub(removed_weight) {
140 Some(new_weight) => self.weight = new_weight,
141 None => {
142 log::error!("weight underflow");
143 self.weight = 0;
144 }
145 };
146 Some(removed.data)
147 }
148 None => {
149 log::debug!("garbage collecting sieve entry as {}", self.sieve_hand);
150 None
151 }
152 }
153 }
154
155 pub fn evict_all(&mut self) {
160 for (k, v) in self.map.drain() {
161 self.lifecycle.on_eviction(k, v.data);
162 }
163 self.sieve_pool.clear();
164 self.sieve_hand = 0;
165 self.weight = 0;
166 }
167
168 fn make_room_for(&mut self, key: &K, value: &V) -> usize {
169 let entry_weight = W::weigh(key, value);
170 while self.max_weight < self.weight + entry_weight {
171 let sieve_entry = &mut self.sieve_pool[self.sieve_hand];
172 let visited = sieve_entry
173 .visited
174 .swap(false, std::sync::atomic::Ordering::Relaxed);
175 if visited {
176 self.sieve_hand = (self.sieve_hand + 1) % self.sieve_pool.len();
177 } else {
178 let sieve_key_entry = self
179 .sieve_pool
180 .swap_remove_back(self.sieve_hand)
181 .expect("the index must be present");
182 let removed = self.remove(&sieve_key_entry.data);
183 if let Some(removed_value) = removed {
184 self.lifecycle
185 .on_eviction(sieve_key_entry.data, removed_value);
186 } else {
187 log::debug!("garbage collecting sieve entry at {}", self.sieve_hand);
191 }
192
193 if self.sieve_hand == self.sieve_pool.len() {
194 self.sieve_hand = 0;
195 }
196 }
197 }
198 entry_weight
199 }
200}
201
202#[cfg(test)]
203mod test {
204 use std::hash::RandomState;
205
206 use super::*;
207
208 #[test]
209 fn test_put() {
210 let mut cache: Cache<String, String, RandomState> = Cache::new(RandomState::new(), 100);
211 cache.put("key1".to_string(), "value1".to_string());
212 assert_eq!(cache.get("key1"), Some(&"value1".to_string()));
213 cache.put("key1".to_string(), "value2".to_string());
214 assert_eq!(cache.get("key1"), Some(&"value2".to_string()));
215 assert_eq!(cache.weight, 1);
216 assert_eq!(cache.map.len(), 1);
217 assert_eq!(cache.sieve_pool.len(), 1);
218 }
219}