1use feldera_types::config::dev_tweaks::BufferCacheStrategy;
2
3use crate::{BufferCache, CacheEntry};
4use std::any::Any;
5use std::collections::BTreeMap;
6use std::fmt::Debug;
7use std::hash::RandomState;
8use std::marker::PhantomData;
9use std::ops::RangeBounds;
10use std::sync::Mutex;
11
12pub struct LruCache<K, V, S = RandomState> {
14 inner: Mutex<CacheInner<K, V>>,
16 marker: PhantomData<fn() -> S>,
18}
19
20struct CacheInner<K, V> {
22 cache: BTreeMap<K, CacheValue<V>>,
24 lru: BTreeMap<u64, K>,
26 next_serial: u64,
28 cur_cost: usize,
30 max_cost: usize,
32}
33
34struct CacheValue<V> {
36 aux: V,
38 serial: u64,
40}
41
42impl<K, V, S> LruCache<K, V, S> {
43 pub const DEFAULT_SHARDS: usize = 1;
45}
46
47impl<K, V> LruCache<K, V, RandomState>
48where
49 K: Ord + Clone + Debug,
50 V: CacheEntry + Clone,
51{
52 pub fn new(max_cost: usize) -> Self {
54 Self::with_hasher(max_cost, RandomState::new())
55 }
56}
57
58#[allow(clippy::len_without_is_empty)]
60impl<K, V, S> LruCache<K, V, S>
61where
62 K: Ord + Clone + Debug,
63 V: CacheEntry + Clone,
64{
65 pub fn with_hasher(max_cost: usize, _hash_builder: S) -> Self {
67 Self {
68 inner: Mutex::new(CacheInner::new(max_cost)),
69 marker: PhantomData,
70 }
71 }
72
73 pub fn insert(&self, key: K, value: V) {
75 self.inner.lock().unwrap().insert(key, value);
76 }
77
78 pub fn get(&self, key: &K) -> Option<V> {
80 self.inner.lock().unwrap().get(key.clone())
81 }
82
83 pub fn remove(&self, key: &K) -> Option<V> {
85 self.inner.lock().unwrap().remove(key)
86 }
87
88 pub fn remove_if<F>(&self, predicate: F)
90 where
91 F: Fn(&K) -> bool,
92 {
93 self.inner.lock().unwrap().remove_if(predicate)
94 }
95
96 pub fn remove_range<R>(&self, range: R) -> usize
98 where
99 R: RangeBounds<K>,
100 {
101 self.inner.lock().unwrap().remove_range(range)
102 }
103
104 pub fn contains_key(&self, key: &K) -> bool {
106 self.inner.lock().unwrap().contains_key(key)
107 }
108
109 pub fn len(&self) -> usize {
111 self.inner.lock().unwrap().len()
112 }
113
114 pub fn total_charge(&self) -> usize {
116 self.inner.lock().unwrap().cur_cost
117 }
118
119 pub fn total_capacity(&self) -> usize {
121 self.inner.lock().unwrap().max_cost
122 }
123
124 pub fn shard_count(&self) -> usize {
126 Self::DEFAULT_SHARDS
127 }
128
129 #[cfg(test)]
135 pub fn shard_usage(&self, idx: usize) -> (usize, usize) {
136 assert_eq!(idx, 0, "shard index out of bounds");
137 let inner = self.inner.lock().unwrap();
138 (inner.cur_cost, inner.max_cost)
139 }
140
141 #[cfg(test)]
142 pub(crate) fn validate_invariants(&self) {
143 self.inner.lock().unwrap().check_invariants();
144 }
145}
146
147impl<K, V, S> BufferCache<K, V> for LruCache<K, V, S>
148where
149 K: Ord + Clone + Debug + Send + Sync + 'static,
150 V: CacheEntry + Clone + Send + Sync + 'static,
151 S: Send + Sync + 'static,
152{
153 fn as_any(&self) -> &dyn Any {
154 self
155 }
156
157 fn strategy(&self) -> BufferCacheStrategy {
158 BufferCacheStrategy::Lru
159 }
160
161 fn insert(&self, key: K, value: V) {
162 self.insert(key, value);
163 }
164
165 fn get(&self, key: K) -> Option<V> {
166 self.inner.lock().unwrap().get(key)
167 }
168
169 fn remove(&self, key: &K) -> Option<V> {
170 self.remove(key)
171 }
172
173 fn remove_if(&self, predicate: &dyn Fn(&K) -> bool) {
174 self.remove_if(|key| predicate(key))
175 }
176
177 fn contains_key(&self, key: &K) -> bool {
178 self.contains_key(key)
179 }
180
181 fn len(&self) -> usize {
182 self.len()
183 }
184
185 fn total_charge(&self) -> usize {
186 self.total_charge()
187 }
188
189 fn total_capacity(&self) -> usize {
190 self.total_capacity()
191 }
192
193 fn shard_count(&self) -> usize {
194 self.shard_count()
195 }
196
197 #[cfg(test)]
198 fn shard_usage(&self, idx: usize) -> (usize, usize) {
199 self.shard_usage(idx)
200 }
201}
202
203impl<K, V> CacheInner<K, V>
204where
205 K: Ord + Clone + Debug,
206 V: CacheEntry + Clone,
207{
208 fn new(max_cost: usize) -> Self {
210 Self {
211 cache: BTreeMap::new(),
212 lru: BTreeMap::new(),
213 next_serial: 0,
214 cur_cost: 0,
215 max_cost,
216 }
217 }
218
219 #[cfg(any(test, debug_assertions))]
221 fn check_invariants(&self) {
222 assert_eq!(self.cache.len(), self.lru.len());
223 let mut cost = 0;
224 for (key, value) in self.cache.iter() {
225 assert_eq!(self.lru.get(&value.serial), Some(key));
226 cost += value.aux.cost();
227 }
228 for (serial, key) in self.lru.iter() {
229 assert_eq!(self.cache.get(key).unwrap().serial, *serial);
230 }
231 assert_eq!(cost, self.cur_cost);
232 }
233
234 fn debug_check_invariants(&self) {
236 #[cfg(debug_assertions)]
237 self.check_invariants()
238 }
239
240 fn get(&mut self, key: K) -> Option<V> {
242 if let Some(value) = self.cache.get_mut(&key) {
243 self.lru.remove(&value.serial);
244 value.serial = self.next_serial;
245 self.lru.insert(value.serial, key);
246 self.next_serial += 1;
247 Some(value.aux.clone())
248 } else {
249 None
250 }
251 }
252
253 fn evict_to(&mut self, max_cost: usize) {
255 while self.cur_cost > max_cost {
256 let (_serial, key) = self.lru.pop_first().unwrap();
259 let value = self.cache.remove(&key).unwrap();
260 self.cur_cost -= value.aux.cost();
261 }
262 self.debug_check_invariants();
263 }
264
265 fn insert(&mut self, key: K, aux: V) {
267 let cost = aux.cost();
268 self.evict_to(self.max_cost.saturating_sub(cost));
269 if let Some(old_value) = self.cache.insert(
270 key.clone(),
271 CacheValue {
272 aux,
273 serial: self.next_serial,
274 },
275 ) {
276 self.lru.remove(&old_value.serial);
277 self.cur_cost -= old_value.aux.cost();
278 }
279 self.lru.insert(self.next_serial, key);
280 self.cur_cost += cost;
281 self.next_serial += 1;
282 self.debug_check_invariants();
283 }
284
285 fn remove(&mut self, key: &K) -> Option<V> {
287 let value = self.cache.remove(key)?;
288 self.lru.remove(&value.serial).unwrap();
289 self.cur_cost -= value.aux.cost();
290 self.debug_check_invariants();
291 Some(value.aux)
292 }
293
294 fn remove_if<F>(&mut self, predicate: F)
296 where
297 F: Fn(&K) -> bool,
298 {
299 let keys: Vec<K> = self
300 .cache
301 .keys()
302 .filter(|key| predicate(key))
303 .cloned()
304 .collect();
305 for key in keys {
306 let _ = self.remove(&key);
307 }
308 }
309
310 fn remove_range<R>(&mut self, range: R) -> usize
312 where
313 R: RangeBounds<K>,
314 {
315 let victims: Vec<(K, u64)> = self
316 .cache
317 .range(range)
318 .map(|(key, value)| (key.clone(), value.serial))
319 .collect();
320
321 let removed = victims.len();
322 for (key, serial) in victims {
323 self.lru.remove(&serial).unwrap();
324 self.cur_cost -= self.cache.remove(&key).unwrap().aux.cost();
325 }
326 self.debug_check_invariants();
327 removed
328 }
329
330 fn contains_key(&self, key: &K) -> bool {
332 self.cache.contains_key(key)
333 }
334
335 fn len(&self) -> usize {
337 self.cache.len()
338 }
339}