Skip to main content

cache_mod/
tinylfu.rs

1//! TinyLFU cache — sharded arena-backed LRU main + per-shard Count-Min Sketch.
2
3use core::hash::{Hash, Hasher};
4use std::collections::hash_map::DefaultHasher;
5use std::collections::HashMap;
6use std::num::NonZeroUsize;
7
8use crate::cache::Cache;
9use crate::error::CacheError;
10use crate::sharding::{self, Sharded};
11use crate::util::MutexExt;
12
13const SKETCH_DEPTH: usize = 4;
14const MIN_SKETCH_WIDTH: usize = 64;
15
16/// A bounded, thread-safe cache with **admission control**.
17///
18/// `TinyLfuCache` tracks the access frequency of *every* key it observes —
19/// including keys that aren't (yet) in the cache — using a Count-Min Sketch.
20/// On capacity overflow, an incoming key is **admitted only if its
21/// estimated frequency exceeds the LRU victim's**. One-hit-wonders are
22/// rejected at the door instead of evicting hot entries.
23///
24/// A successful [`insert`](Cache::insert) call **does not guarantee** the
25/// value is in the cache. The admission filter may reject it. Callers that
26/// need strict insertion guarantees should use `LruCache` or `LfuCache`.
27///
28/// # Implementation
29///
30/// Sharded into up to 16 independent arenas keyed by hash of `K`. **Each
31/// shard owns its own Count-Min Sketch** — the frequency signal is
32/// per-shard, not global. This is a deliberate trade-off: a global sketch
33/// would force every access to lock a shared structure, defeating the
34/// point of sharding. Per-shard sketches still capture the local frequency
35/// signal accurately, which is what the local admission decision needs.
36///
37/// Eviction is approximate (per-shard LRU). Tiny caches (< 32 entries)
38/// use a single shard and retain strict global semantics.
39///
40/// # Example
41///
42/// ```
43/// use cache_mod::{Cache, TinyLfuCache};
44///
45/// let cache: TinyLfuCache<&'static str, u32> =
46///     TinyLfuCache::new(4).expect("capacity > 0");
47///
48/// for _ in 0..16 {
49///     let _ = cache.get(&"hot");
50///     let _ = cache.insert("hot", 1);
51/// }
52///
53/// assert_eq!(cache.get(&"hot"), Some(1));
54/// ```
55pub struct TinyLfuCache<K, V> {
56    capacity: NonZeroUsize,
57    sharded: Sharded<Inner<K, V>>,
58}
59
60struct Node<K, V> {
61    key: K,
62    value: V,
63    prev: Option<usize>,
64    next: Option<usize>,
65}
66
67struct Inner<K, V> {
68    capacity: NonZeroUsize,
69    nodes: Vec<Option<Node<K, V>>>,
70    free: Vec<usize>,
71    head: Option<usize>,
72    tail: Option<usize>,
73    map: HashMap<K, usize>,
74    sketch: CountMinSketch,
75}
76
77impl<K, V> Inner<K, V>
78where
79    K: Eq + Hash + Clone,
80{
81    fn with_capacity(capacity: NonZeroUsize) -> Self {
82        let cap = capacity.get();
83        Self {
84            capacity,
85            nodes: Vec::with_capacity(cap),
86            free: Vec::new(),
87            head: None,
88            tail: None,
89            map: HashMap::with_capacity(cap),
90            sketch: CountMinSketch::new(cap),
91        }
92    }
93
94    fn alloc(&mut self, node: Node<K, V>) -> usize {
95        if let Some(idx) = self.free.pop() {
96            self.nodes[idx] = Some(node);
97            idx
98        } else {
99            self.nodes.push(Some(node));
100            self.nodes.len() - 1
101        }
102    }
103
104    fn dealloc(&mut self, idx: usize) -> Node<K, V> {
105        let node = self.nodes[idx]
106            .take()
107            .unwrap_or_else(|| unreachable!("arena slot must be occupied"));
108        self.free.push(idx);
109        node
110    }
111
112    fn unlink(&mut self, idx: usize) {
113        let (prev, next) = {
114            let n = self.nodes[idx]
115                .as_ref()
116                .unwrap_or_else(|| unreachable!("unlink target must be occupied"));
117            (n.prev, n.next)
118        };
119        match prev {
120            Some(p) => {
121                self.nodes[p]
122                    .as_mut()
123                    .unwrap_or_else(|| unreachable!())
124                    .next = next
125            }
126            None => self.head = next,
127        }
128        match next {
129            Some(n) => {
130                self.nodes[n]
131                    .as_mut()
132                    .unwrap_or_else(|| unreachable!())
133                    .prev = prev
134            }
135            None => self.tail = prev,
136        }
137        if let Some(n) = self.nodes[idx].as_mut() {
138            n.prev = None;
139            n.next = None;
140        }
141    }
142
143    fn push_front(&mut self, idx: usize) {
144        let old_head = self.head;
145        if let Some(n) = self.nodes[idx].as_mut() {
146            n.prev = None;
147            n.next = old_head;
148        }
149        if let Some(h) = old_head {
150            if let Some(n) = self.nodes[h].as_mut() {
151                n.prev = Some(idx);
152            }
153        } else {
154            self.tail = Some(idx);
155        }
156        self.head = Some(idx);
157    }
158
159    fn promote(&mut self, idx: usize) {
160        if self.head == Some(idx) {
161            return;
162        }
163        self.unlink(idx);
164        self.push_front(idx);
165    }
166}
167
168impl<K, V> TinyLfuCache<K, V>
169where
170    K: Eq + Hash + Clone,
171    V: Clone,
172{
173    /// Creates a cache with the given entry-count capacity.
174    ///
175    /// Returns [`CacheError::InvalidCapacity`] if `capacity == 0`.
176    ///
177    /// # Example
178    ///
179    /// ```
180    /// use cache_mod::TinyLfuCache;
181    ///
182    /// let cache: TinyLfuCache<String, u32> =
183    ///     TinyLfuCache::new(256).expect("capacity > 0");
184    /// ```
185    pub fn new(capacity: usize) -> Result<Self, CacheError> {
186        let cap = NonZeroUsize::new(capacity).ok_or(CacheError::InvalidCapacity)?;
187        Ok(Self::with_capacity(cap))
188    }
189
190    /// Creates a cache with the given non-zero capacity. Infallible.
191    ///
192    /// # Example
193    ///
194    /// ```
195    /// use std::num::NonZeroUsize;
196    /// use cache_mod::TinyLfuCache;
197    ///
198    /// let cap = NonZeroUsize::new(256).expect("256 != 0");
199    /// let cache: TinyLfuCache<String, u32> = TinyLfuCache::with_capacity(cap);
200    /// ```
201    pub fn with_capacity(capacity: NonZeroUsize) -> Self {
202        let num_shards = sharding::shard_count(capacity);
203        let per_shard = sharding::per_shard_capacity(capacity, num_shards);
204        let sharded = Sharded::from_factory(num_shards, |_| Inner::with_capacity(per_shard));
205        Self { capacity, sharded }
206    }
207}
208
209impl<K, V> Cache<K, V> for TinyLfuCache<K, V>
210where
211    K: Eq + Hash + Clone,
212    V: Clone,
213{
214    fn get(&self, key: &K) -> Option<V> {
215        let mut inner = self.sharded.shard_for(key).lock_recover();
216        inner.sketch.increment(key);
217        let idx = *inner.map.get(key)?;
218        inner.promote(idx);
219        inner.nodes[idx].as_ref().map(|n| n.value.clone())
220    }
221
222    fn insert(&self, key: K, value: V) -> Option<V> {
223        let mut inner = self.sharded.shard_for(&key).lock_recover();
224        inner.sketch.increment(&key);
225
226        if let Some(&idx) = inner.map.get(&key) {
227            let old = inner.nodes[idx]
228                .as_mut()
229                .map(|n| core::mem::replace(&mut n.value, value))
230                .unwrap_or_else(|| unreachable!("mapped index must be occupied"));
231            inner.promote(idx);
232            return Some(old);
233        }
234
235        if inner.map.len() >= inner.capacity.get() {
236            let candidate_freq = inner.sketch.estimate(&key);
237            let tail_idx = inner.tail?;
238            let victim_key = inner.nodes[tail_idx]
239                .as_ref()
240                .map(|n| n.key.clone())
241                .unwrap_or_else(|| unreachable!("tail must be occupied"));
242            let victim_freq = inner.sketch.estimate(&victim_key);
243            if candidate_freq <= victim_freq {
244                return None;
245            }
246            inner.unlink(tail_idx);
247            let _ = inner.dealloc(tail_idx);
248            let _ = inner.map.remove(&victim_key);
249        }
250
251        let idx = inner.alloc(Node {
252            key: key.clone(),
253            value,
254            prev: None,
255            next: None,
256        });
257        inner.push_front(idx);
258        let _ = inner.map.insert(key, idx);
259        None
260    }
261
262    fn remove(&self, key: &K) -> Option<V> {
263        let mut inner = self.sharded.shard_for(key).lock_recover();
264        let idx = inner.map.remove(key)?;
265        inner.unlink(idx);
266        let node = inner.dealloc(idx);
267        Some(node.value)
268    }
269
270    fn contains_key(&self, key: &K) -> bool {
271        self.sharded
272            .shard_for(key)
273            .lock_recover()
274            .map
275            .contains_key(key)
276    }
277
278    fn len(&self) -> usize {
279        self.sharded
280            .iter()
281            .map(|m| m.lock_recover().map.len())
282            .sum()
283    }
284
285    fn clear(&self) {
286        for mutex in self.sharded.iter() {
287            let mut inner = mutex.lock_recover();
288            inner.nodes.clear();
289            inner.free.clear();
290            inner.head = None;
291            inner.tail = None;
292            inner.map.clear();
293            inner.sketch.reset();
294        }
295    }
296
297    fn capacity(&self) -> usize {
298        self.capacity.get()
299    }
300}
301
302// -----------------------------------------------------------------------------
303// Count-Min Sketch (per-shard)
304// -----------------------------------------------------------------------------
305
306struct CountMinSketch {
307    counters: Vec<u8>,
308    width: usize,
309    width_u64: u64,
310    samples: u64,
311    sample_window: u64,
312}
313
314impl CountMinSketch {
315    fn new(capacity: usize) -> Self {
316        let mut width = capacity.saturating_mul(2).max(MIN_SKETCH_WIDTH);
317        width = width.next_power_of_two();
318        let sample_window = (capacity as u64).saturating_mul(10).max(64);
319        Self {
320            counters: vec![0; width.saturating_mul(SKETCH_DEPTH)],
321            width,
322            width_u64: width as u64,
323            samples: 0,
324            sample_window,
325        }
326    }
327
328    fn estimate<K: Hash>(&self, key: &K) -> u8 {
329        let mut min = u8::MAX;
330        for d in 0..SKETCH_DEPTH {
331            let idx = self.cell(d, key);
332            let observed = *self.counters.get(idx).unwrap_or(&0);
333            if observed < min {
334                min = observed;
335            }
336        }
337        min
338    }
339
340    fn increment<K: Hash>(&mut self, key: &K) {
341        for d in 0..SKETCH_DEPTH {
342            let idx = self.cell(d, key);
343            if let Some(slot) = self.counters.get_mut(idx) {
344                *slot = slot.saturating_add(1);
345            }
346        }
347        self.samples = self.samples.saturating_add(1);
348        if self.samples >= self.sample_window {
349            self.age();
350            self.samples = 0;
351        }
352    }
353
354    fn reset(&mut self) {
355        for c in self.counters.iter_mut() {
356            *c = 0;
357        }
358        self.samples = 0;
359    }
360
361    fn age(&mut self) {
362        for c in self.counters.iter_mut() {
363            *c >>= 1;
364        }
365    }
366
367    fn cell<K: Hash>(&self, d: usize, key: &K) -> usize {
368        let h = hash_with_seed(key, d as u64);
369        let col = (h % self.width_u64) as usize;
370        d.saturating_mul(self.width).saturating_add(col)
371    }
372}
373
374fn hash_with_seed<K: Hash>(key: &K, seed: u64) -> u64 {
375    let mut hasher = DefaultHasher::new();
376    seed.hash(&mut hasher);
377    key.hash(&mut hasher);
378    hasher.finish()
379}