1use core::hash::Hash;
4use std::collections::{BTreeMap, HashMap};
5use std::num::NonZeroUsize;
6
7use crate::cache::Cache;
8use crate::error::CacheError;
9use crate::sharding::{self, Sharded};
10use crate::util::MutexExt;
11
12pub struct LfuCache<K, V> {
52 capacity: NonZeroUsize,
53 sharded: Sharded<Inner<K, V>>,
54}
55
56struct Entry<V> {
57 value: V,
58 count: u64,
59 age: u64,
60}
61
62struct Inner<K, V> {
63 capacity: NonZeroUsize,
64 map: HashMap<K, Entry<V>>,
65 by_priority: BTreeMap<(u64, u64), K>,
66 clock: u64,
67}
68
69impl<K, V> Inner<K, V>
70where
71 K: Eq + Hash + Clone,
72{
73 fn with_capacity(capacity: NonZeroUsize) -> Self {
74 let cap = capacity.get();
75 Self {
76 capacity,
77 map: HashMap::with_capacity(cap),
78 by_priority: BTreeMap::new(),
79 clock: 0,
80 }
81 }
82
83 fn tick(&mut self) -> u64 {
84 self.clock = self.clock.wrapping_add(1);
85 self.clock
86 }
87}
88
89impl<K, V> LfuCache<K, V>
90where
91 K: Eq + Hash + Clone,
92 V: Clone,
93{
94 pub fn new(capacity: usize) -> Result<Self, CacheError> {
106 let cap = NonZeroUsize::new(capacity).ok_or(CacheError::InvalidCapacity)?;
107 Ok(Self::with_capacity(cap))
108 }
109
110 pub fn with_capacity(capacity: NonZeroUsize) -> Self {
122 let num_shards = sharding::shard_count(capacity);
123 let per_shard = sharding::per_shard_capacity(capacity, num_shards);
124 let sharded = Sharded::from_factory(num_shards, |_| Inner::with_capacity(per_shard));
125 Self { capacity, sharded }
126 }
127}
128
129impl<K, V> Cache<K, V> for LfuCache<K, V>
130where
131 K: Eq + Hash + Clone,
132 V: Clone,
133{
134 fn get(&self, key: &K) -> Option<V> {
135 let mut inner = self.sharded.shard_for(key).lock_recover();
136 let new_age = inner.tick();
137
138 let (old_priority, new_priority, value) = {
139 let entry = inner.map.get_mut(key)?;
140 let old = (entry.count, entry.age);
141 entry.count = entry.count.saturating_add(1);
142 entry.age = new_age;
143 let new = (entry.count, entry.age);
144 (old, new, entry.value.clone())
145 };
146
147 let _ = inner.by_priority.remove(&old_priority);
148 let _ = inner.by_priority.insert(new_priority, key.clone());
149 Some(value)
150 }
151
152 fn insert(&self, key: K, value: V) -> Option<V> {
153 let mut inner = self.sharded.shard_for(&key).lock_recover();
154 let new_age = inner.tick();
155
156 if let Some(entry) = inner.map.get_mut(&key) {
158 let old_priority = (entry.count, entry.age);
159 entry.count = entry.count.saturating_add(1);
160 entry.age = new_age;
161 let new_priority = (entry.count, entry.age);
162 let old_value = core::mem::replace(&mut entry.value, value);
163 let _ = inner.by_priority.remove(&old_priority);
164 let _ = inner.by_priority.insert(new_priority, key);
165 return Some(old_value);
166 }
167
168 if inner.map.len() >= inner.capacity.get() {
170 if let Some((_, victim_key)) = inner.by_priority.pop_first() {
171 let _ = inner.map.remove(&victim_key);
172 }
173 }
174
175 let entry = Entry {
176 value,
177 count: 1,
178 age: new_age,
179 };
180 let priority = (entry.count, entry.age);
181 let _ = inner.map.insert(key.clone(), entry);
182 let _ = inner.by_priority.insert(priority, key);
183 None
184 }
185
186 fn remove(&self, key: &K) -> Option<V> {
187 let mut inner = self.sharded.shard_for(key).lock_recover();
188 let entry = inner.map.remove(key)?;
189 let _ = inner.by_priority.remove(&(entry.count, entry.age));
190 Some(entry.value)
191 }
192
193 fn contains_key(&self, key: &K) -> bool {
194 self.sharded
195 .shard_for(key)
196 .lock_recover()
197 .map
198 .contains_key(key)
199 }
200
201 fn len(&self) -> usize {
202 self.sharded
203 .iter()
204 .map(|m| m.lock_recover().map.len())
205 .sum()
206 }
207
208 fn clear(&self) {
209 for mutex in self.sharded.iter() {
210 let mut inner = mutex.lock_recover();
211 inner.map.clear();
212 inner.by_priority.clear();
213 inner.clock = 0;
214 }
215 }
216
217 fn capacity(&self) -> usize {
218 self.capacity.get()
219 }
220}