Skip to main content

oxistore_cache/
sharded.rs

1//! Sharded concurrent cache.
2//!
3//! [`ShardedCache`] wraps `N` independent [`LruCache`] shards behind a
4//! `Mutex` each, reducing lock contention under parallel workloads.
5//! `N` must be a power of two so that routing can use a fast bitmask instead
6//! of a modulo operation.
7
8use std::collections::hash_map::DefaultHasher;
9use std::hash::{Hash, Hasher};
10use std::sync::Mutex;
11
12use crate::{Cache, LruCache};
13
14/// A concurrent cache backed by N power-of-two LRU shards.
15///
16/// Keys are routed to shards via `hash(key) & (n_shards - 1)`.  Each shard
17/// is independently protected by a `Mutex<LruCache<Vec<u8>, Vec<u8>>>`.
18///
19/// # Panics
20///
21/// `new` panics if `n_shards` is not a power of two, or if `n_shards` is 0.
22pub struct ShardedCache {
23    shards: Vec<Mutex<LruCache<Vec<u8>, Vec<u8>>>>,
24    /// Bitmask = n_shards - 1 (valid because n_shards is power of 2).
25    mask: usize,
26    /// Capacity per shard.
27    shard_cap: usize,
28}
29
30impl ShardedCache {
31    /// Create a new sharded cache.
32    ///
33    /// - `n_shards`: number of shards — must be a power of two.
34    /// - `shard_cap`: capacity per shard (entry count).
35    ///
36    /// # Panics
37    ///
38    /// Panics if `n_shards == 0` or `!n_shards.is_power_of_two()`.
39    pub fn new(n_shards: usize, shard_cap: usize) -> Self {
40        assert!(
41            n_shards > 0 && n_shards.is_power_of_two(),
42            "n_shards must be a positive power of two, got {n_shards}"
43        );
44        let shards = (0..n_shards)
45            .map(|_| Mutex::new(LruCache::new(shard_cap)))
46            .collect();
47        ShardedCache {
48            shards,
49            mask: n_shards - 1,
50            shard_cap,
51        }
52    }
53
54    /// Return the number of shards.
55    #[must_use]
56    pub fn n_shards(&self) -> usize {
57        self.shards.len()
58    }
59
60    /// Capacity per shard.
61    #[must_use]
62    pub fn shard_cap(&self) -> usize {
63        self.shard_cap
64    }
65
66    /// Hash a key to its shard index.
67    fn shard_index(&self, key: &[u8]) -> usize {
68        let mut h = DefaultHasher::new();
69        key.hash(&mut h);
70        (h.finish() as usize) & self.mask
71    }
72
73    /// Acquire the shard mutex for `key`.
74    ///
75    /// # Panics
76    ///
77    /// Panics if the mutex is poisoned.
78    fn shard(&self, key: &[u8]) -> std::sync::MutexGuard<'_, LruCache<Vec<u8>, Vec<u8>>> {
79        let idx = self.shard_index(key);
80        self.shards[idx].lock().expect("shard mutex poisoned")
81    }
82
83    /// Look up `key`, returning a cloned copy of the value if present.
84    ///
85    /// Returns `None` if the key is absent or expired.
86    pub fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
87        self.shard(key).get(&key.to_vec()).cloned()
88    }
89
90    /// Insert or update `key` -> `value`.
91    pub fn put(&self, key: Vec<u8>, value: Vec<u8>) {
92        self.shard(&key).put(key, value);
93    }
94
95    /// Remove `key`, returning its value if present.
96    pub fn remove(&self, key: &[u8]) -> Option<Vec<u8>> {
97        self.shard(key).remove(&key.to_vec())
98    }
99
100    /// Return `true` if `key` is present and not expired.
101    pub fn contains(&self, key: &[u8]) -> bool {
102        self.shard(key).contains_key(&key.to_vec())
103    }
104
105    /// Return the total number of live entries across all shards.
106    #[must_use]
107    pub fn len(&self) -> usize {
108        self.shards
109            .iter()
110            .map(|s| s.lock().expect("shard mutex poisoned").len())
111            .sum()
112    }
113
114    /// Return `true` if no entries are in any shard.
115    #[must_use]
116    pub fn is_empty(&self) -> bool {
117        self.len() == 0
118    }
119
120    /// Clear all shards.
121    pub fn clear(&self) {
122        for shard in &self.shards {
123            shard.lock().expect("shard mutex poisoned").clear();
124        }
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use std::sync::Arc;
132    use std::thread;
133
134    #[test]
135    #[should_panic]
136    fn sharded_panics_on_non_power_of_two() {
137        let _ = ShardedCache::new(3, 10);
138    }
139
140    #[test]
141    fn sharded_basic_put_get() {
142        let cache = ShardedCache::new(4, 16);
143        cache.put(b"hello".to_vec(), b"world".to_vec());
144        assert_eq!(cache.get(b"hello"), Some(b"world".to_vec()));
145        assert!(cache.get(b"missing").is_none());
146    }
147
148    #[test]
149    fn sharded_remove() {
150        let cache = ShardedCache::new(4, 16);
151        cache.put(b"k".to_vec(), b"v".to_vec());
152        assert!(cache.contains(b"k"));
153        let v = cache.remove(b"k");
154        assert_eq!(v, Some(b"v".to_vec()));
155        assert!(!cache.contains(b"k"));
156    }
157
158    #[test]
159    fn sharded_len_and_clear() {
160        let cache = ShardedCache::new(4, 16);
161        cache.put(b"a".to_vec(), b"1".to_vec());
162        cache.put(b"b".to_vec(), b"2".to_vec());
163        assert_eq!(cache.len(), 2);
164        cache.clear();
165        assert_eq!(cache.len(), 0);
166        assert!(cache.is_empty());
167    }
168
169    #[test]
170    fn sharded_concurrent_puts() {
171        let cache = Arc::new(ShardedCache::new(8, 256));
172        let n_threads = 8;
173        let keys_per_thread = 32;
174
175        let handles: Vec<_> = (0..n_threads)
176            .map(|t| {
177                let cache = Arc::clone(&cache);
178                thread::spawn(move || {
179                    for i in 0..keys_per_thread {
180                        let key = format!("thread{t}_key{i}").into_bytes();
181                        let val = format!("val{i}").into_bytes();
182                        cache.put(key, val);
183                    }
184                })
185            })
186            .collect();
187
188        for h in handles {
189            h.join().expect("thread panicked");
190        }
191
192        // All values should be retrievable.
193        for t in 0..n_threads {
194            for i in 0..keys_per_thread {
195                let key = format!("thread{t}_key{i}").into_bytes();
196                let expected = format!("val{i}").into_bytes();
197                assert_eq!(
198                    cache.get(&key),
199                    Some(expected),
200                    "missing key thread{t}_key{i}"
201                );
202            }
203        }
204    }
205}