Skip to main content

feldera_buffer_cache/
lru.rs

1use crate::{BufferCache, BufferCacheStrategy, CacheEntry};
2use std::any::Any;
3use std::collections::BTreeMap;
4use std::fmt::Debug;
5use std::hash::RandomState;
6use std::marker::PhantomData;
7use std::ops::RangeBounds;
8use std::sync::Mutex;
9
10/// A weighted, thread-safe LRU cache.
11pub struct LruCache<K, V, S = RandomState> {
12    /// Mutable cache state guarded by a single mutex.
13    inner: Mutex<CacheInner<K, V>>,
14    /// Retains the public hash-builder type parameter used by shared builders.
15    marker: PhantomData<fn() -> S>,
16}
17
18/// Mutable state for [`LruCache`].
19struct CacheInner<K, V> {
20    /// Cache contents.
21    cache: BTreeMap<K, CacheValue<V>>,
22    /// Map from LRU serial number to cache key.
23    lru: BTreeMap<u64, K>,
24    /// Serial number to use the next time we touch a key.
25    next_serial: u64,
26    /// Sum over `cache[*].aux.cost()`.
27    cur_cost: usize,
28    /// Maximum total cost.
29    max_cost: usize,
30}
31
32/// Resident value stored by [`LruCache`].
33struct CacheValue<V> {
34    /// Cached value.
35    aux: V,
36    /// Recency serial used by the LRU queue.
37    serial: u64,
38}
39
40impl<K, V, S> LruCache<K, V, S> {
41    /// Default shard count reported by [`LruCache::shard_count`].
42    pub const DEFAULT_SHARDS: usize = 1;
43}
44
45impl<K, V> LruCache<K, V, RandomState>
46where
47    K: Ord + Clone + Debug,
48    V: CacheEntry + Clone,
49{
50    /// Creates a cache with the default hash builder.
51    pub fn new(max_cost: usize) -> Self {
52        Self::with_hasher(max_cost, RandomState::new())
53    }
54}
55
56// explicit allow, we do have `is_empty` in the trait so this is a false positive
57#[allow(clippy::len_without_is_empty)]
58impl<K, V, S> LruCache<K, V, S>
59where
60    K: Ord + Clone + Debug,
61    V: CacheEntry + Clone,
62{
63    /// Creates a cache with an explicit hash builder.
64    pub fn with_hasher(max_cost: usize, _hash_builder: S) -> Self {
65        Self {
66            inner: Mutex::new(CacheInner::new(max_cost)),
67            marker: PhantomData,
68        }
69    }
70
71    /// Inserts or replaces `key` with `value`.
72    pub fn insert(&self, key: K, value: V) {
73        self.inner.lock().unwrap().insert(key, value);
74    }
75
76    /// Looks up `key` and returns a clone of the stored value.
77    pub fn get(&self, key: &K) -> Option<V> {
78        self.inner.lock().unwrap().get(key.clone())
79    }
80
81    /// Removes `key` if present and returns the removed value.
82    pub fn remove(&self, key: &K) -> Option<V> {
83        self.inner.lock().unwrap().remove(key)
84    }
85
86    /// Removes every entry whose key matches `predicate`.
87    pub fn remove_if<F>(&self, predicate: F)
88    where
89        F: Fn(&K) -> bool,
90    {
91        self.inner.lock().unwrap().remove_if(predicate)
92    }
93
94    /// Removes every entry whose key falls within `range`.
95    pub fn remove_range<R>(&self, range: R) -> usize
96    where
97        R: RangeBounds<K>,
98    {
99        self.inner.lock().unwrap().remove_range(range)
100    }
101
102    /// Returns `true` if `key` is currently resident.
103    pub fn contains_key(&self, key: &K) -> bool {
104        self.inner.lock().unwrap().contains_key(key)
105    }
106
107    /// Returns the number of resident entries.
108    pub fn len(&self) -> usize {
109        self.inner.lock().unwrap().len()
110    }
111
112    /// Returns the total resident cost.
113    pub fn total_charge(&self) -> usize {
114        self.inner.lock().unwrap().cur_cost
115    }
116
117    /// Returns the configured total cost capacity.
118    pub fn total_capacity(&self) -> usize {
119        self.inner.lock().unwrap().max_cost
120    }
121
122    /// Returns the number of shards reported by this backend.
123    pub fn shard_count(&self) -> usize {
124        Self::DEFAULT_SHARDS
125    }
126
127    /// Returns `(used_charge, capacity)` for shard `idx`.
128    ///
129    /// # Panics
130    ///
131    /// Panics if `idx != 0`.
132    #[cfg(test)]
133    pub fn shard_usage(&self, idx: usize) -> (usize, usize) {
134        assert_eq!(idx, 0, "shard index out of bounds");
135        let inner = self.inner.lock().unwrap();
136        (inner.cur_cost, inner.max_cost)
137    }
138
139    #[cfg(test)]
140    pub(crate) fn validate_invariants(&self) {
141        self.inner.lock().unwrap().check_invariants();
142    }
143}
144
145impl<K, V, S> BufferCache<K, V> for LruCache<K, V, S>
146where
147    K: Ord + Clone + Debug + Send + Sync + 'static,
148    V: CacheEntry + Clone + Send + Sync + 'static,
149    S: Send + Sync + 'static,
150{
151    fn as_any(&self) -> &dyn Any {
152        self
153    }
154
155    fn strategy(&self) -> BufferCacheStrategy {
156        BufferCacheStrategy::Lru
157    }
158
159    fn insert(&self, key: K, value: V) {
160        self.insert(key, value);
161    }
162
163    fn get(&self, key: K) -> Option<V> {
164        self.inner.lock().unwrap().get(key)
165    }
166
167    fn remove(&self, key: &K) -> Option<V> {
168        self.remove(key)
169    }
170
171    fn remove_if(&self, predicate: &dyn Fn(&K) -> bool) {
172        self.remove_if(|key| predicate(key))
173    }
174
175    fn contains_key(&self, key: &K) -> bool {
176        self.contains_key(key)
177    }
178
179    fn len(&self) -> usize {
180        self.len()
181    }
182
183    fn total_charge(&self) -> usize {
184        self.total_charge()
185    }
186
187    fn total_capacity(&self) -> usize {
188        self.total_capacity()
189    }
190
191    fn shard_count(&self) -> usize {
192        self.shard_count()
193    }
194
195    #[cfg(test)]
196    fn shard_usage(&self, idx: usize) -> (usize, usize) {
197        self.shard_usage(idx)
198    }
199}
200
201impl<K, V> CacheInner<K, V>
202where
203    K: Ord + Clone + Debug,
204    V: CacheEntry + Clone,
205{
206    /// Creates an empty cache with `max_cost` capacity.
207    fn new(max_cost: usize) -> Self {
208        Self {
209            cache: BTreeMap::new(),
210            lru: BTreeMap::new(),
211            next_serial: 0,
212            cur_cost: 0,
213            max_cost,
214        }
215    }
216
217    /// Checks the cache/LRU bookkeeping invariants.
218    #[cfg(any(test, debug_assertions))]
219    fn check_invariants(&self) {
220        assert_eq!(self.cache.len(), self.lru.len());
221        let mut cost = 0;
222        for (key, value) in self.cache.iter() {
223            assert_eq!(self.lru.get(&value.serial), Some(key));
224            cost += value.aux.cost();
225        }
226        for (serial, key) in self.lru.iter() {
227            assert_eq!(self.cache.get(key).unwrap().serial, *serial);
228        }
229        assert_eq!(cost, self.cur_cost);
230    }
231
232    /// Runs invariant checks in debug builds.
233    fn debug_check_invariants(&self) {
234        #[cfg(debug_assertions)]
235        self.check_invariants()
236    }
237
238    /// Looks up `key`, refreshes its recency, and returns the cached value.
239    fn get(&mut self, key: K) -> Option<V> {
240        if let Some(value) = self.cache.get_mut(&key) {
241            self.lru.remove(&value.serial);
242            value.serial = self.next_serial;
243            self.lru.insert(value.serial, key);
244            self.next_serial += 1;
245            Some(value.aux.clone())
246        } else {
247            None
248        }
249    }
250
251    /// Evicts least-recently-used entries until `cur_cost <= max_cost`.
252    fn evict_to(&mut self, max_cost: usize) {
253        while self.cur_cost > max_cost {
254            // lru and cache are kept in sync by all mutating methods;
255            // since cur_cost > max_cost >= 0, at least one entry exists.
256            let (_serial, key) = self.lru.pop_first().unwrap();
257            let value = self.cache.remove(&key).unwrap();
258            self.cur_cost -= value.aux.cost();
259        }
260        self.debug_check_invariants();
261    }
262
263    /// Inserts or replaces `key` with `aux`.
264    fn insert(&mut self, key: K, aux: V) {
265        let cost = aux.cost();
266        self.evict_to(self.max_cost.saturating_sub(cost));
267        if let Some(old_value) = self.cache.insert(
268            key.clone(),
269            CacheValue {
270                aux,
271                serial: self.next_serial,
272            },
273        ) {
274            self.lru.remove(&old_value.serial);
275            self.cur_cost -= old_value.aux.cost();
276        }
277        self.lru.insert(self.next_serial, key);
278        self.cur_cost += cost;
279        self.next_serial += 1;
280        self.debug_check_invariants();
281    }
282
283    /// Removes `key` if it is present and returns the removed value.
284    fn remove(&mut self, key: &K) -> Option<V> {
285        let value = self.cache.remove(key)?;
286        self.lru.remove(&value.serial).unwrap();
287        self.cur_cost -= value.aux.cost();
288        self.debug_check_invariants();
289        Some(value.aux)
290    }
291
292    /// Removes every entry whose key matches `predicate`.
293    fn remove_if<F>(&mut self, predicate: F)
294    where
295        F: Fn(&K) -> bool,
296    {
297        let keys: Vec<K> = self
298            .cache
299            .keys()
300            .filter(|key| predicate(key))
301            .cloned()
302            .collect();
303        for key in keys {
304            let _ = self.remove(&key);
305        }
306    }
307
308    /// Removes every entry whose key falls within `range`.
309    fn remove_range<R>(&mut self, range: R) -> usize
310    where
311        R: RangeBounds<K>,
312    {
313        let victims: Vec<(K, u64)> = self
314            .cache
315            .range(range)
316            .map(|(key, value)| (key.clone(), value.serial))
317            .collect();
318
319        let removed = victims.len();
320        for (key, serial) in victims {
321            self.lru.remove(&serial).unwrap();
322            self.cur_cost -= self.cache.remove(&key).unwrap().aux.cost();
323        }
324        self.debug_check_invariants();
325        removed
326    }
327
328    /// Returns `true` if `key` is resident.
329    fn contains_key(&self, key: &K) -> bool {
330        self.cache.contains_key(key)
331    }
332
333    /// Returns the number of resident entries.
334    fn len(&self) -> usize {
335        self.cache.len()
336    }
337}