oxistore_cache/
sharded.rs1use std::collections::hash_map::DefaultHasher;
9use std::hash::{Hash, Hasher};
10use std::sync::Mutex;
11
12use crate::{Cache, LruCache};
13
14pub struct ShardedCache {
23 shards: Vec<Mutex<LruCache<Vec<u8>, Vec<u8>>>>,
24 mask: usize,
26 shard_cap: usize,
28}
29
30impl ShardedCache {
31 pub fn new(n_shards: usize, shard_cap: usize) -> Self {
40 assert!(
41 n_shards > 0 && n_shards.is_power_of_two(),
42 "n_shards must be a positive power of two, got {n_shards}"
43 );
44 let shards = (0..n_shards)
45 .map(|_| Mutex::new(LruCache::new(shard_cap)))
46 .collect();
47 ShardedCache {
48 shards,
49 mask: n_shards - 1,
50 shard_cap,
51 }
52 }
53
54 #[must_use]
56 pub fn n_shards(&self) -> usize {
57 self.shards.len()
58 }
59
60 #[must_use]
62 pub fn shard_cap(&self) -> usize {
63 self.shard_cap
64 }
65
66 fn shard_index(&self, key: &[u8]) -> usize {
68 let mut h = DefaultHasher::new();
69 key.hash(&mut h);
70 (h.finish() as usize) & self.mask
71 }
72
73 fn shard(&self, key: &[u8]) -> std::sync::MutexGuard<'_, LruCache<Vec<u8>, Vec<u8>>> {
79 let idx = self.shard_index(key);
80 self.shards[idx].lock().expect("shard mutex poisoned")
81 }
82
83 pub fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
87 self.shard(key).get(&key.to_vec()).cloned()
88 }
89
90 pub fn put(&self, key: Vec<u8>, value: Vec<u8>) {
92 self.shard(&key).put(key, value);
93 }
94
95 pub fn remove(&self, key: &[u8]) -> Option<Vec<u8>> {
97 self.shard(key).remove(&key.to_vec())
98 }
99
100 pub fn contains(&self, key: &[u8]) -> bool {
102 self.shard(key).contains_key(&key.to_vec())
103 }
104
105 #[must_use]
107 pub fn len(&self) -> usize {
108 self.shards
109 .iter()
110 .map(|s| s.lock().expect("shard mutex poisoned").len())
111 .sum()
112 }
113
114 #[must_use]
116 pub fn is_empty(&self) -> bool {
117 self.len() == 0
118 }
119
120 pub fn clear(&self) {
122 for shard in &self.shards {
123 shard.lock().expect("shard mutex poisoned").clear();
124 }
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use std::sync::Arc;
132 use std::thread;
133
134 #[test]
135 #[should_panic]
136 fn sharded_panics_on_non_power_of_two() {
137 let _ = ShardedCache::new(3, 10);
138 }
139
140 #[test]
141 fn sharded_basic_put_get() {
142 let cache = ShardedCache::new(4, 16);
143 cache.put(b"hello".to_vec(), b"world".to_vec());
144 assert_eq!(cache.get(b"hello"), Some(b"world".to_vec()));
145 assert!(cache.get(b"missing").is_none());
146 }
147
148 #[test]
149 fn sharded_remove() {
150 let cache = ShardedCache::new(4, 16);
151 cache.put(b"k".to_vec(), b"v".to_vec());
152 assert!(cache.contains(b"k"));
153 let v = cache.remove(b"k");
154 assert_eq!(v, Some(b"v".to_vec()));
155 assert!(!cache.contains(b"k"));
156 }
157
158 #[test]
159 fn sharded_len_and_clear() {
160 let cache = ShardedCache::new(4, 16);
161 cache.put(b"a".to_vec(), b"1".to_vec());
162 cache.put(b"b".to_vec(), b"2".to_vec());
163 assert_eq!(cache.len(), 2);
164 cache.clear();
165 assert_eq!(cache.len(), 0);
166 assert!(cache.is_empty());
167 }
168
169 #[test]
170 fn sharded_concurrent_puts() {
171 let cache = Arc::new(ShardedCache::new(8, 256));
172 let n_threads = 8;
173 let keys_per_thread = 32;
174
175 let handles: Vec<_> = (0..n_threads)
176 .map(|t| {
177 let cache = Arc::clone(&cache);
178 thread::spawn(move || {
179 for i in 0..keys_per_thread {
180 let key = format!("thread{t}_key{i}").into_bytes();
181 let val = format!("val{i}").into_bytes();
182 cache.put(key, val);
183 }
184 })
185 })
186 .collect();
187
188 for h in handles {
189 h.join().expect("thread panicked");
190 }
191
192 for t in 0..n_threads {
194 for i in 0..keys_per_thread {
195 let key = format!("thread{t}_key{i}").into_bytes();
196 let expected = format!("val{i}").into_bytes();
197 assert_eq!(
198 cache.get(&key),
199 Some(expected),
200 "missing key thread{t}_key{i}"
201 );
202 }
203 }
204 }
205}