1use ahash::RandomState;
16use std::borrow::Borrow;
17use std::hash::Hash;
18use std::marker::PhantomData;
19use std::time::{Duration, Instant};
20
21use tinyufo::TinyUfo;
22
23mod read_through;
24pub use read_through::{Lookup, MultiLookup, RTCache};
25
26#[derive(Debug, PartialEq, Eq)]
27pub enum CacheStatus {
29 Hit,
31 Miss,
33 Expired,
35 LockHit,
37 Stale(Duration),
40}
41
42impl CacheStatus {
43 pub fn as_str(&self) -> &str {
45 match self {
46 Self::Hit => "hit",
47 Self::Miss => "miss",
48 Self::Expired => "expired",
49 Self::LockHit => "lock_hit",
50 Self::Stale(_) => "stale",
51 }
52 }
53
54 pub fn is_hit(&self) -> bool {
56 match self {
57 CacheStatus::Hit | CacheStatus::LockHit | CacheStatus::Stale(_) => true,
58 CacheStatus::Miss | CacheStatus::Expired => false,
59 }
60 }
61
62 pub fn stale(&self) -> Option<Duration> {
64 match self {
65 CacheStatus::Stale(time) => Some(*time),
66 _ => None,
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
72struct Node<T: Clone> {
73 pub value: T,
74 expire_on: Option<Instant>,
75}
76
77impl<T: Clone> Node<T> {
78 fn new(value: T, ttl: Option<Duration>) -> Self {
79 let expire_on = match ttl {
80 Some(t) => Instant::now().checked_add(t),
81 None => None,
82 };
83 Node { value, expire_on }
84 }
85
86 fn will_expire_at(&self, time: &Instant) -> bool {
87 self.stale_duration(time).is_some()
88 }
89
90 fn is_expired(&self) -> bool {
91 self.will_expire_at(&Instant::now())
92 }
93
94 fn stale_duration(&self, time: &Instant) -> Option<Duration> {
95 let expire_time = self.expire_on?;
96 if &expire_time <= time {
97 Some(time.duration_since(expire_time))
98 } else {
99 None
100 }
101 }
102}
103
104pub struct MemoryCache<K: Hash, T: Clone> {
106 store: TinyUfo<u64, Node<T>>,
107 _key_type: PhantomData<K>,
108 pub(crate) hasher: RandomState,
109}
110
111impl<K: Hash, T: Clone + Send + Sync + 'static> MemoryCache<K, T> {
112 pub fn new(size: usize) -> Self {
114 MemoryCache {
115 store: TinyUfo::new(size, size),
116 _key_type: PhantomData,
117 hasher: RandomState::new(),
118 }
119 }
120
121 pub fn get<Q>(&self, key: &Q) -> (Option<T>, CacheStatus)
123 where
124 K: Borrow<Q>,
125 Q: Hash + ?Sized,
126 {
127 let hashed_key = self.hasher.hash_one(key);
128
129 if let Some(n) = self.store.get(&hashed_key) {
130 if !n.is_expired() {
131 (Some(n.value), CacheStatus::Hit)
132 } else {
133 (None, CacheStatus::Expired)
134 }
135 } else {
136 (None, CacheStatus::Miss)
137 }
138 }
139
140 pub fn get_stale<Q>(&self, key: &Q) -> (Option<T>, CacheStatus)
145 where
146 K: Borrow<Q>,
147 Q: Hash + ?Sized,
148 {
149 let hashed_key = self.hasher.hash_one(key);
150
151 if let Some(n) = self.store.get(&hashed_key) {
152 let stale_duration = n.stale_duration(&Instant::now());
153 if let Some(stale_duration) = stale_duration {
154 (Some(n.value), CacheStatus::Stale(stale_duration))
155 } else {
156 (Some(n.value), CacheStatus::Hit)
157 }
158 } else {
159 (None, CacheStatus::Miss)
160 }
161 }
162
163 pub fn put<Q>(&self, key: &Q, value: T, ttl: Option<Duration>)
167 where
168 K: Borrow<Q>,
169 Q: Hash + ?Sized,
170 {
171 if let Some(t) = ttl {
172 if t.is_zero() {
173 return;
174 }
175 }
176 let hashed_key = self.hasher.hash_one(key);
177 let node = Node::new(value, ttl);
178 self.store.put(hashed_key, node, 1);
180 }
181
182 pub fn remove<Q>(&self, key: &Q)
184 where
185 K: Borrow<Q>,
186 Q: Hash + ?Sized,
187 {
188 let hashed_key = self.hasher.hash_one(key);
189 self.store.remove(&hashed_key);
190 }
191
192 pub(crate) fn force_put(&self, key: &K, value: T, ttl: Option<Duration>) {
193 if let Some(t) = ttl {
194 if t.is_zero() {
195 return;
196 }
197 }
198 let hashed_key = self.hasher.hash_one(key);
199 let node = Node::new(value, ttl);
200 self.store.force_put(hashed_key, node, 1);
202 }
203
204 pub fn multi_get<'a, I, Q>(&self, keys: I) -> Vec<(Option<T>, CacheStatus)>
206 where
207 I: Iterator<Item = &'a Q>,
208 Q: Hash + ?Sized + 'a,
209 K: Borrow<Q> + 'a,
210 {
211 let mut resp = Vec::with_capacity(keys.size_hint().0);
212 for key in keys {
213 resp.push(self.get(key));
214 }
215 resp
216 }
217
218 pub fn multi_get_with_miss<'a, I, Q>(
220 &self,
221 keys: I,
222 ) -> (Vec<(Option<T>, CacheStatus)>, Vec<&'a Q>)
223 where
224 I: Iterator<Item = &'a Q>,
225 Q: Hash + ?Sized + 'a,
226 K: Borrow<Q> + 'a,
227 {
228 let mut resp = Vec::with_capacity(keys.size_hint().0);
229 let mut missed = Vec::with_capacity(keys.size_hint().0 / 2);
230 for key in keys {
231 let (lookup, cache_status) = self.get(key);
232 if lookup.is_none() {
233 missed.push(key);
234 }
235 resp.push((lookup, cache_status));
236 }
237 (resp, missed)
238 }
239
240 }
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use std::thread::sleep;
247
248 #[test]
249 fn test_get() {
250 let cache: MemoryCache<i32, ()> = MemoryCache::new(10);
251 let (res, hit) = cache.get(&1);
252 assert_eq!(res, None);
253 assert_eq!(hit, CacheStatus::Miss);
254 }
255
256 #[test]
257 fn test_put_get() {
258 let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
259 let (res, hit) = cache.get(&1);
260 assert_eq!(res, None);
261 assert_eq!(hit, CacheStatus::Miss);
262 cache.put(&1, 2, None);
263 let (res, hit) = cache.get(&1);
264 assert_eq!(res.unwrap(), 2);
265 assert_eq!(hit, CacheStatus::Hit);
266 }
267
268 #[test]
269 fn test_put_get_remove() {
270 let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
271 let (res, hit) = cache.get(&1);
272 assert_eq!(res, None);
273 assert_eq!(hit, CacheStatus::Miss);
274 cache.put(&1, 2, None);
275 cache.put(&3, 4, None);
276 cache.put(&5, 6, None);
277 let (res, hit) = cache.get(&1);
278 assert_eq!(res.unwrap(), 2);
279 assert_eq!(hit, CacheStatus::Hit);
280 cache.remove(&1);
281 cache.remove(&3);
282 let (res, hit) = cache.get(&1);
283 assert_eq!(res, None);
284 assert_eq!(hit, CacheStatus::Miss);
285 let (res, hit) = cache.get(&3);
286 assert_eq!(res, None);
287 assert_eq!(hit, CacheStatus::Miss);
288 let (res, hit) = cache.get(&5);
289 assert_eq!(res.unwrap(), 6);
290 assert_eq!(hit, CacheStatus::Hit);
291 }
292
293 #[test]
294 fn test_get_expired() {
295 let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
296 let (res, hit) = cache.get(&1);
297 assert_eq!(res, None);
298 assert_eq!(hit, CacheStatus::Miss);
299 cache.put(&1, 2, Some(Duration::from_secs(1)));
300 sleep(Duration::from_millis(1100));
301 let (res, hit) = cache.get(&1);
302 assert_eq!(res, None);
303 assert_eq!(hit, CacheStatus::Expired);
304 }
305
306 #[test]
307 fn test_get_stale() {
308 let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
309 let (res, hit) = cache.get(&1);
310 assert_eq!(res, None);
311 assert_eq!(hit, CacheStatus::Miss);
312 cache.put(&1, 2, Some(Duration::from_secs(1)));
313 sleep(Duration::from_millis(1100));
314 let (res, hit) = cache.get_stale(&1);
315 assert_eq!(res.unwrap(), 2);
316 assert!(hit.stale().unwrap() >= Duration::from_millis(100));
318 }
319
320 #[test]
321 fn test_eviction() {
322 let cache: MemoryCache<i32, i32> = MemoryCache::new(2);
323 cache.put(&1, 2, None);
324 cache.put(&2, 4, None);
325 cache.put(&3, 6, None);
326 let (res, hit) = cache.get(&1);
327 assert_eq!(res, None);
328 assert_eq!(hit, CacheStatus::Miss);
329 let (res, hit) = cache.get(&2);
330 assert_eq!(res.unwrap(), 4);
331 assert_eq!(hit, CacheStatus::Hit);
332 let (res, hit) = cache.get(&3);
333 assert_eq!(res.unwrap(), 6);
334 assert_eq!(hit, CacheStatus::Hit);
335 }
336
337 #[test]
338 fn test_multi_get() {
339 let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
340 cache.put(&2, -2, None);
341 let keys: Vec<i32> = vec![1, 2, 3];
342 let resp = cache.multi_get(keys.iter());
343 assert_eq!(resp[0].0, None);
344 assert_eq!(resp[0].1, CacheStatus::Miss);
345 assert_eq!(resp[1].0.unwrap(), -2);
346 assert_eq!(resp[1].1, CacheStatus::Hit);
347 assert_eq!(resp[2].0, None);
348 assert_eq!(resp[2].1, CacheStatus::Miss);
349
350 let (resp, missed) = cache.multi_get_with_miss(keys.iter());
351 assert_eq!(resp[0].0, None);
352 assert_eq!(resp[0].1, CacheStatus::Miss);
353 assert_eq!(resp[1].0.unwrap(), -2);
354 assert_eq!(resp[1].1, CacheStatus::Hit);
355 assert_eq!(resp[2].0, None);
356 assert_eq!(resp[2].1, CacheStatus::Miss);
357 assert_eq!(missed[0], &1);
358 assert_eq!(missed[1], &3);
359 }
360
361 #[test]
362 fn test_get_with_mismatched_key() {
363 let cache: MemoryCache<String, ()> = MemoryCache::new(10);
364 let (res, hit) = cache.get("Hello");
365 assert_eq!(res, None);
366 assert_eq!(hit, CacheStatus::Miss);
367 }
368
369 #[test]
370 fn test_put_get_with_mismatched_key() {
371 let cache: MemoryCache<String, i32> = MemoryCache::new(10);
372 let (res, hit) = cache.get("1");
373 assert_eq!(res, None);
374 assert_eq!(hit, CacheStatus::Miss);
375 cache.put("1", 2, None);
376 let (res, hit) = cache.get("1");
377 assert_eq!(res.unwrap(), 2);
378 assert_eq!(hit, CacheStatus::Hit);
379 }
380}