1use crate::Stats;
2use parking_lot::{Mutex, RwLock};
3use shard::Shard;
4use std::borrow::Borrow;
5use std::hash::{BuildHasher, Hash};
6use std::num::NonZero;
7use std::time::Instant;
8use std::{cmp, thread};
9
10mod entry;
11mod fixed_size_hash_table;
12mod ring_buffer;
13mod shard;
14pub(crate) mod stats;
15
16pub(crate) type RandomState = ahash::RandomState;
17
18#[derive(Debug)]
28pub struct Cache<K, V, S = RandomState> {
29 hash_builder: S,
30 shards: Vec<RwLock<Shard<K, V, S>>>,
31 metrics_last_accessed: Mutex<Instant>,
32}
33
34impl<K, V> Cache<K, V, RandomState>
35where
36 K: Clone + Eq + Hash,
37 V: Clone,
38{
39 pub fn with_capacity(capacity: usize) -> Cache<K, V, RandomState> {
43 Cache::with_capacity_and_hasher(capacity, Default::default())
44 }
45}
46
47impl<K, V, S> Cache<K, V, S>
48where
49 K: Clone + Eq + Hash,
50 V: Clone,
51 S: BuildHasher,
52{
53 pub fn insert(&self, key: K, value: V) -> Option<V> {
59 let hash = self.hash_builder.hash_one(&key);
60 let shard_lock = self.get_shard(hash)?;
61
62 let mut shard = shard_lock.write();
63 shard.insert(key, value)
64 }
65
66 pub fn get<Q>(&self, key: &Q) -> Option<V>
71 where
72 K: Borrow<Q>,
73 Q: ?Sized + Hash + Eq,
74 {
75 let hash = self.hash_builder.hash_one(key);
76 let shard_lock = self.get_shard(hash)?;
77
78 let shard = shard_lock.read();
79 shard.get(key)
80 }
81
82 fn get_shard(&self, hash: u64) -> Option<&RwLock<Shard<K, V, S>>> {
83 let shard_idx = hash as usize % (cmp::max(self.shards.len(), 2) - 1);
84 self.shards.get(shard_idx)
85 }
86}
87
88impl<K, V, S> Cache<K, V, S>
89where
90 K: Clone + Eq + Hash,
91 V: Clone,
92 S: Clone + BuildHasher,
93{
94 pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Cache<K, V, S> {
99 let available_parallelism = thread::available_parallelism()
100 .map(NonZero::get)
101 .unwrap_or(1);
102
103 let number_of_shards = cmp::min(available_parallelism * 4, capacity);
104
105 let mut shards = Vec::with_capacity(number_of_shards);
106
107 let metrics_last_accessed = Mutex::new(Instant::now());
108
109 if number_of_shards == 0 {
110 return Self {
111 hash_builder,
112 shards,
113 metrics_last_accessed,
114 };
115 }
116
117 let capacity_per_shard = capacity.div_ceil(number_of_shards);
118
119 for _ in 0..number_of_shards {
120 let shard = Shard::with_capacity_and_hasher(capacity_per_shard, hash_builder.clone());
121 shards.push(RwLock::new(shard))
122 }
123
124 Self {
125 hash_builder,
126 shards,
127 metrics_last_accessed,
128 }
129 }
130}
131
132impl<K, V, S> Cache<K, V, S> {
133 pub fn stats(&self) -> Stats {
174 let mut stats = Stats::default();
175
176 let millis_elapsed = {
177 let mut guard = self.metrics_last_accessed.lock();
178 let millis_elapsed = guard.elapsed().as_millis();
179 *guard = Instant::now();
180 millis_elapsed
181 };
182
183 stats.millis_elapsed = millis_elapsed;
184
185 for shard in &self.shards {
186 let shard = shard.read();
187 stats.hit_count += shard.hit_count();
188 stats.miss_count += shard.miss_count();
189 stats.eviction_count += shard.eviction_count();
190 shard.reset_counters();
191 }
192
193 stats
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use std::sync::Arc;
201 use std::thread;
202
203 #[test]
204 fn it_inserts_and_gets_basic_values() {
205 let cache = Cache::with_capacity(100);
207
208 cache.insert("key1", "value1");
210
211 assert_eq!(cache.get("key1"), Some("value1"));
213 assert_eq!(cache.get("key2"), None);
214 }
215
216 #[test]
217 fn it_updates_existing_value() {
218 let cache = Cache::with_capacity(100);
220 cache.insert("key1", "value1");
221
222 let old_value = cache.insert("key1", "new_value");
224
225 assert_eq!(old_value, Some("value1"));
227 assert_eq!(cache.get("key1"), Some("new_value"));
228 }
229
230 #[test]
231 fn it_handles_zero_capacity() {
232 let cache = Cache::with_capacity(0);
234
235 cache.insert("key1", "value1");
237
238 assert_eq!(cache.get("key1"), None);
240 }
241
242 #[test]
243 fn it_handles_one_capacity() {
244 let cache = Cache::with_capacity(1);
246
247 cache.insert("key1", "value1");
249
250 assert_eq!(cache.get("key1"), Some("value1"));
252 assert_eq!(cache.get("key2"), None);
253 }
254
255 #[test]
256 fn it_works_with_custom_hasher() {
257 use std::collections::hash_map::RandomState;
259 let cache = Cache::with_capacity_and_hasher(100, RandomState::new());
260
261 cache.insert("key1", "value1");
263
264 assert_eq!(cache.get("key1"), Some("value1"));
266 }
267
268 #[test]
269 fn it_is_thread_safe() {
270 let cache: Arc<Cache<String, String>> = Arc::new(Cache::with_capacity(1_000));
272 let mut handles = vec![];
273
274 for i in 0..5 {
276 let cache_clone = Arc::clone(&cache);
277 let key = format!("key{}", i);
278 let value = format!("value{}", i);
279 let handle = thread::spawn(move || {
280 cache_clone.insert(key.clone(), value.clone());
281 assert_eq!(cache_clone.get(&key), Some(value));
282 });
283 handles.push(handle);
284 }
285
286 for handle in handles {
288 handle.join().unwrap();
289 }
290
291 for i in 0..5 {
293 let key = format!("key{}", i);
294 let value = format!("value{}", i);
295 assert_eq!(cache.get(&key), Some(value));
296 }
297 }
298
299 #[test]
300 fn it_respects_capacity_limits() {
301 let cache = Cache::with_capacity(2);
303
304 cache.insert("key1", "value1");
306 cache.insert("key2", "value2");
307 cache.insert("key3", "value3");
308 cache.insert("key4", "value4");
309
310 assert_eq!(cache.get("key1"), None);
312 }
313
314 #[test]
315 fn it_returns_and_resets_stats() {
316 let cache = Cache::with_capacity(1_000);
318
319 for i in 0..10 {
321 cache.insert(i, i);
322 }
323
324 for i in 0..5 {
326 cache.get(&i);
327 }
328
329 for i in 10..15 {
331 cache.get(&i);
332 }
333
334 let stats = cache.stats();
336 assert_eq!(stats.hit_count, 5);
337 assert_eq!(stats.miss_count, 5);
338
339 let stats = cache.stats();
340 assert_eq!(stats.hit_count, 0);
341 assert_eq!(stats.miss_count, 0);
342 }
343}