Skip to main content

cache_mod/
lfu.rs

1//! Least-Frequently-Used (LFU) cache — sharded, per-shard BTreeMap priority index.
2
3use core::hash::Hash;
4use std::collections::{BTreeMap, HashMap};
5use std::num::NonZeroUsize;
6
7use crate::cache::Cache;
8use crate::error::CacheError;
9use crate::sharding::{self, Sharded};
10use crate::util::MutexExt;
11
12/// A bounded, thread-safe LFU cache.
13///
14/// Each entry carries a counter that is incremented on every [`get`](Cache::get)
15/// or [`insert`](Cache::insert) of an already-present key. On overflow, the
16/// entry with the **lowest counter** is evicted; ties are broken in favour of
17/// evicting the **least-recently-accessed** entry.
18///
19/// [`contains_key`](Cache::contains_key) is a query and does not increment
20/// the counter or touch access order.
21///
22/// # Implementation
23///
24/// Sharded into up to 16 independent stores keyed by hash of `K`. Each
25/// shard pairs a `HashMap<K, Entry<V>>` for value lookup with a
26/// `BTreeMap<(count, age), K>` ordered priority index for O(log n) eviction.
27///
28/// Eviction is **per-shard approximate** LFU — the entry evicted on
29/// overflow is the lowest-counter entry in the affected shard, not
30/// necessarily the lowest-counter entry globally. Tiny caches (< 32 entries)
31/// use a single shard and retain strict global semantics.
32///
33/// # Example
34///
35/// ```
36/// use cache_mod::{Cache, LfuCache};
37///
38/// let cache: LfuCache<&'static str, u32> = LfuCache::new(2).expect("capacity > 0");
39///
40/// cache.insert("a", 1);
41/// cache.insert("b", 2);
42///
43/// assert_eq!(cache.get(&"a"), Some(1));
44/// assert_eq!(cache.get(&"a"), Some(1));
45///
46/// cache.insert("c", 3);
47/// assert_eq!(cache.get(&"b"), None);
48/// assert_eq!(cache.get(&"a"), Some(1));
49/// assert_eq!(cache.get(&"c"), Some(3));
50/// ```
51pub struct LfuCache<K, V> {
52    capacity: NonZeroUsize,
53    sharded: Sharded<Inner<K, V>>,
54}
55
56struct Entry<V> {
57    value: V,
58    count: u64,
59    age: u64,
60}
61
62struct Inner<K, V> {
63    capacity: NonZeroUsize,
64    map: HashMap<K, Entry<V>>,
65    by_priority: BTreeMap<(u64, u64), K>,
66    clock: u64,
67}
68
69impl<K, V> Inner<K, V>
70where
71    K: Eq + Hash + Clone,
72{
73    fn with_capacity(capacity: NonZeroUsize) -> Self {
74        let cap = capacity.get();
75        Self {
76            capacity,
77            map: HashMap::with_capacity(cap),
78            by_priority: BTreeMap::new(),
79            clock: 0,
80        }
81    }
82
83    fn tick(&mut self) -> u64 {
84        self.clock = self.clock.wrapping_add(1);
85        self.clock
86    }
87}
88
89impl<K, V> LfuCache<K, V>
90where
91    K: Eq + Hash + Clone,
92    V: Clone,
93{
94    /// Creates a cache with the given capacity.
95    ///
96    /// Returns [`CacheError::InvalidCapacity`] if `capacity == 0`.
97    ///
98    /// # Example
99    ///
100    /// ```
101    /// use cache_mod::LfuCache;
102    ///
103    /// let cache: LfuCache<String, u32> = LfuCache::new(128).expect("capacity > 0");
104    /// ```
105    pub fn new(capacity: usize) -> Result<Self, CacheError> {
106        let cap = NonZeroUsize::new(capacity).ok_or(CacheError::InvalidCapacity)?;
107        Ok(Self::with_capacity(cap))
108    }
109
110    /// Creates a cache with the given non-zero capacity. Infallible.
111    ///
112    /// # Example
113    ///
114    /// ```
115    /// use std::num::NonZeroUsize;
116    /// use cache_mod::LfuCache;
117    ///
118    /// let cap = NonZeroUsize::new(64).expect("64 != 0");
119    /// let cache: LfuCache<String, u32> = LfuCache::with_capacity(cap);
120    /// ```
121    pub fn with_capacity(capacity: NonZeroUsize) -> Self {
122        let num_shards = sharding::shard_count(capacity);
123        let per_shard = sharding::per_shard_capacity(capacity, num_shards);
124        let sharded = Sharded::from_factory(num_shards, |_| Inner::with_capacity(per_shard));
125        Self { capacity, sharded }
126    }
127}
128
129impl<K, V> Cache<K, V> for LfuCache<K, V>
130where
131    K: Eq + Hash + Clone,
132    V: Clone,
133{
134    fn get(&self, key: &K) -> Option<V> {
135        let mut inner = self.sharded.shard_for(key).lock_recover();
136        let new_age = inner.tick();
137
138        let (old_priority, new_priority, value) = {
139            let entry = inner.map.get_mut(key)?;
140            let old = (entry.count, entry.age);
141            entry.count = entry.count.saturating_add(1);
142            entry.age = new_age;
143            let new = (entry.count, entry.age);
144            (old, new, entry.value.clone())
145        };
146
147        let _ = inner.by_priority.remove(&old_priority);
148        let _ = inner.by_priority.insert(new_priority, key.clone());
149        Some(value)
150    }
151
152    fn insert(&self, key: K, value: V) -> Option<V> {
153        let mut inner = self.sharded.shard_for(&key).lock_recover();
154        let new_age = inner.tick();
155
156        // Live update.
157        if let Some(entry) = inner.map.get_mut(&key) {
158            let old_priority = (entry.count, entry.age);
159            entry.count = entry.count.saturating_add(1);
160            entry.age = new_age;
161            let new_priority = (entry.count, entry.age);
162            let old_value = core::mem::replace(&mut entry.value, value);
163            let _ = inner.by_priority.remove(&old_priority);
164            let _ = inner.by_priority.insert(new_priority, key);
165            return Some(old_value);
166        }
167
168        // New key — evict if at per-shard capacity.
169        if inner.map.len() >= inner.capacity.get() {
170            if let Some((_, victim_key)) = inner.by_priority.pop_first() {
171                let _ = inner.map.remove(&victim_key);
172            }
173        }
174
175        let entry = Entry {
176            value,
177            count: 1,
178            age: new_age,
179        };
180        let priority = (entry.count, entry.age);
181        let _ = inner.map.insert(key.clone(), entry);
182        let _ = inner.by_priority.insert(priority, key);
183        None
184    }
185
186    fn remove(&self, key: &K) -> Option<V> {
187        let mut inner = self.sharded.shard_for(key).lock_recover();
188        let entry = inner.map.remove(key)?;
189        let _ = inner.by_priority.remove(&(entry.count, entry.age));
190        Some(entry.value)
191    }
192
193    fn contains_key(&self, key: &K) -> bool {
194        self.sharded
195            .shard_for(key)
196            .lock_recover()
197            .map
198            .contains_key(key)
199    }
200
201    fn len(&self) -> usize {
202        self.sharded
203            .iter()
204            .map(|m| m.lock_recover().map.len())
205            .sum()
206    }
207
208    fn clear(&self) {
209        for mutex in self.sharded.iter() {
210            let mut inner = mutex.lock_recover();
211            inner.map.clear();
212            inner.by_priority.clear();
213            inner.clock = 0;
214        }
215    }
216
217    fn capacity(&self) -> usize {
218        self.capacity.get()
219    }
220}