Skip to main content

feldera_buffer_cache/
lru.rs

1use feldera_types::config::dev_tweaks::BufferCacheStrategy;
2
3use crate::{BufferCache, CacheEntry};
4use std::any::Any;
5use std::collections::BTreeMap;
6use std::fmt::Debug;
7use std::hash::RandomState;
8use std::marker::PhantomData;
9use std::ops::RangeBounds;
10use std::sync::Mutex;
11
12/// A weighted, thread-safe LRU cache.
13pub struct LruCache<K, V, S = RandomState> {
14    /// Mutable cache state guarded by a single mutex.
15    inner: Mutex<CacheInner<K, V>>,
16    /// Retains the public hash-builder type parameter used by shared builders.
17    marker: PhantomData<fn() -> S>,
18}
19
20/// Mutable state for [`LruCache`].
21struct CacheInner<K, V> {
22    /// Cache contents.
23    cache: BTreeMap<K, CacheValue<V>>,
24    /// Map from LRU serial number to cache key.
25    lru: BTreeMap<u64, K>,
26    /// Serial number to use the next time we touch a key.
27    next_serial: u64,
28    /// Sum over `cache[*].aux.cost()`.
29    cur_cost: usize,
30    /// Maximum total cost.
31    max_cost: usize,
32}
33
34/// Resident value stored by [`LruCache`].
35struct CacheValue<V> {
36    /// Cached value.
37    aux: V,
38    /// Recency serial used by the LRU queue.
39    serial: u64,
40}
41
42impl<K, V, S> LruCache<K, V, S> {
43    /// Default shard count reported by [`LruCache::shard_count`].
44    pub const DEFAULT_SHARDS: usize = 1;
45}
46
47impl<K, V> LruCache<K, V, RandomState>
48where
49    K: Ord + Clone + Debug,
50    V: CacheEntry + Clone,
51{
52    /// Creates a cache with the default hash builder.
53    pub fn new(max_cost: usize) -> Self {
54        Self::with_hasher(max_cost, RandomState::new())
55    }
56}
57
58// explicit allow, we do have `is_empty` in the trait so this is a false positive
59#[allow(clippy::len_without_is_empty)]
60impl<K, V, S> LruCache<K, V, S>
61where
62    K: Ord + Clone + Debug,
63    V: CacheEntry + Clone,
64{
65    /// Creates a cache with an explicit hash builder.
66    pub fn with_hasher(max_cost: usize, _hash_builder: S) -> Self {
67        Self {
68            inner: Mutex::new(CacheInner::new(max_cost)),
69            marker: PhantomData,
70        }
71    }
72
73    /// Inserts or replaces `key` with `value`.
74    pub fn insert(&self, key: K, value: V) {
75        self.inner.lock().unwrap().insert(key, value);
76    }
77
78    /// Looks up `key` and returns a clone of the stored value.
79    pub fn get(&self, key: &K) -> Option<V> {
80        self.inner.lock().unwrap().get(key.clone())
81    }
82
83    /// Removes `key` if present and returns the removed value.
84    pub fn remove(&self, key: &K) -> Option<V> {
85        self.inner.lock().unwrap().remove(key)
86    }
87
88    /// Removes every entry whose key matches `predicate`.
89    pub fn remove_if<F>(&self, predicate: F)
90    where
91        F: Fn(&K) -> bool,
92    {
93        self.inner.lock().unwrap().remove_if(predicate)
94    }
95
96    /// Removes every entry whose key falls within `range`.
97    pub fn remove_range<R>(&self, range: R) -> usize
98    where
99        R: RangeBounds<K>,
100    {
101        self.inner.lock().unwrap().remove_range(range)
102    }
103
104    /// Returns `true` if `key` is currently resident.
105    pub fn contains_key(&self, key: &K) -> bool {
106        self.inner.lock().unwrap().contains_key(key)
107    }
108
109    /// Returns the number of resident entries.
110    pub fn len(&self) -> usize {
111        self.inner.lock().unwrap().len()
112    }
113
114    /// Returns the total resident cost.
115    pub fn total_charge(&self) -> usize {
116        self.inner.lock().unwrap().cur_cost
117    }
118
119    /// Returns the configured total cost capacity.
120    pub fn total_capacity(&self) -> usize {
121        self.inner.lock().unwrap().max_cost
122    }
123
124    /// Returns the number of shards reported by this backend.
125    pub fn shard_count(&self) -> usize {
126        Self::DEFAULT_SHARDS
127    }
128
129    /// Returns `(used_charge, capacity)` for shard `idx`.
130    ///
131    /// # Panics
132    ///
133    /// Panics if `idx != 0`.
134    #[cfg(test)]
135    pub fn shard_usage(&self, idx: usize) -> (usize, usize) {
136        assert_eq!(idx, 0, "shard index out of bounds");
137        let inner = self.inner.lock().unwrap();
138        (inner.cur_cost, inner.max_cost)
139    }
140
141    #[cfg(test)]
142    pub(crate) fn validate_invariants(&self) {
143        self.inner.lock().unwrap().check_invariants();
144    }
145}
146
147impl<K, V, S> BufferCache<K, V> for LruCache<K, V, S>
148where
149    K: Ord + Clone + Debug + Send + Sync + 'static,
150    V: CacheEntry + Clone + Send + Sync + 'static,
151    S: Send + Sync + 'static,
152{
153    fn as_any(&self) -> &dyn Any {
154        self
155    }
156
157    fn strategy(&self) -> BufferCacheStrategy {
158        BufferCacheStrategy::Lru
159    }
160
161    fn insert(&self, key: K, value: V) {
162        self.insert(key, value);
163    }
164
165    fn get(&self, key: K) -> Option<V> {
166        self.inner.lock().unwrap().get(key)
167    }
168
169    fn remove(&self, key: &K) -> Option<V> {
170        self.remove(key)
171    }
172
173    fn remove_if(&self, predicate: &dyn Fn(&K) -> bool) {
174        self.remove_if(|key| predicate(key))
175    }
176
177    fn contains_key(&self, key: &K) -> bool {
178        self.contains_key(key)
179    }
180
181    fn len(&self) -> usize {
182        self.len()
183    }
184
185    fn total_charge(&self) -> usize {
186        self.total_charge()
187    }
188
189    fn total_capacity(&self) -> usize {
190        self.total_capacity()
191    }
192
193    fn shard_count(&self) -> usize {
194        self.shard_count()
195    }
196
197    #[cfg(test)]
198    fn shard_usage(&self, idx: usize) -> (usize, usize) {
199        self.shard_usage(idx)
200    }
201}
202
203impl<K, V> CacheInner<K, V>
204where
205    K: Ord + Clone + Debug,
206    V: CacheEntry + Clone,
207{
208    /// Creates an empty cache with `max_cost` capacity.
209    fn new(max_cost: usize) -> Self {
210        Self {
211            cache: BTreeMap::new(),
212            lru: BTreeMap::new(),
213            next_serial: 0,
214            cur_cost: 0,
215            max_cost,
216        }
217    }
218
219    /// Checks the cache/LRU bookkeeping invariants.
220    #[cfg(any(test, debug_assertions))]
221    fn check_invariants(&self) {
222        assert_eq!(self.cache.len(), self.lru.len());
223        let mut cost = 0;
224        for (key, value) in self.cache.iter() {
225            assert_eq!(self.lru.get(&value.serial), Some(key));
226            cost += value.aux.cost();
227        }
228        for (serial, key) in self.lru.iter() {
229            assert_eq!(self.cache.get(key).unwrap().serial, *serial);
230        }
231        assert_eq!(cost, self.cur_cost);
232    }
233
234    /// Runs invariant checks in debug builds.
235    fn debug_check_invariants(&self) {
236        #[cfg(debug_assertions)]
237        self.check_invariants()
238    }
239
240    /// Looks up `key`, refreshes its recency, and returns the cached value.
241    fn get(&mut self, key: K) -> Option<V> {
242        if let Some(value) = self.cache.get_mut(&key) {
243            self.lru.remove(&value.serial);
244            value.serial = self.next_serial;
245            self.lru.insert(value.serial, key);
246            self.next_serial += 1;
247            Some(value.aux.clone())
248        } else {
249            None
250        }
251    }
252
253    /// Evicts least-recently-used entries until `cur_cost <= max_cost`.
254    fn evict_to(&mut self, max_cost: usize) {
255        while self.cur_cost > max_cost {
256            // lru and cache are kept in sync by all mutating methods;
257            // since cur_cost > max_cost >= 0, at least one entry exists.
258            let (_serial, key) = self.lru.pop_first().unwrap();
259            let value = self.cache.remove(&key).unwrap();
260            self.cur_cost -= value.aux.cost();
261        }
262        self.debug_check_invariants();
263    }
264
265    /// Inserts or replaces `key` with `aux`.
266    fn insert(&mut self, key: K, aux: V) {
267        let cost = aux.cost();
268        self.evict_to(self.max_cost.saturating_sub(cost));
269        if let Some(old_value) = self.cache.insert(
270            key.clone(),
271            CacheValue {
272                aux,
273                serial: self.next_serial,
274            },
275        ) {
276            self.lru.remove(&old_value.serial);
277            self.cur_cost -= old_value.aux.cost();
278        }
279        self.lru.insert(self.next_serial, key);
280        self.cur_cost += cost;
281        self.next_serial += 1;
282        self.debug_check_invariants();
283    }
284
285    /// Removes `key` if it is present and returns the removed value.
286    fn remove(&mut self, key: &K) -> Option<V> {
287        let value = self.cache.remove(key)?;
288        self.lru.remove(&value.serial).unwrap();
289        self.cur_cost -= value.aux.cost();
290        self.debug_check_invariants();
291        Some(value.aux)
292    }
293
294    /// Removes every entry whose key matches `predicate`.
295    fn remove_if<F>(&mut self, predicate: F)
296    where
297        F: Fn(&K) -> bool,
298    {
299        let keys: Vec<K> = self
300            .cache
301            .keys()
302            .filter(|key| predicate(key))
303            .cloned()
304            .collect();
305        for key in keys {
306            let _ = self.remove(&key);
307        }
308    }
309
310    /// Removes every entry whose key falls within `range`.
311    fn remove_range<R>(&mut self, range: R) -> usize
312    where
313        R: RangeBounds<K>,
314    {
315        let victims: Vec<(K, u64)> = self
316            .cache
317            .range(range)
318            .map(|(key, value)| (key.clone(), value.serial))
319            .collect();
320
321        let removed = victims.len();
322        for (key, serial) in victims {
323            self.lru.remove(&serial).unwrap();
324            self.cur_cost -= self.cache.remove(&key).unwrap().aux.cost();
325        }
326        self.debug_check_invariants();
327        removed
328    }
329
330    /// Returns `true` if `key` is resident.
331    fn contains_key(&self, key: &K) -> bool {
332        self.cache.contains_key(key)
333    }
334
335    /// Returns the number of resident entries.
336    fn len(&self) -> usize {
337        self.cache.len()
338    }
339}