1use std::{
2 fmt::Debug,
3 hash::{BuildHasher, Hash},
4 marker::PhantomData,
5 time::{Duration, Instant},
6 cmp::Reverse,
7};
8
9use parking_lot::{Mutex, MutexGuard};
10
11use priority_queue::PriorityQueue;
12
13use super::{
14 linked_arena::{LinkedArena, LinkedNode},
15 Policy, Prune,
16};
17use crate::LightCache;
18
19pub struct LruPolicy<K, V> {
21 inner: Mutex<LruPolicyInner<K>>,
22 phantom: PhantomData<V>,
24}
25
26pub struct LruPolicyInner<K> {
27 capacity: usize,
28 arena: LinkedArena<K, LruNode<K>>,
29 expiring: Option<(Duration, PriorityQueue<K, Reverse<Instant>>)>,
30}
31
32impl<K: Hash + Eq, V> LruPolicy<K, V> {
33 pub fn new(capacity: usize, ttl: Option<Duration>) -> Self {
36 assert!(capacity > 1, "LRU capacity must be greater than 1");
37
38 LruPolicy {
39 inner: Mutex::new(LruPolicyInner {
40 capacity,
41 arena: LinkedArena::new(),
42 expiring: ttl.map(|ttl| (ttl, PriorityQueue::new())),
43 }),
44 phantom: PhantomData,
45 }
46 }
47}
48
49impl<K, V> Policy<K, V> for LruPolicy<K, V>
50where
51 K: Copy + Eq + Hash,
52 V: Clone + Sync,
53{
54 type Inner = LruPolicyInner<K>;
55
56 #[inline]
57 fn lock_inner(&self) -> MutexGuard<'_, Self::Inner> {
58 self.inner.lock()
59 }
60
61 fn get<S: BuildHasher>(&self, key: &K, cache: &LightCache<K, V, S, Self>) -> Option<V> {
62 {
63 let mut inner = self.lock_and_prune(cache);
64
65 if let Some((idx, _)) = inner.arena.get_node_mut(&key) {
66 inner.arena.move_to_head(idx);
67 }
68 }
69
70 cache.get_no_policy(key)
71 }
72
73 fn insert<S: BuildHasher>(&self, key: K, value: V, cache: &LightCache<K, V, S, Self>) -> Option<V> {
74 {
75 let mut inner = self.lock_and_prune(cache);
76
77 if let Some((idx, _)) = inner.arena.get_node_mut(&key) {
79 inner.arena.move_to_head(idx);
80 } else {
81 inner.arena.insert_head(key);
82 }
83
84 if let Some((duration, pq)) = inner.expiring.as_mut() {
85 pq.push(key, Reverse(Instant::now() + *duration));
86 }
87
88 inner.evict(cache);
89 }
90
91 cache.insert_no_policy(key, value)
92 }
93
94 fn remove<S: BuildHasher>(&self, key: &K, cache: &LightCache<K, V, S, Self>) -> Option<V> {
95 {
96 let mut inner = self.lock_and_prune(cache);
97 inner.arena.remove_item(key);
98
99 if let Some((_, pq)) = inner.expiring.as_mut() {
100 pq.remove(key);
101 }
102 }
103
104 cache.remove_no_policy(key)
105 }
106}
107
108impl<K, V> Prune<K, V, LruPolicy<K, V>> for LruPolicyInner<K>
109where
110 K: Copy + Eq + Hash,
111 V: Clone + Sync,
112{
113 #[inline]
114 fn prune<S: BuildHasher>(&mut self, cache: &LightCache<K, V, S, LruPolicy<K, V>>) {
115 if let Some((_, pq)) = self.expiring.as_mut() {
116 while let Some((key, expiry)) = pq.peek() {
117 if expiry.0 < Instant::now() {
118 self.arena.remove_item(key);
119 cache.remove_no_policy(key);
120 pq.pop();
121 } else {
122 break;
123 }
124 }
125 }
126 }
127}
128
129impl<K: Copy + Eq + Hash> LruPolicyInner<K> {
130 #[inline]
131 fn evict<S: BuildHasher, V: Clone + Sync>(&mut self, cache: &LightCache<K, V, S, LruPolicy<K, V>>) {
132 if self.arena.len() > self.capacity {
133 if let Some((idx, _)) = self.arena.tail() {
135 let (_, n) = self.arena.remove(idx);
136
137 cache.remove_no_policy(n.item());
138
139 if let Some((_, pq)) = self.expiring.as_mut() {
140 pq.remove(&n.key);
141 }
142 }
143 }
144 }
145}
146
147struct LruNode<K> {
148 key: K,
149 prev: Option<usize>,
150 next: Option<usize>,
151}
152
153impl<K> LinkedNode<K> for LruNode<K>
154where
155 K: Copy + Eq + Hash,
156{
157 fn new(key: K, prev: Option<usize>, next: Option<usize>) -> Self {
158 LruNode {
159 key,
160 prev,
161 next,
162 }
163 }
164
165 fn item(&self) -> &K {
166 &self.key
167 }
168
169 fn prev(&self) -> Option<usize> {
170 self.prev
171 }
172
173 fn next(&self) -> Option<usize> {
174 self.next
175 }
176
177 fn set_prev(&mut self, prev: Option<usize>) {
178 self.prev = prev;
179 }
180
181 fn set_next(&mut self, next: Option<usize>) {
182 self.next = next;
183 }
184}
185
186impl<K> Debug for LruNode<K> {
187 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188 f.debug_struct("LruNode")
189 .field("prev", &self.prev)
190 .field("next", &self.next)
191 .finish()
192 }
193}
194
195#[cfg(test)]
196mod test {
197 use hashbrown::hash_map::DefaultHashBuilder;
198
199 use super::*;
200
201 fn duration_seconds(seconds: u64) -> Duration {
202 Duration::from_secs(seconds)
203 }
204
205 fn sleep_seconds(seconds: u64) {
206 std::thread::sleep(duration_seconds(seconds));
207 }
208
209 fn insert_n<S>(cache: &LightCache<i32, i32, S, LruPolicy<i32, i32>>, n: usize)
210 where
211 S: BuildHasher,
212 {
213 for i in 0..n {
214 cache.insert(i as i32, i as i32);
215 }
216 }
217
218 fn cache<K, V>(capacity: usize, lifetime: Duration) -> LightCache<K, V, DefaultHashBuilder, LruPolicy<K, V>>
219 where
220 K: Copy + Eq + Hash,
221 V: Clone + Sync,
222 {
223 LightCache::from_parts(LruPolicy::new(capacity, Some(lifetime)), Default::default())
224 }
225
226 #[test]
227 fn test_basic_scenario_1() {
230 let cache = cache::<i32, i32>(5, duration_seconds(1));
231
232 insert_n(&cache, 5);
233
234 sleep_seconds(2);
235
236 insert_n(&cache, 2);
237
238 assert_eq!(cache.len(), 2);
239 let policy = cache.policy().lock_inner();
240
241 assert_eq!(policy.arena.nodes.len(), 2);
242 assert_eq!(policy.arena.head, Some(1));
243 assert_eq!(policy.arena.tail, Some(0));
244 }
245
246 #[test]
247 fn test_basic_scenario_2() {
248 let cache = cache::<i32, i32>(5, duration_seconds(2));
249
250 insert_n(&cache, 10);
251 cache.remove(&8);
252
253 assert_eq!(cache.len(), 4);
254 }
255}