liquid_cache_storage/cache/cache_policies/
lru.rs1use std::{collections::HashMap, ptr::NonNull};
4
5use crate::{
6 cache::{cached_batch::CachedBatchType, utils::EntryID},
7 sync::Mutex,
8};
9
10use super::{
11 CachePolicy,
12 doubly_linked_list::{DoublyLinkedList, DoublyLinkedNode, drop_boxed_node},
13};
14
15#[derive(Debug)]
16struct LruNode {
17 entry_id: EntryID,
18}
19
20type NodePtr = NonNull<DoublyLinkedNode<LruNode>>;
21
22#[derive(Debug, Default)]
23struct HashList {
24 map: HashMap<EntryID, NodePtr>,
25 list: DoublyLinkedList<LruNode>,
26}
27
28impl HashList {
29 fn tail(&self) -> Option<NodePtr> {
30 self.list.tail()
31 }
32
33 unsafe fn move_to_front(&mut self, node_ptr: NodePtr) {
34 unsafe { self.list.move_to_front(node_ptr) };
35 }
36
37 unsafe fn push_front(&mut self, node_ptr: NodePtr) {
38 unsafe { self.list.push_front(node_ptr) };
39 }
40
41 unsafe fn remove_and_release(&mut self, node_ptr: NodePtr) {
42 unsafe {
43 self.list.unlink(node_ptr);
44 drop_boxed_node(node_ptr);
45 }
46 }
47}
48
49impl Drop for HashList {
50 fn drop(&mut self) {
51 for (_, node_ptr) in self.map.drain() {
52 unsafe {
53 self.list.unlink(node_ptr);
54 drop_boxed_node(node_ptr);
55 }
56 }
57 unsafe {
59 self.list.drop_all();
60 }
61 }
62}
63
64#[derive(Debug, Default)]
66pub struct LruPolicy {
67 state: Mutex<HashList>,
68}
69
70impl LruPolicy {
71 pub fn new() -> Self {
73 Self {
74 state: Mutex::new(HashList::default()),
75 }
76 }
77}
78
79unsafe impl Send for LruPolicy {}
83unsafe impl Sync for LruPolicy {}
84
85impl CachePolicy for LruPolicy {
86 fn find_victim(&self, cnt: usize) -> Vec<EntryID> {
87 let mut state = self.state.lock().unwrap();
88 if cnt == 0 {
89 return vec![];
90 }
91
92 let mut advices = Vec::with_capacity(cnt);
93 for _ in 0..cnt {
94 let Some(tail_ptr) = state.tail() else {
95 break;
96 };
97 let tail_entry_id = unsafe { tail_ptr.as_ref().data.entry_id };
98 let node_ptr = state
99 .map
100 .remove(&tail_entry_id)
101 .expect("tail node not found");
102 unsafe {
103 state.remove_and_release(node_ptr);
104 }
105 advices.push(tail_entry_id);
106 }
107
108 advices
109 }
110
111 fn notify_access(&self, entry_id: &EntryID, _batch_type: CachedBatchType) {
112 let mut state = self.state.lock().unwrap();
113 if let Some(node_ptr) = state.map.get(entry_id).copied() {
114 unsafe { state.move_to_front(node_ptr) };
115 }
116 }
117
118 fn notify_insert(&self, entry_id: &EntryID, _batch_type: CachedBatchType) {
119 let mut state = self.state.lock().unwrap();
120
121 if let Some(existing_node_ptr) = state.map.get(entry_id).copied() {
122 unsafe { state.move_to_front(existing_node_ptr) };
123 return;
124 }
125
126 let node = DoublyLinkedNode::new(LruNode {
127 entry_id: *entry_id,
128 });
129 let node_ptr = NonNull::from(Box::leak(node));
130
131 state.map.insert(*entry_id, node_ptr);
132 unsafe {
133 state.push_front(node_ptr);
134 }
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use crate::cache::utils::{EntryID, create_cache_store, create_test_arrow_array};
142 use crate::sync::{Arc, Barrier, thread};
143 use std::sync::atomic::{AtomicUsize, Ordering};
144
145 fn entry(id: usize) -> EntryID {
146 id.into()
147 }
148
149 fn assert_evict_advice(policy: &LruPolicy, expect_evict: EntryID) {
150 let advice = policy.find_victim(1);
151 assert_eq!(advice, vec![expect_evict]);
152 }
153
154 #[test]
155 fn test_lru_policy_insertion_order() {
156 let policy = LruPolicy::new();
157 let e1 = entry(1);
158 let e2 = entry(2);
159 let e3 = entry(3);
160
161 policy.notify_insert(&e1, CachedBatchType::MemoryArrow);
162 policy.notify_insert(&e2, CachedBatchType::MemoryArrow);
163 policy.notify_insert(&e3, CachedBatchType::MemoryArrow);
164
165 assert_evict_advice(&policy, e1);
166 }
167
168 #[test]
169 fn test_lru_policy_access_moves_to_front() {
170 let policy = LruPolicy::new();
171 let e1 = entry(1);
172 let e2 = entry(2);
173 let e3 = entry(3);
174
175 policy.notify_insert(&e1, CachedBatchType::MemoryArrow);
176 policy.notify_insert(&e2, CachedBatchType::MemoryArrow);
177 policy.notify_insert(&e3, CachedBatchType::MemoryArrow);
178
179 policy.notify_access(&e1, CachedBatchType::MemoryArrow);
180 assert_evict_advice(&policy, e2);
181 policy.notify_access(&e2, CachedBatchType::MemoryArrow);
182 assert_evict_advice(&policy, e3);
183 }
184
185 #[test]
186 fn test_lru_policy_reinsert_moves_to_front() {
187 let policy = LruPolicy::new();
188 let e1 = entry(1);
189 let e2 = entry(2);
190 let e3 = entry(3);
191
192 policy.notify_insert(&e1, CachedBatchType::MemoryArrow);
193 policy.notify_insert(&e2, CachedBatchType::MemoryArrow);
194 policy.notify_insert(&e3, CachedBatchType::MemoryArrow);
195
196 policy.notify_insert(&e1, CachedBatchType::MemoryArrow);
197 assert_evict_advice(&policy, e2);
198 }
199
200 #[test]
201 fn test_lru_policy_advise_empty() {
202 let policy = LruPolicy::new();
203 assert_eq!(policy.find_victim(1), vec![]);
204 }
205
206 #[test]
207 fn test_lru_policy_advise_single_item_self() {
208 let policy = LruPolicy::new();
209 let e1 = entry(1);
210 policy.notify_insert(&e1, CachedBatchType::MemoryArrow);
211
212 assert_evict_advice(&policy, e1);
213 }
214
215 #[test]
216 fn test_lru_policy_advise_single_item_other() {
217 let policy = LruPolicy::new();
218 let e1 = entry(1);
219 policy.notify_insert(&e1, CachedBatchType::MemoryArrow);
220 assert_evict_advice(&policy, e1);
221 }
222
223 #[test]
224 fn test_lru_policy_access_nonexistent() {
225 let policy = LruPolicy::new();
226 let e1 = entry(1);
227 let e2 = entry(2);
228
229 policy.notify_insert(&e1, CachedBatchType::MemoryArrow);
230 policy.notify_insert(&e2, CachedBatchType::MemoryArrow);
231
232 policy.notify_access(&entry(99), CachedBatchType::MemoryArrow);
233
234 assert_evict_advice(&policy, e1);
235 }
236
237 impl HashList {
238 fn check_integrity(&self) {
239 let map_count = self.map.len();
240 let forward_count = count_nodes_in_list(self);
241 let backward_count = count_nodes_reverse(self);
242
243 assert_eq!(map_count, forward_count);
244 assert_eq!(map_count, backward_count);
245 }
246 }
247
248 fn count_nodes_in_list(state: &HashList) -> usize {
249 let mut count = 0;
250 let mut current = state.list.head();
251
252 while let Some(node_ptr) = current {
253 count += 1;
254 current = unsafe { node_ptr.as_ref().next };
255 }
256
257 count
258 }
259
260 fn count_nodes_reverse(state: &HashList) -> usize {
261 let mut count = 0;
262 let mut current = state.list.tail();
263
264 while let Some(node_ptr) = current {
265 count += 1;
266 current = unsafe { node_ptr.as_ref().prev };
267 }
268
269 count
270 }
271
272 #[test]
273 fn test_lru_policy_invariants() {
274 let policy = LruPolicy::new();
275
276 for i in 0..10 {
277 policy.notify_insert(&entry(i), CachedBatchType::MemoryArrow);
278 }
279 policy.notify_access(&entry(2), CachedBatchType::MemoryArrow);
280 policy.notify_access(&entry(5), CachedBatchType::MemoryArrow);
281 policy.find_victim(1);
282 policy.find_victim(1);
283
284 let state = policy.state.lock().unwrap();
285 state.check_integrity();
286
287 let map_count = state.map.len();
288 assert_eq!(map_count, 8);
289 assert!(!state.map.contains_key(&entry(0)));
290 assert!(!state.map.contains_key(&entry(1)));
291 assert!(state.map.contains_key(&entry(2)));
292
293 let head_id = unsafe { state.list.head().unwrap().as_ref().data.entry_id };
294 assert_eq!(head_id, entry(5));
295 }
296
297 #[test]
298 fn test_concurrent_lru_operations() {
299 concurrent_lru_operations();
300 }
301
302 #[cfg(feature = "shuttle")]
303 #[test]
304 fn shuttle_lru_operations() {
305 crate::utils::shuttle_test(concurrent_lru_operations);
306 }
307
308 fn concurrent_lru_operations() {
309 let policy = Arc::new(LruPolicy::new());
310 let num_threads = 4;
311 let operations_per_thread = 100;
312
313 let total_inserts = Arc::new(AtomicUsize::new(0));
314 let total_evictions = Arc::new(AtomicUsize::new(0));
315
316 let barrier = Arc::new(Barrier::new(num_threads));
317
318 let mut handles = vec![];
319 for thread_id in 0..num_threads {
320 let policy_clone = policy.clone();
321 let total_inserts_clone = total_inserts.clone();
322 let total_evictions_clone = total_evictions.clone();
323 let barrier_clone = barrier.clone();
324
325 let handle = thread::spawn(move || {
326 barrier_clone.wait();
327
328 for i in 0..operations_per_thread {
329 let op_type = i % 3;
330 let entry_id = entry(thread_id * operations_per_thread + i);
331
332 match op_type {
333 0 => {
334 policy_clone.notify_insert(&entry_id, CachedBatchType::MemoryArrow);
335 total_inserts_clone.fetch_add(1, Ordering::SeqCst);
336 }
337 1 => {
338 policy_clone.notify_access(&entry_id, CachedBatchType::MemoryArrow);
339 }
340 _ => {
341 let advised = policy_clone.find_victim(1);
342 if !advised.is_empty() {
343 total_evictions_clone.fetch_add(1, Ordering::SeqCst);
344 }
345 }
346 }
347 }
348 });
349
350 handles.push(handle);
351 }
352
353 for handle in handles {
354 handle.join().unwrap();
355 }
356
357 let state = policy.state.lock().unwrap();
358 state.check_integrity();
359
360 let inserts = total_inserts.load(Ordering::SeqCst);
361 let evictions = total_evictions.load(Ordering::SeqCst);
362 assert!(inserts >= evictions);
363 }
364
365 #[tokio::test]
366 async fn test_lru_integration() {
367 let policy = LruPolicy::new();
368 let store = create_cache_store(3000, Box::new(policy));
369
370 let entry_id1 = EntryID::from(1);
371 let entry_id2 = EntryID::from(2);
372 let entry_id3 = EntryID::from(3);
373
374 store.insert(entry_id1, create_test_arrow_array(100)).await;
375 store.insert(entry_id2, create_test_arrow_array(100)).await;
376 store.insert(entry_id3, create_test_arrow_array(100)).await;
377
378 assert!(store.index().get(&entry_id1).is_some());
379 assert!(store.index().get(&entry_id2).is_some());
380 assert!(store.index().get(&entry_id3).is_some());
381 }
382}