1use core::hash::{Hash, Hasher};
4use std::collections::hash_map::DefaultHasher;
5use std::collections::HashMap;
6use std::num::NonZeroUsize;
7
8use crate::cache::Cache;
9use crate::error::CacheError;
10use crate::sharding::{self, Sharded};
11use crate::util::MutexExt;
12
13const SKETCH_DEPTH: usize = 4;
14const MIN_SKETCH_WIDTH: usize = 64;
15
16pub struct TinyLfuCache<K, V> {
56 capacity: NonZeroUsize,
57 sharded: Sharded<Inner<K, V>>,
58}
59
60struct Node<K, V> {
61 key: K,
62 value: V,
63 prev: Option<usize>,
64 next: Option<usize>,
65}
66
67struct Inner<K, V> {
68 capacity: NonZeroUsize,
69 nodes: Vec<Option<Node<K, V>>>,
70 free: Vec<usize>,
71 head: Option<usize>,
72 tail: Option<usize>,
73 map: HashMap<K, usize>,
74 sketch: CountMinSketch,
75}
76
77impl<K, V> Inner<K, V>
78where
79 K: Eq + Hash + Clone,
80{
81 fn with_capacity(capacity: NonZeroUsize) -> Self {
82 let cap = capacity.get();
83 Self {
84 capacity,
85 nodes: Vec::with_capacity(cap),
86 free: Vec::new(),
87 head: None,
88 tail: None,
89 map: HashMap::with_capacity(cap),
90 sketch: CountMinSketch::new(cap),
91 }
92 }
93
94 fn alloc(&mut self, node: Node<K, V>) -> usize {
95 if let Some(idx) = self.free.pop() {
96 self.nodes[idx] = Some(node);
97 idx
98 } else {
99 self.nodes.push(Some(node));
100 self.nodes.len() - 1
101 }
102 }
103
104 fn dealloc(&mut self, idx: usize) -> Node<K, V> {
105 let node = self.nodes[idx]
106 .take()
107 .unwrap_or_else(|| unreachable!("arena slot must be occupied"));
108 self.free.push(idx);
109 node
110 }
111
112 fn unlink(&mut self, idx: usize) {
113 let (prev, next) = {
114 let n = self.nodes[idx]
115 .as_ref()
116 .unwrap_or_else(|| unreachable!("unlink target must be occupied"));
117 (n.prev, n.next)
118 };
119 match prev {
120 Some(p) => {
121 self.nodes[p]
122 .as_mut()
123 .unwrap_or_else(|| unreachable!())
124 .next = next
125 }
126 None => self.head = next,
127 }
128 match next {
129 Some(n) => {
130 self.nodes[n]
131 .as_mut()
132 .unwrap_or_else(|| unreachable!())
133 .prev = prev
134 }
135 None => self.tail = prev,
136 }
137 if let Some(n) = self.nodes[idx].as_mut() {
138 n.prev = None;
139 n.next = None;
140 }
141 }
142
143 fn push_front(&mut self, idx: usize) {
144 let old_head = self.head;
145 if let Some(n) = self.nodes[idx].as_mut() {
146 n.prev = None;
147 n.next = old_head;
148 }
149 if let Some(h) = old_head {
150 if let Some(n) = self.nodes[h].as_mut() {
151 n.prev = Some(idx);
152 }
153 } else {
154 self.tail = Some(idx);
155 }
156 self.head = Some(idx);
157 }
158
159 fn promote(&mut self, idx: usize) {
160 if self.head == Some(idx) {
161 return;
162 }
163 self.unlink(idx);
164 self.push_front(idx);
165 }
166}
167
168impl<K, V> TinyLfuCache<K, V>
169where
170 K: Eq + Hash + Clone,
171 V: Clone,
172{
173 pub fn new(capacity: usize) -> Result<Self, CacheError> {
186 let cap = NonZeroUsize::new(capacity).ok_or(CacheError::InvalidCapacity)?;
187 Ok(Self::with_capacity(cap))
188 }
189
190 pub fn with_capacity(capacity: NonZeroUsize) -> Self {
202 let num_shards = sharding::shard_count(capacity);
203 let per_shard = sharding::per_shard_capacity(capacity, num_shards);
204 let sharded = Sharded::from_factory(num_shards, |_| Inner::with_capacity(per_shard));
205 Self { capacity, sharded }
206 }
207}
208
209impl<K, V> Cache<K, V> for TinyLfuCache<K, V>
210where
211 K: Eq + Hash + Clone,
212 V: Clone,
213{
214 fn get(&self, key: &K) -> Option<V> {
215 let mut inner = self.sharded.shard_for(key).lock_recover();
216 inner.sketch.increment(key);
217 let idx = *inner.map.get(key)?;
218 inner.promote(idx);
219 inner.nodes[idx].as_ref().map(|n| n.value.clone())
220 }
221
222 fn insert(&self, key: K, value: V) -> Option<V> {
223 let mut inner = self.sharded.shard_for(&key).lock_recover();
224 inner.sketch.increment(&key);
225
226 if let Some(&idx) = inner.map.get(&key) {
227 let old = inner.nodes[idx]
228 .as_mut()
229 .map(|n| core::mem::replace(&mut n.value, value))
230 .unwrap_or_else(|| unreachable!("mapped index must be occupied"));
231 inner.promote(idx);
232 return Some(old);
233 }
234
235 if inner.map.len() >= inner.capacity.get() {
236 let candidate_freq = inner.sketch.estimate(&key);
237 let tail_idx = inner.tail?;
238 let victim_key = inner.nodes[tail_idx]
239 .as_ref()
240 .map(|n| n.key.clone())
241 .unwrap_or_else(|| unreachable!("tail must be occupied"));
242 let victim_freq = inner.sketch.estimate(&victim_key);
243 if candidate_freq <= victim_freq {
244 return None;
245 }
246 inner.unlink(tail_idx);
247 let _ = inner.dealloc(tail_idx);
248 let _ = inner.map.remove(&victim_key);
249 }
250
251 let idx = inner.alloc(Node {
252 key: key.clone(),
253 value,
254 prev: None,
255 next: None,
256 });
257 inner.push_front(idx);
258 let _ = inner.map.insert(key, idx);
259 None
260 }
261
262 fn remove(&self, key: &K) -> Option<V> {
263 let mut inner = self.sharded.shard_for(key).lock_recover();
264 let idx = inner.map.remove(key)?;
265 inner.unlink(idx);
266 let node = inner.dealloc(idx);
267 Some(node.value)
268 }
269
270 fn contains_key(&self, key: &K) -> bool {
271 self.sharded
272 .shard_for(key)
273 .lock_recover()
274 .map
275 .contains_key(key)
276 }
277
278 fn len(&self) -> usize {
279 self.sharded
280 .iter()
281 .map(|m| m.lock_recover().map.len())
282 .sum()
283 }
284
285 fn clear(&self) {
286 for mutex in self.sharded.iter() {
287 let mut inner = mutex.lock_recover();
288 inner.nodes.clear();
289 inner.free.clear();
290 inner.head = None;
291 inner.tail = None;
292 inner.map.clear();
293 inner.sketch.reset();
294 }
295 }
296
297 fn capacity(&self) -> usize {
298 self.capacity.get()
299 }
300}
301
302struct CountMinSketch {
307 counters: Vec<u8>,
308 width: usize,
309 width_u64: u64,
310 samples: u64,
311 sample_window: u64,
312}
313
314impl CountMinSketch {
315 fn new(capacity: usize) -> Self {
316 let mut width = capacity.saturating_mul(2).max(MIN_SKETCH_WIDTH);
317 width = width.next_power_of_two();
318 let sample_window = (capacity as u64).saturating_mul(10).max(64);
319 Self {
320 counters: vec![0; width.saturating_mul(SKETCH_DEPTH)],
321 width,
322 width_u64: width as u64,
323 samples: 0,
324 sample_window,
325 }
326 }
327
328 fn estimate<K: Hash>(&self, key: &K) -> u8 {
329 let mut min = u8::MAX;
330 for d in 0..SKETCH_DEPTH {
331 let idx = self.cell(d, key);
332 let observed = *self.counters.get(idx).unwrap_or(&0);
333 if observed < min {
334 min = observed;
335 }
336 }
337 min
338 }
339
340 fn increment<K: Hash>(&mut self, key: &K) {
341 for d in 0..SKETCH_DEPTH {
342 let idx = self.cell(d, key);
343 if let Some(slot) = self.counters.get_mut(idx) {
344 *slot = slot.saturating_add(1);
345 }
346 }
347 self.samples = self.samples.saturating_add(1);
348 if self.samples >= self.sample_window {
349 self.age();
350 self.samples = 0;
351 }
352 }
353
354 fn reset(&mut self) {
355 for c in self.counters.iter_mut() {
356 *c = 0;
357 }
358 self.samples = 0;
359 }
360
361 fn age(&mut self) {
362 for c in self.counters.iter_mut() {
363 *c >>= 1;
364 }
365 }
366
367 fn cell<K: Hash>(&self, d: usize, key: &K) -> usize {
368 let h = hash_with_seed(key, d as u64);
369 let col = (h % self.width_u64) as usize;
370 d.saturating_mul(self.width).saturating_add(col)
371 }
372}
373
374fn hash_with_seed<K: Hash>(key: &K, seed: u64) -> u64 {
375 let mut hasher = DefaultHasher::new();
376 seed.hash(&mut hasher);
377 key.hash(&mut hasher);
378 hasher.finish()
379}