1use std::collections::HashMap;
2use std::hash::Hash;
3use std::ptr::NonNull;
4use std::time::{Duration, Instant};
5
6struct Node<K, V> {
7 key: K,
8 value: V,
9 expires_at: Option<Instant>,
10 next: Option<NonNull<Node<K, V>>>,
11 prev: Option<NonNull<Node<K, V>>>,
12}
13
14pub enum CleanupMode {
15 OnAccess,
17 OnDemand,
19}
20
21pub struct LruCache<K, V> {
22 map: HashMap<K, NonNull<Node<K, V>>>,
23 head: Option<NonNull<Node<K, V>>>,
24 tail: Option<NonNull<Node<K, V>>>,
25 capacity: usize,
26 cleanup_mode: CleanupMode,
27}
28
29impl<K: Eq + Hash + Clone, V> LruCache<K, V> {
30 pub fn new(capacity: usize, cleanup_mode: CleanupMode) -> Self {
31 assert!(capacity > 0);
32 LruCache {
33 map: HashMap::with_capacity(capacity),
34 head: None,
35 tail: None,
36 capacity,
37 cleanup_mode,
38 }
39 }
40
41 pub fn put(&mut self, key: K, value: V, ttl: Option<Duration>) {
42 if matches!(self.cleanup_mode, CleanupMode::OnAccess) {
43 self.evict_expired();
44 }
45 let expires_at = ttl.map(|d| Instant::now() + d);
46
47 if let Some(&node_ptr) = self.map.get(&key) {
48 unsafe {
49 let node = node_ptr.as_ptr().as_mut().unwrap();
50 node.value = value;
51 node.expires_at = expires_at;
52 self.remove_node(node_ptr);
53 self.push_front(node_ptr);
54 }
55 return;
56 }
57
58 if self.map.len() >= self.capacity {
59 self.remove_last();
60 }
61
62 let node = Box::new(Node {
63 key: key.clone(),
64 value,
65 expires_at,
66 next: self.head,
67 prev: None,
68 });
69
70 let node_ptr = unsafe { NonNull::new_unchecked(Box::into_raw(node)) };
71
72 if let Some(mut head) = self.head {
73 unsafe { head.as_mut().prev = Some(node_ptr) };
74 } else {
75 self.tail = Some(node_ptr);
76 }
77
78 self.head = Some(node_ptr);
79 self.map.insert(key, node_ptr);
80 }
81
82 pub fn get(&mut self, key: &K) -> Option<&V> {
83 if matches!(self.cleanup_mode, CleanupMode::OnAccess) {
84 self.evict_expired();
85 }
86
87 let node_ptr = *self.map.get(key)?;
89
90 unsafe {
91 let node = node_ptr.as_ptr().as_ref().unwrap();
92
93 if node.expired() {
94 self.map.remove(key);
95 self.remove_node(node_ptr);
96 let _ = Box::from_raw(node_ptr.as_ptr());
97 return None;
98 }
99
100 self.remove_node(node_ptr);
101 self.push_front(node_ptr);
102
103 Some(&(*node_ptr.as_ptr()).value)
104 }
105 }
106
107 pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
108 if matches!(self.cleanup_mode, CleanupMode::OnAccess) {
109 self.evict_expired();
110 }
111
112 let node_ptr = *self.map.get(key)?;
113
114 unsafe {
115 let node = node_ptr.as_ptr().as_mut().unwrap();
116
117 if node.expired() {
118 self.map.remove(key);
119 self.remove_node(node_ptr);
120 let _ = Box::from_raw(node_ptr.as_ptr());
121 return None;
122 }
123
124 self.remove_node(node_ptr);
125 self.push_front(node_ptr);
126
127 Some(&mut (*node_ptr.as_ptr()).value)
128 }
129 }
130
131 fn remove_node(&mut self, node_ptr: NonNull<Node<K, V>>) {
132 unsafe {
133 let node = node_ptr.as_ptr();
134
135 if let Some(prev) = (*node).prev {
136 (*prev.as_ptr()).next = (*node).next;
137 } else {
138 self.head = (*node).next;
139 }
140
141 if let Some(next) = (*node).next {
142 (*next.as_ptr()).prev = (*node).prev;
143 } else {
144 self.tail = (*node).prev;
145 }
146 }
147 }
148
149 fn push_front(&mut self, node_ptr: NonNull<Node<K, V>>) {
150 unsafe {
151 (*node_ptr.as_ptr()).next = self.head;
152 (*node_ptr.as_ptr()).prev = None;
153
154 if let Some(head) = self.head {
155 let head_mut = head.as_ptr() as *mut Node<K, V>;
156 (*head_mut).prev = Some(node_ptr);
157 } else {
158 self.tail = Some(node_ptr);
159 }
160
161 self.head = Some(node_ptr);
162 }
163 }
164
165 fn remove_last(&mut self) {
166 if let Some(tail_ptr) = self.tail {
167 unsafe {
168 let key = &(*tail_ptr.as_ptr()).key;
169 let prev = (*tail_ptr.as_ptr()).prev;
170
171 self.map.remove(key);
172
173 match prev {
174 Some(prev) => {
175 let prev_mut = prev.as_ptr() as *mut Node<K, V>;
176 (*prev_mut).next = None;
177 self.tail = Some(prev);
178 }
179 None => {
180 self.head = None;
181 self.tail = None;
182 }
183 }
184
185 let _ = Box::from_raw(tail_ptr.as_ptr());
186 }
187 }
188 }
189
190 pub fn evict_expired(&mut self) {
191 let now = Instant::now();
192 let mut current = self.head;
193
194 while let Some(node_ptr) = current {
195 unsafe {
196 let node = node_ptr.as_ptr();
197 current = (*node).next;
198
199 if (*node).expired_at(now) {
200 self.map.remove(&(*node).key);
201 self.remove_node(node_ptr);
202 let _ = Box::from_raw(node);
203 }
204 }
205 }
206 }
207
208 pub fn len(&self) -> usize {
209 self.map.len()
210 }
211
212 pub fn is_empty(&self) -> bool {
213 self.map.is_empty()
214 }
215
216 pub fn capacity(&self) -> usize {
217 self.capacity
218 }
219}
220
221impl<K, V> Drop for LruCache<K, V> {
222 fn drop(&mut self) {
223 let mut current = self.head;
224 while let Some(node_ptr) = current {
225 unsafe {
226 current = (*node_ptr.as_ptr()).next;
227 let _ = Box::from_raw(node_ptr.as_ptr());
228 }
229 }
230 }
231}
232
233impl<K, V> Node<K, V> {
234 fn expired(&self) -> bool {
235 self.expires_at.map_or(false, |e| e <= Instant::now())
236 }
237
238 fn expired_at(&self, now: Instant) -> bool {
239 self.expires_at.map_or(false, |e| e <= now)
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use std::thread;
247
248 #[test]
249 fn test_basic_operations() {
250 let mut cache = LruCache::new(2, CleanupMode::OnAccess);
251 cache.put("a", 1, None);
252 cache.put("b", 2, None);
253
254 assert_eq!(cache.get(&"a"), Some(&1));
255 assert_eq!(cache.get(&"b"), Some(&2));
256 assert_eq!(cache.get(&"c"), None);
257
258 cache.put("c", 3, None);
259 assert_eq!(cache.get(&"a"), None);
260 assert_eq!(cache.get(&"b"), Some(&2));
261 assert_eq!(cache.get(&"c"), Some(&3));
262 }
263
264 #[test]
265 fn test_ttl_expiration_auto() {
266 let mut cache = LruCache::new(2, CleanupMode::OnDemand);
267 cache.put("a", 1, Some(Duration::from_millis(150)));
268 cache.put("b", 2, None);
269
270 assert_eq!(cache.get(&"a"), Some(&1));
271 assert_eq!(cache.get(&"b"), Some(&2));
272
273 thread::sleep(Duration::from_millis(200));
274
275 assert_eq!(cache.get(&"a"), None);
276 assert_eq!(cache.get(&"b"), Some(&2));
277 }
278
279 #[test]
280 fn test_ttl_expiration() {
281 let mut cache = LruCache::new(2, CleanupMode::OnAccess);
282 cache.put("a", 1, Some(Duration::from_millis(150)));
283 cache.put("b", 2, None);
284
285 assert_eq!(cache.get(&"a"), Some(&1));
286 assert_eq!(cache.get(&"b"), Some(&2));
287
288 thread::sleep(Duration::from_millis(200));
289
290 cache.evict_expired();
291
292 assert_eq!(cache.get(&"a"), None);
293 assert_eq!(cache.get(&"b"), Some(&2));
294 }
295
296 #[test]
297 fn test_lru_eviction() {
298 let mut cache = LruCache::new(3, CleanupMode::OnAccess);
299 cache.put("a", 1, None);
300 cache.put("b", 2, None);
301 cache.put("c", 3, None);
302
303 cache.get(&"a");
304 cache.put("d", 4, None);
305
306 assert_eq!(cache.get(&"b"), None);
307 assert_eq!(cache.get(&"a"), Some(&1));
308 assert_eq!(cache.get(&"c"), Some(&3));
309 assert_eq!(cache.get(&"d"), Some(&4));
310 }
311
312 #[test]
313 fn test_no_memory_leaks() {
314 let mut cache = LruCache::new(2, CleanupMode::OnAccess);
315 for i in 0..1000 {
316 cache.put(i, Box::new([0u8; 1024]), None);
317 }
318 assert_eq!(cache.len(), 2);
319 }
320}