liquid_cache_storage/cache/cache_policies/
lru.rs

1//! LRU cache policy implementation using a hash map and doubly linked list.
2
3use std::{collections::HashMap, ptr::NonNull};
4
5use crate::{
6    cache::{cached_batch::CachedBatchType, utils::EntryID},
7    sync::Mutex,
8};
9
10use super::{
11    CachePolicy,
12    doubly_linked_list::{DoublyLinkedList, DoublyLinkedNode, drop_boxed_node},
13};
14
15#[derive(Debug)]
16struct LruNode {
17    entry_id: EntryID,
18}
19
20type NodePtr = NonNull<DoublyLinkedNode<LruNode>>;
21
22#[derive(Debug, Default)]
23struct HashList {
24    map: HashMap<EntryID, NodePtr>,
25    list: DoublyLinkedList<LruNode>,
26}
27
28impl HashList {
29    fn tail(&self) -> Option<NodePtr> {
30        self.list.tail()
31    }
32
33    unsafe fn move_to_front(&mut self, node_ptr: NodePtr) {
34        unsafe { self.list.move_to_front(node_ptr) };
35    }
36
37    unsafe fn push_front(&mut self, node_ptr: NodePtr) {
38        unsafe { self.list.push_front(node_ptr) };
39    }
40
41    unsafe fn remove_and_release(&mut self, node_ptr: NodePtr) {
42        unsafe {
43            self.list.unlink(node_ptr);
44            drop_boxed_node(node_ptr);
45        }
46    }
47}
48
49impl Drop for HashList {
50    fn drop(&mut self) {
51        for (_, node_ptr) in self.map.drain() {
52            unsafe {
53                self.list.unlink(node_ptr);
54                drop_boxed_node(node_ptr);
55            }
56        }
57        // Any nodes not tracked in the map (shouldn't happen) get cleaned up here.
58        unsafe {
59            self.list.drop_all();
60        }
61    }
62}
63
64/// The policy that implement the LRU algorithm using a HashMap and a doubly linked list.
65#[derive(Debug, Default)]
66pub struct LruPolicy {
67    state: Mutex<HashList>,
68}
69
70impl LruPolicy {
71    /// Create a new [`LruPolicy`].
72    pub fn new() -> Self {
73        Self {
74            state: Mutex::new(HashList::default()),
75        }
76    }
77}
78
79// SAFETY: The Mutex ensures that only one thread accesses the internal state
80// (hash map and intrusive list containing NonNull pointers) at a time, making it safe
81// to send and share across threads.
82unsafe impl Send for LruPolicy {}
83unsafe impl Sync for LruPolicy {}
84
85impl CachePolicy for LruPolicy {
86    fn find_victim(&self, cnt: usize) -> Vec<EntryID> {
87        let mut state = self.state.lock().unwrap();
88        if cnt == 0 {
89            return vec![];
90        }
91
92        let mut advices = Vec::with_capacity(cnt);
93        for _ in 0..cnt {
94            let Some(tail_ptr) = state.tail() else {
95                break;
96            };
97            let tail_entry_id = unsafe { tail_ptr.as_ref().data.entry_id };
98            let node_ptr = state
99                .map
100                .remove(&tail_entry_id)
101                .expect("tail node not found");
102            unsafe {
103                state.remove_and_release(node_ptr);
104            }
105            advices.push(tail_entry_id);
106        }
107
108        advices
109    }
110
111    fn notify_access(&self, entry_id: &EntryID, _batch_type: CachedBatchType) {
112        let mut state = self.state.lock().unwrap();
113        if let Some(node_ptr) = state.map.get(entry_id).copied() {
114            unsafe { state.move_to_front(node_ptr) };
115        }
116    }
117
118    fn notify_insert(&self, entry_id: &EntryID, _batch_type: CachedBatchType) {
119        let mut state = self.state.lock().unwrap();
120
121        if let Some(existing_node_ptr) = state.map.get(entry_id).copied() {
122            unsafe { state.move_to_front(existing_node_ptr) };
123            return;
124        }
125
126        let node = DoublyLinkedNode::new(LruNode {
127            entry_id: *entry_id,
128        });
129        let node_ptr = NonNull::from(Box::leak(node));
130
131        state.map.insert(*entry_id, node_ptr);
132        unsafe {
133            state.push_front(node_ptr);
134        }
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use crate::cache::utils::{EntryID, create_cache_store, create_test_arrow_array};
142    use crate::sync::{Arc, Barrier, thread};
143    use std::sync::atomic::{AtomicUsize, Ordering};
144
145    fn entry(id: usize) -> EntryID {
146        id.into()
147    }
148
149    fn assert_evict_advice(policy: &LruPolicy, expect_evict: EntryID) {
150        let advice = policy.find_victim(1);
151        assert_eq!(advice, vec![expect_evict]);
152    }
153
154    #[test]
155    fn test_lru_policy_insertion_order() {
156        let policy = LruPolicy::new();
157        let e1 = entry(1);
158        let e2 = entry(2);
159        let e3 = entry(3);
160
161        policy.notify_insert(&e1, CachedBatchType::MemoryArrow);
162        policy.notify_insert(&e2, CachedBatchType::MemoryArrow);
163        policy.notify_insert(&e3, CachedBatchType::MemoryArrow);
164
165        assert_evict_advice(&policy, e1);
166    }
167
168    #[test]
169    fn test_lru_policy_access_moves_to_front() {
170        let policy = LruPolicy::new();
171        let e1 = entry(1);
172        let e2 = entry(2);
173        let e3 = entry(3);
174
175        policy.notify_insert(&e1, CachedBatchType::MemoryArrow);
176        policy.notify_insert(&e2, CachedBatchType::MemoryArrow);
177        policy.notify_insert(&e3, CachedBatchType::MemoryArrow);
178
179        policy.notify_access(&e1, CachedBatchType::MemoryArrow);
180        assert_evict_advice(&policy, e2);
181        policy.notify_access(&e2, CachedBatchType::MemoryArrow);
182        assert_evict_advice(&policy, e3);
183    }
184
185    #[test]
186    fn test_lru_policy_reinsert_moves_to_front() {
187        let policy = LruPolicy::new();
188        let e1 = entry(1);
189        let e2 = entry(2);
190        let e3 = entry(3);
191
192        policy.notify_insert(&e1, CachedBatchType::MemoryArrow);
193        policy.notify_insert(&e2, CachedBatchType::MemoryArrow);
194        policy.notify_insert(&e3, CachedBatchType::MemoryArrow);
195
196        policy.notify_insert(&e1, CachedBatchType::MemoryArrow);
197        assert_evict_advice(&policy, e2);
198    }
199
200    #[test]
201    fn test_lru_policy_advise_empty() {
202        let policy = LruPolicy::new();
203        assert_eq!(policy.find_victim(1), vec![]);
204    }
205
206    #[test]
207    fn test_lru_policy_advise_single_item_self() {
208        let policy = LruPolicy::new();
209        let e1 = entry(1);
210        policy.notify_insert(&e1, CachedBatchType::MemoryArrow);
211
212        assert_evict_advice(&policy, e1);
213    }
214
215    #[test]
216    fn test_lru_policy_advise_single_item_other() {
217        let policy = LruPolicy::new();
218        let e1 = entry(1);
219        policy.notify_insert(&e1, CachedBatchType::MemoryArrow);
220        assert_evict_advice(&policy, e1);
221    }
222
223    #[test]
224    fn test_lru_policy_access_nonexistent() {
225        let policy = LruPolicy::new();
226        let e1 = entry(1);
227        let e2 = entry(2);
228
229        policy.notify_insert(&e1, CachedBatchType::MemoryArrow);
230        policy.notify_insert(&e2, CachedBatchType::MemoryArrow);
231
232        policy.notify_access(&entry(99), CachedBatchType::MemoryArrow);
233
234        assert_evict_advice(&policy, e1);
235    }
236
237    impl HashList {
238        fn check_integrity(&self) {
239            let map_count = self.map.len();
240            let forward_count = count_nodes_in_list(self);
241            let backward_count = count_nodes_reverse(self);
242
243            assert_eq!(map_count, forward_count);
244            assert_eq!(map_count, backward_count);
245        }
246    }
247
248    fn count_nodes_in_list(state: &HashList) -> usize {
249        let mut count = 0;
250        let mut current = state.list.head();
251
252        while let Some(node_ptr) = current {
253            count += 1;
254            current = unsafe { node_ptr.as_ref().next };
255        }
256
257        count
258    }
259
260    fn count_nodes_reverse(state: &HashList) -> usize {
261        let mut count = 0;
262        let mut current = state.list.tail();
263
264        while let Some(node_ptr) = current {
265            count += 1;
266            current = unsafe { node_ptr.as_ref().prev };
267        }
268
269        count
270    }
271
272    #[test]
273    fn test_lru_policy_invariants() {
274        let policy = LruPolicy::new();
275
276        for i in 0..10 {
277            policy.notify_insert(&entry(i), CachedBatchType::MemoryArrow);
278        }
279        policy.notify_access(&entry(2), CachedBatchType::MemoryArrow);
280        policy.notify_access(&entry(5), CachedBatchType::MemoryArrow);
281        policy.find_victim(1);
282        policy.find_victim(1);
283
284        let state = policy.state.lock().unwrap();
285        state.check_integrity();
286
287        let map_count = state.map.len();
288        assert_eq!(map_count, 8);
289        assert!(!state.map.contains_key(&entry(0)));
290        assert!(!state.map.contains_key(&entry(1)));
291        assert!(state.map.contains_key(&entry(2)));
292
293        let head_id = unsafe { state.list.head().unwrap().as_ref().data.entry_id };
294        assert_eq!(head_id, entry(5));
295    }
296
297    #[test]
298    fn test_concurrent_lru_operations() {
299        concurrent_lru_operations();
300    }
301
302    #[cfg(feature = "shuttle")]
303    #[test]
304    fn shuttle_lru_operations() {
305        crate::utils::shuttle_test(concurrent_lru_operations);
306    }
307
308    fn concurrent_lru_operations() {
309        let policy = Arc::new(LruPolicy::new());
310        let num_threads = 4;
311        let operations_per_thread = 100;
312
313        let total_inserts = Arc::new(AtomicUsize::new(0));
314        let total_evictions = Arc::new(AtomicUsize::new(0));
315
316        let barrier = Arc::new(Barrier::new(num_threads));
317
318        let mut handles = vec![];
319        for thread_id in 0..num_threads {
320            let policy_clone = policy.clone();
321            let total_inserts_clone = total_inserts.clone();
322            let total_evictions_clone = total_evictions.clone();
323            let barrier_clone = barrier.clone();
324
325            let handle = thread::spawn(move || {
326                barrier_clone.wait();
327
328                for i in 0..operations_per_thread {
329                    let op_type = i % 3;
330                    let entry_id = entry(thread_id * operations_per_thread + i);
331
332                    match op_type {
333                        0 => {
334                            policy_clone.notify_insert(&entry_id, CachedBatchType::MemoryArrow);
335                            total_inserts_clone.fetch_add(1, Ordering::SeqCst);
336                        }
337                        1 => {
338                            policy_clone.notify_access(&entry_id, CachedBatchType::MemoryArrow);
339                        }
340                        _ => {
341                            let advised = policy_clone.find_victim(1);
342                            if !advised.is_empty() {
343                                total_evictions_clone.fetch_add(1, Ordering::SeqCst);
344                            }
345                        }
346                    }
347                }
348            });
349
350            handles.push(handle);
351        }
352
353        for handle in handles {
354            handle.join().unwrap();
355        }
356
357        let state = policy.state.lock().unwrap();
358        state.check_integrity();
359
360        let inserts = total_inserts.load(Ordering::SeqCst);
361        let evictions = total_evictions.load(Ordering::SeqCst);
362        assert!(inserts >= evictions);
363    }
364
365    #[tokio::test]
366    async fn test_lru_integration() {
367        let policy = LruPolicy::new();
368        let store = create_cache_store(3000, Box::new(policy));
369
370        let entry_id1 = EntryID::from(1);
371        let entry_id2 = EntryID::from(2);
372        let entry_id3 = EntryID::from(3);
373
374        store.insert(entry_id1, create_test_arrow_array(100)).await;
375        store.insert(entry_id2, create_test_arrow_array(100)).await;
376        store.insert(entry_id3, create_test_arrow_array(100)).await;
377
378        assert!(store.index().get(&entry_id1).is_some());
379        assert!(store.index().get(&entry_id2).is_some());
380        assert!(store.index().get(&entry_id3).is_some());
381    }
382}