lru_cache_rs/
lib.rs

1use std::collections::HashMap;
2use std::hash::Hash;
3use std::ptr::NonNull;
4use std::time::{Duration, Instant};
5
6struct Node<K, V> {
7    key: K,
8    value: V,
9    expires_at: Option<Instant>,
10    next: Option<NonNull<Node<K, V>>>,
11    prev: Option<NonNull<Node<K, V>>>,
12}
13
14pub enum CleanupMode {
15    /// Автоматическая очистка при каждом доступе
16    OnAccess,
17    /// Только при ручном вызове `evict_expired()`
18    OnDemand,
19}
20
21pub struct LruCache<K, V> {
22    map: HashMap<K, NonNull<Node<K, V>>>,
23    head: Option<NonNull<Node<K, V>>>,
24    tail: Option<NonNull<Node<K, V>>>,
25    capacity: usize,
26    cleanup_mode: CleanupMode,
27}
28
29impl<K: Eq + Hash + Clone, V> LruCache<K, V> {
30    pub fn new(capacity: usize, cleanup_mode: CleanupMode) -> Self {
31        assert!(capacity > 0);
32        LruCache {
33            map: HashMap::with_capacity(capacity),
34            head: None,
35            tail: None,
36            capacity,
37            cleanup_mode,
38        }
39    }
40
41    pub fn put(&mut self, key: K, value: V, ttl: Option<Duration>) {
42        if matches!(self.cleanup_mode, CleanupMode::OnAccess) {
43            self.evict_expired();
44        }
45        let expires_at = ttl.map(|d| Instant::now() + d);
46
47        if let Some(&node_ptr) = self.map.get(&key) {
48            unsafe {
49                let node = node_ptr.as_ptr().as_mut().unwrap();
50                node.value = value;
51                node.expires_at = expires_at;
52                self.remove_node(node_ptr);
53                self.push_front(node_ptr);
54            }
55            return;
56        }
57
58        if self.map.len() >= self.capacity {
59            self.remove_last();
60        }
61
62        let node = Box::new(Node {
63            key: key.clone(),
64            value,
65            expires_at,
66            next: self.head,
67            prev: None,
68        });
69
70        let node_ptr = unsafe { NonNull::new_unchecked(Box::into_raw(node)) };
71
72        if let Some(mut head) = self.head {
73            unsafe { head.as_mut().prev = Some(node_ptr) };
74        } else {
75            self.tail = Some(node_ptr);
76        }
77
78        self.head = Some(node_ptr);
79        self.map.insert(key, node_ptr);
80    }
81
82    pub fn get(&mut self, key: &K) -> Option<&V> {
83        if matches!(self.cleanup_mode, CleanupMode::OnAccess) {
84            self.evict_expired();
85        }
86
87        // разименовать выгоднее, в противном случае необходим cloned notnull
88        let node_ptr = *self.map.get(key)?;
89
90        unsafe {
91            let node = node_ptr.as_ptr().as_ref().unwrap();
92
93            if node.expired() {
94                self.map.remove(key);
95                self.remove_node(node_ptr);
96                let _ = Box::from_raw(node_ptr.as_ptr());
97                return None;
98            }
99
100            self.remove_node(node_ptr);
101            self.push_front(node_ptr);
102
103            Some(&(*node_ptr.as_ptr()).value)
104        }
105    }
106
107    pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
108        if matches!(self.cleanup_mode, CleanupMode::OnAccess) {
109            self.evict_expired();
110        }
111
112        let node_ptr = *self.map.get(key)?;
113
114        unsafe {
115            let node = node_ptr.as_ptr().as_mut().unwrap();
116
117            if node.expired() {
118                self.map.remove(key);
119                self.remove_node(node_ptr);
120                let _ = Box::from_raw(node_ptr.as_ptr());
121                return None;
122            }
123
124            self.remove_node(node_ptr);
125            self.push_front(node_ptr);
126
127            Some(&mut (*node_ptr.as_ptr()).value)
128        }
129    }
130
131    fn remove_node(&mut self, node_ptr: NonNull<Node<K, V>>) {
132        unsafe {
133            let node = node_ptr.as_ptr();
134
135            if let Some(prev) = (*node).prev {
136                (*prev.as_ptr()).next = (*node).next;
137            } else {
138                self.head = (*node).next;
139            }
140
141            if let Some(next) = (*node).next {
142                (*next.as_ptr()).prev = (*node).prev;
143            } else {
144                self.tail = (*node).prev;
145            }
146        }
147    }
148
149    fn push_front(&mut self, node_ptr: NonNull<Node<K, V>>) {
150        unsafe {
151            (*node_ptr.as_ptr()).next = self.head;
152            (*node_ptr.as_ptr()).prev = None;
153
154            if let Some(head) = self.head {
155                let head_mut = head.as_ptr() as *mut Node<K, V>;
156                (*head_mut).prev = Some(node_ptr);
157            } else {
158                self.tail = Some(node_ptr);
159            }
160
161            self.head = Some(node_ptr);
162        }
163    }
164
165    fn remove_last(&mut self) {
166        if let Some(tail_ptr) = self.tail {
167            unsafe {
168                let key = &(*tail_ptr.as_ptr()).key;
169                let prev = (*tail_ptr.as_ptr()).prev;
170
171                self.map.remove(key);
172
173                match prev {
174                    Some(prev) => {
175                        let prev_mut = prev.as_ptr() as *mut Node<K, V>;
176                        (*prev_mut).next = None;
177                        self.tail = Some(prev);
178                    }
179                    None => {
180                        self.head = None;
181                        self.tail = None;
182                    }
183                }
184
185                let _ = Box::from_raw(tail_ptr.as_ptr());
186            }
187        }
188    }
189
190    pub fn evict_expired(&mut self) {
191        let now = Instant::now();
192        let mut current = self.head;
193
194        while let Some(node_ptr) = current {
195            unsafe {
196                let node = node_ptr.as_ptr();
197                current = (*node).next;
198
199                if (*node).expired_at(now) {
200                    self.map.remove(&(*node).key);
201                    self.remove_node(node_ptr);
202                    let _ = Box::from_raw(node);
203                }
204            }
205        }
206    }
207
208    pub fn len(&self) -> usize {
209        self.map.len()
210    }
211
212    pub fn is_empty(&self) -> bool {
213        self.map.is_empty()
214    }
215
216    pub fn capacity(&self) -> usize {
217        self.capacity
218    }
219}
220
221impl<K, V> Drop for LruCache<K, V> {
222    fn drop(&mut self) {
223        let mut current = self.head;
224        while let Some(node_ptr) = current {
225            unsafe {
226                current = (*node_ptr.as_ptr()).next;
227                let _ = Box::from_raw(node_ptr.as_ptr());
228            }
229        }
230    }
231}
232
233impl<K, V> Node<K, V> {
234    fn expired(&self) -> bool {
235        self.expires_at.map_or(false, |e| e <= Instant::now())
236    }
237
238    fn expired_at(&self, now: Instant) -> bool {
239        self.expires_at.map_or(false, |e| e <= now)
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use std::thread;
247
248    #[test]
249    fn test_basic_operations() {
250        let mut cache = LruCache::new(2, CleanupMode::OnAccess);
251        cache.put("a", 1, None);
252        cache.put("b", 2, None);
253
254        assert_eq!(cache.get(&"a"), Some(&1));
255        assert_eq!(cache.get(&"b"), Some(&2));
256        assert_eq!(cache.get(&"c"), None);
257
258        cache.put("c", 3, None);
259        assert_eq!(cache.get(&"a"), None);
260        assert_eq!(cache.get(&"b"), Some(&2));
261        assert_eq!(cache.get(&"c"), Some(&3));
262    }
263
264    #[test]
265    fn test_ttl_expiration_auto() {
266        let mut cache = LruCache::new(2, CleanupMode::OnDemand);
267        cache.put("a", 1, Some(Duration::from_millis(150)));
268        cache.put("b", 2, None);
269
270        assert_eq!(cache.get(&"a"), Some(&1));
271        assert_eq!(cache.get(&"b"), Some(&2));
272
273        thread::sleep(Duration::from_millis(200));
274
275        assert_eq!(cache.get(&"a"), None);
276        assert_eq!(cache.get(&"b"), Some(&2));
277    }
278
279    #[test]
280    fn test_ttl_expiration() {
281        let mut cache = LruCache::new(2, CleanupMode::OnAccess);
282        cache.put("a", 1, Some(Duration::from_millis(150)));
283        cache.put("b", 2, None);
284
285        assert_eq!(cache.get(&"a"), Some(&1));
286        assert_eq!(cache.get(&"b"), Some(&2));
287
288        thread::sleep(Duration::from_millis(200));
289
290        cache.evict_expired();
291
292        assert_eq!(cache.get(&"a"), None);
293        assert_eq!(cache.get(&"b"), Some(&2));
294    }
295
296    #[test]
297    fn test_lru_eviction() {
298        let mut cache = LruCache::new(3, CleanupMode::OnAccess);
299        cache.put("a", 1, None);
300        cache.put("b", 2, None);
301        cache.put("c", 3, None);
302
303        cache.get(&"a");
304        cache.put("d", 4, None);
305
306        assert_eq!(cache.get(&"b"), None);
307        assert_eq!(cache.get(&"a"), Some(&1));
308        assert_eq!(cache.get(&"c"), Some(&3));
309        assert_eq!(cache.get(&"d"), Some(&4));
310    }
311
312    #[test]
313    fn test_no_memory_leaks() {
314        let mut cache = LruCache::new(2, CleanupMode::OnAccess);
315        for i in 0..1000 {
316            cache.put(i, Box::new([0u8; 1024]), None);
317        }
318        assert_eq!(cache.len(), 2);
319    }
320}