1use crate::Stats;
2use parking_lot::{Mutex, RwLock};
3use shard::Shard;
4use std::borrow::Borrow;
5use std::hash::{BuildHasher, Hash};
6use std::num::NonZero;
7use std::time::Instant;
8use std::{cmp, thread};
9
10mod entry;
11mod fixed_size_hash_table;
12mod ring_buffer;
13mod shard;
14pub(crate) mod stats;
15
16pub(crate) type RandomState = ahash::RandomState;
17
18#[derive(Debug)]
28pub struct Cache<K, V, S = RandomState> {
29 hash_builder: S,
30 shards: Vec<RwLock<Shard<K, V, S>>>,
31 metrics_last_accessed: Mutex<Instant>,
32}
33
34impl<K, V> Cache<K, V, RandomState>
35where
36 K: Clone + Eq + Hash,
37 V: Clone,
38{
39 pub fn with_capacity(capacity: usize) -> Cache<K, V, RandomState> {
43 Cache::with_capacity_and_hasher(capacity, Default::default())
44 }
45}
46
47impl<K, V, S> Cache<K, V, S>
48where
49 K: Clone + Eq + Hash,
50 V: Clone,
51 S: BuildHasher,
52{
53 pub fn insert(&self, key: K, value: V) -> Option<V> {
59 let hash = self.hash_builder.hash_one(&key);
60 let shard_lock = self.get_shard(hash)?;
61
62 let mut shard = shard_lock.write();
63 shard.insert(key, value)
64 }
65
66 pub fn get<Q>(&self, key: &Q) -> Option<V>
71 where
72 K: Borrow<Q>,
73 Q: ?Sized + Hash + Eq,
74 {
75 let hash = self.hash_builder.hash_one(key);
76 let shard_lock = self.get_shard(hash)?;
77
78 let shard = shard_lock.read();
79 shard.get(key)
80 }
81
82 fn get_shard(&self, hash: u64) -> Option<&RwLock<Shard<K, V, S>>> {
83 let shard_idx = hash as usize % (cmp::max(self.shards.len(), 2) - 1);
84 self.shards.get(shard_idx)
85 }
86}
87
88impl<K, V, S> Cache<K, V, S>
89where
90 K: Clone + Eq + Hash,
91 V: Clone,
92 S: Clone + BuildHasher,
93{
94 pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Cache<K, V, S> {
99 let available_parallelism = thread::available_parallelism()
100 .map(NonZero::get)
101 .unwrap_or(1);
102
103 let number_of_shards = cmp::min(available_parallelism * 4, capacity);
104
105 let mut shards = Vec::with_capacity(number_of_shards);
106
107 let metrics_last_accessed = Mutex::new(Instant::now());
108
109 if number_of_shards == 0 {
110 return Self {
111 hash_builder,
112 shards,
113 metrics_last_accessed,
114 };
115 }
116
117 let capacity_per_shard = capacity.div_ceil(number_of_shards);
118
119 for _ in 0..number_of_shards {
120 let shard = Shard::with_capacity_and_hasher(capacity_per_shard, hash_builder.clone());
121 shards.push(RwLock::new(shard))
122 }
123
124 Self {
125 hash_builder,
126 shards,
127 metrics_last_accessed,
128 }
129 }
130}
131
132impl<K, V, S> Cache<K, V, S> {
133 pub fn stats(&self) -> Stats {
176 let mut stats = Stats::default();
177
178 let millis_elapsed = {
179 let mut guard = self.metrics_last_accessed.lock();
180 let millis_elapsed = guard.elapsed().as_millis();
181 *guard = Instant::now();
182 millis_elapsed
183 };
184
185 stats.millis_elapsed = millis_elapsed;
186
187 for shard in &self.shards {
188 let shard = shard.read();
189 stats.hit_count += shard.hit_count();
190 stats.miss_count += shard.miss_count();
191 stats.eviction_count += shard.eviction_count();
192 shard.reset_counters();
193 }
194
195 stats
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use std::sync::Arc;
203 use std::thread;
204
205 #[test]
206 fn it_inserts_and_gets_basic_values() {
207 let cache = Cache::with_capacity(100);
209
210 cache.insert("key1", "value1");
212
213 assert_eq!(cache.get("key1"), Some("value1"));
215 assert_eq!(cache.get("key2"), None);
216 }
217
218 #[test]
219 fn it_updates_existing_value() {
220 let cache = Cache::with_capacity(100);
222 cache.insert("key1", "value1");
223
224 let old_value = cache.insert("key1", "new_value");
226
227 assert_eq!(old_value, Some("value1"));
229 assert_eq!(cache.get("key1"), Some("new_value"));
230 }
231
232 #[test]
233 fn it_handles_zero_capacity() {
234 let cache = Cache::with_capacity(0);
236
237 cache.insert("key1", "value1");
239
240 assert_eq!(cache.get("key1"), None);
242 }
243
244 #[test]
245 fn it_handles_one_capacity() {
246 let cache = Cache::with_capacity(1);
248
249 cache.insert("key1", "value1");
251
252 assert_eq!(cache.get("key1"), Some("value1"));
254 assert_eq!(cache.get("key2"), None);
255 }
256
257 #[test]
258 fn it_works_with_custom_hasher() {
259 use std::collections::hash_map::RandomState;
261 let cache = Cache::with_capacity_and_hasher(100, RandomState::new());
262
263 cache.insert("key1", "value1");
265
266 assert_eq!(cache.get("key1"), Some("value1"));
268 }
269
270 #[test]
271 fn it_is_thread_safe() {
272 let cache: Arc<Cache<String, String>> = Arc::new(Cache::with_capacity(1_000));
274 let mut handles = vec![];
275
276 for i in 0..5 {
278 let cache_clone = Arc::clone(&cache);
279 let key = format!("key{}", i);
280 let value = format!("value{}", i);
281 let handle = thread::spawn(move || {
282 cache_clone.insert(key.clone(), value.clone());
283 assert_eq!(cache_clone.get(&key), Some(value));
284 });
285 handles.push(handle);
286 }
287
288 for handle in handles {
290 handle.join().unwrap();
291 }
292
293 for i in 0..5 {
295 let key = format!("key{}", i);
296 let value = format!("value{}", i);
297 assert_eq!(cache.get(&key), Some(value));
298 }
299 }
300
301 #[test]
302 fn it_respects_capacity_limits() {
303 let cache = Cache::with_capacity(2);
305
306 cache.insert("key1", "value1");
308 cache.insert("key2", "value2");
309 cache.insert("key3", "value3");
310 cache.insert("key4", "value4");
311
312 assert_eq!(cache.get("key1"), None);
314 }
315
316 #[test]
317 fn it_returns_and_resets_stats() {
318 let cache = Cache::with_capacity(1_000);
320
321 for i in 0..10 {
323 cache.insert(i, i);
324 }
325
326 for i in 0..5 {
328 cache.get(&i);
329 }
330
331 for i in 10..15 {
333 cache.get(&i);
334 }
335
336 let stats = cache.stats();
338 assert_eq!(stats.hit_count, 5);
339 assert_eq!(stats.miss_count, 5);
340
341 let stats = cache.stats();
342 assert_eq!(stats.hit_count, 0);
343 assert_eq!(stats.miss_count, 0);
344 }
345}