Skip to main content

cache_mod/
lru.rs

1//! Least-Recently-Used (LRU) cache — sharded, arena-backed implementation.
2
3use core::hash::Hash;
4use std::collections::HashMap;
5use std::num::NonZeroUsize;
6
7use crate::cache::Cache;
8use crate::error::CacheError;
9use crate::sharding::{self, Sharded};
10use crate::util::MutexExt;
11
12/// A bounded, thread-safe LRU cache.
13///
14/// On insert overflow the least-recently-accessed entry is evicted. Both
15/// [`get`](Cache::get) and [`insert`](Cache::insert) count as accesses and
16/// promote the affected entry to most-recently-used.
17///
18/// # Implementation (0.7.0)
19///
20/// Sharded into up to 16 independent arenas keyed by hash of `K`. Each shard
21/// owns its own doubly-linked list, free-list, and `HashMap`, with its own
22/// `Mutex<Inner>`. Contention on the lock is bounded by the number of
23/// threads routing into the same shard, not by total cache traffic.
24///
25/// **Eviction is approximate.** Once the cache uses more than one shard,
26/// `insert` overflow evicts the local-to-shard least-recently-used entry,
27/// not the global one. Caches with fewer than 32 entries automatically use
28/// a single shard and retain strict global LRU ordering — this keeps small
29/// caches and test fixtures deterministic.
30///
31/// # Example
32///
33/// ```
34/// use cache_mod::{Cache, LruCache};
35///
36/// let cache: LruCache<u32, &'static str> = LruCache::new(2).expect("capacity > 0");
37///
38/// cache.insert(1, "one");
39/// cache.insert(2, "two");
40/// assert_eq!(cache.get(&1), Some("one")); // 1 is now MRU, 2 is LRU
41///
42/// cache.insert(3, "three"); // evicts 2 (single shard at this size)
43/// assert_eq!(cache.get(&2), None);
44/// assert_eq!(cache.get(&1), Some("one"));
45/// assert_eq!(cache.get(&3), Some("three"));
46/// ```
47pub struct LruCache<K, V> {
48    capacity: NonZeroUsize,
49    sharded: Sharded<Inner<K, V>>,
50}
51
52struct Node<K, V> {
53    key: K,
54    value: V,
55    prev: Option<usize>,
56    next: Option<usize>,
57}
58
59struct Inner<K, V> {
60    /// Per-shard capacity (not the total LruCache capacity).
61    capacity: NonZeroUsize,
62    nodes: Vec<Option<Node<K, V>>>,
63    free: Vec<usize>,
64    head: Option<usize>,
65    tail: Option<usize>,
66    map: HashMap<K, usize>,
67}
68
69impl<K, V> Inner<K, V>
70where
71    K: Eq + Hash + Clone,
72{
73    fn with_capacity(capacity: NonZeroUsize) -> Self {
74        let cap = capacity.get();
75        Self {
76            capacity,
77            nodes: Vec::with_capacity(cap),
78            free: Vec::new(),
79            head: None,
80            tail: None,
81            map: HashMap::with_capacity(cap),
82        }
83    }
84
85    fn alloc(&mut self, node: Node<K, V>) -> usize {
86        if let Some(idx) = self.free.pop() {
87            self.nodes[idx] = Some(node);
88            idx
89        } else {
90            self.nodes.push(Some(node));
91            self.nodes.len() - 1
92        }
93    }
94
95    fn dealloc(&mut self, idx: usize) -> Node<K, V> {
96        let node = self.nodes[idx]
97            .take()
98            .unwrap_or_else(|| unreachable!("arena slot must be occupied"));
99        self.free.push(idx);
100        node
101    }
102
103    fn unlink(&mut self, idx: usize) {
104        let (prev, next) = {
105            let n = self.nodes[idx]
106                .as_ref()
107                .unwrap_or_else(|| unreachable!("unlink target must be occupied"));
108            (n.prev, n.next)
109        };
110        match prev {
111            Some(p) => {
112                self.nodes[p]
113                    .as_mut()
114                    .unwrap_or_else(|| unreachable!())
115                    .next = next
116            }
117            None => self.head = next,
118        }
119        match next {
120            Some(n) => {
121                self.nodes[n]
122                    .as_mut()
123                    .unwrap_or_else(|| unreachable!())
124                    .prev = prev
125            }
126            None => self.tail = prev,
127        }
128        if let Some(n) = self.nodes[idx].as_mut() {
129            n.prev = None;
130            n.next = None;
131        }
132    }
133
134    fn push_front(&mut self, idx: usize) {
135        let old_head = self.head;
136        if let Some(n) = self.nodes[idx].as_mut() {
137            n.prev = None;
138            n.next = old_head;
139        }
140        if let Some(h) = old_head {
141            if let Some(n) = self.nodes[h].as_mut() {
142                n.prev = Some(idx);
143            }
144        } else {
145            self.tail = Some(idx);
146        }
147        self.head = Some(idx);
148    }
149
150    fn promote(&mut self, idx: usize) {
151        if self.head == Some(idx) {
152            return;
153        }
154        self.unlink(idx);
155        self.push_front(idx);
156    }
157}
158
159impl<K, V> LruCache<K, V>
160where
161    K: Eq + Hash + Clone,
162    V: Clone,
163{
164    /// Creates a cache with the given capacity.
165    ///
166    /// Returns [`CacheError::InvalidCapacity`] if `capacity == 0`.
167    ///
168    /// # Example
169    ///
170    /// ```
171    /// use cache_mod::LruCache;
172    ///
173    /// let cache: LruCache<String, u32> = LruCache::new(128).expect("capacity > 0");
174    /// ```
175    pub fn new(capacity: usize) -> Result<Self, CacheError> {
176        let cap = NonZeroUsize::new(capacity).ok_or(CacheError::InvalidCapacity)?;
177        Ok(Self::with_capacity(cap))
178    }
179
180    /// Creates a cache with the given non-zero capacity. Infallible.
181    ///
182    /// # Example
183    ///
184    /// ```
185    /// use std::num::NonZeroUsize;
186    /// use cache_mod::LruCache;
187    ///
188    /// let cap = NonZeroUsize::new(64).expect("64 != 0");
189    /// let cache: LruCache<String, u32> = LruCache::with_capacity(cap);
190    /// ```
191    pub fn with_capacity(capacity: NonZeroUsize) -> Self {
192        let num_shards = sharding::shard_count(capacity);
193        let per_shard = sharding::per_shard_capacity(capacity, num_shards);
194        let sharded = Sharded::from_factory(num_shards, |_| Inner::with_capacity(per_shard));
195        Self { capacity, sharded }
196    }
197}
198
199impl<K, V> Cache<K, V> for LruCache<K, V>
200where
201    K: Eq + Hash + Clone,
202    V: Clone,
203{
204    fn get(&self, key: &K) -> Option<V> {
205        let mut inner = self.sharded.shard_for(key).lock_recover();
206        let idx = *inner.map.get(key)?;
207        inner.promote(idx);
208        inner.nodes[idx].as_ref().map(|n| n.value.clone())
209    }
210
211    fn insert(&self, key: K, value: V) -> Option<V> {
212        let mut inner = self.sharded.shard_for(&key).lock_recover();
213
214        if let Some(&idx) = inner.map.get(&key) {
215            let old = inner.nodes[idx]
216                .as_mut()
217                .map(|n| core::mem::replace(&mut n.value, value))
218                .unwrap_or_else(|| unreachable!("mapped index must be occupied"));
219            inner.promote(idx);
220            return Some(old);
221        }
222
223        // New entry. Evict the LRU tail if at per-shard capacity.
224        if inner.map.len() >= inner.capacity.get() {
225            if let Some(tail_idx) = inner.tail {
226                inner.unlink(tail_idx);
227                let evicted = inner.dealloc(tail_idx);
228                let _ = inner.map.remove(&evicted.key);
229            }
230        }
231
232        let idx = inner.alloc(Node {
233            key: key.clone(),
234            value,
235            prev: None,
236            next: None,
237        });
238        inner.push_front(idx);
239        let _ = inner.map.insert(key, idx);
240        None
241    }
242
243    fn remove(&self, key: &K) -> Option<V> {
244        let mut inner = self.sharded.shard_for(key).lock_recover();
245        let idx = inner.map.remove(key)?;
246        inner.unlink(idx);
247        let node = inner.dealloc(idx);
248        Some(node.value)
249    }
250
251    fn contains_key(&self, key: &K) -> bool {
252        self.sharded
253            .shard_for(key)
254            .lock_recover()
255            .map
256            .contains_key(key)
257    }
258
259    fn len(&self) -> usize {
260        self.sharded
261            .iter()
262            .map(|m| m.lock_recover().map.len())
263            .sum()
264    }
265
266    fn clear(&self) {
267        for mutex in self.sharded.iter() {
268            let mut inner = mutex.lock_recover();
269            inner.nodes.clear();
270            inner.free.clear();
271            inner.head = None;
272            inner.tail = None;
273            inner.map.clear();
274        }
275    }
276
277    fn capacity(&self) -> usize {
278        self.capacity.get()
279    }
280}