1use ahash::RandomState;
16use std::hash::Hash;
17use std::marker::PhantomData;
18use std::time::{Duration, Instant};
19
20use tinyufo::TinyUfo;
21
22mod read_through;
23pub use read_through::{Lookup, MultiLookup, RTCache};
24
25#[derive(Debug, PartialEq, Eq)]
26pub enum CacheStatus {
28 Hit,
30 Miss,
32 Expired,
34 LockHit,
36}
37
38impl CacheStatus {
39 pub fn as_str(&self) -> &str {
41 match self {
42 Self::Hit => "hit",
43 Self::Miss => "miss",
44 Self::Expired => "expired",
45 Self::LockHit => "lock_hit",
46 }
47 }
48
49 pub fn is_hit(&self) -> bool {
51 match self {
52 CacheStatus::Hit | CacheStatus::LockHit => true,
53 CacheStatus::Miss | CacheStatus::Expired => false,
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
59struct Node<T: Clone> {
60 pub value: T,
61 expire_on: Option<Instant>,
62}
63
64impl<T: Clone> Node<T> {
65 fn new(value: T, ttl: Option<Duration>) -> Self {
66 let expire_on = match ttl {
67 Some(t) => Instant::now().checked_add(t),
68 None => None,
69 };
70 Node { value, expire_on }
71 }
72
73 fn will_expire_at(&self, time: &Instant) -> bool {
74 match self.expire_on.as_ref() {
75 Some(t) => t <= time,
76 None => false,
77 }
78 }
79
80 fn is_expired(&self) -> bool {
81 self.will_expire_at(&Instant::now())
82 }
83}
84
85pub struct MemoryCache<K: Hash, T: Clone> {
87 store: TinyUfo<u64, Node<T>>,
88 _key_type: PhantomData<K>,
89 pub(crate) hasher: RandomState,
90}
91
92impl<K: Hash, T: Clone + Send + Sync + 'static> MemoryCache<K, T> {
93 pub fn new(size: usize) -> Self {
95 MemoryCache {
96 store: TinyUfo::new(size, size),
97 _key_type: PhantomData,
98 hasher: RandomState::new(),
99 }
100 }
101
102 pub fn get(&self, key: &K) -> (Option<T>, CacheStatus) {
104 let hashed_key = self.hasher.hash_one(key);
105
106 if let Some(n) = self.store.get(&hashed_key) {
107 if !n.is_expired() {
108 (Some(n.value), CacheStatus::Hit)
109 } else {
110 (None, CacheStatus::Expired)
112 }
113 } else {
114 (None, CacheStatus::Miss)
115 }
116 }
117
118 pub fn put(&self, key: &K, value: T, ttl: Option<Duration>) {
122 if let Some(t) = ttl {
123 if t.is_zero() {
124 return;
125 }
126 }
127 let hashed_key = self.hasher.hash_one(key);
128 let node = Node::new(value, ttl);
129 self.store.put(hashed_key, node, 1);
131 }
132
133 pub fn remove(&self, key: &K) {
135 let hashed_key = self.hasher.hash_one(key);
136 self.store.remove(&hashed_key);
137 }
138
139 pub(crate) fn force_put(&self, key: &K, value: T, ttl: Option<Duration>) {
140 if let Some(t) = ttl {
141 if t.is_zero() {
142 return;
143 }
144 }
145 let hashed_key = self.hasher.hash_one(key);
146 let node = Node::new(value, ttl);
147 self.store.force_put(hashed_key, node, 1);
149 }
150
151 pub fn multi_get<'a, I>(&self, keys: I) -> Vec<(Option<T>, CacheStatus)>
153 where
154 I: Iterator<Item = &'a K>,
155 K: 'a,
156 {
157 let mut resp = Vec::with_capacity(keys.size_hint().0);
158 for key in keys {
159 resp.push(self.get(key));
160 }
161 resp
162 }
163
164 pub fn multi_get_with_miss<'a, I>(&self, keys: I) -> (Vec<(Option<T>, CacheStatus)>, Vec<&'a K>)
166 where
167 I: Iterator<Item = &'a K>,
168 K: 'a,
169 {
170 let mut resp = Vec::with_capacity(keys.size_hint().0);
171 let mut missed = Vec::with_capacity(keys.size_hint().0 / 2);
172 for key in keys {
173 let (lookup, cache_status) = self.get(key);
174 if lookup.is_none() {
175 missed.push(key);
176 }
177 resp.push((lookup, cache_status));
178 }
179 (resp, missed)
180 }
181
182 }
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use std::thread::sleep;
189
190 #[test]
191 fn test_get() {
192 let cache: MemoryCache<i32, ()> = MemoryCache::new(10);
193 let (res, hit) = cache.get(&1);
194 assert_eq!(res, None);
195 assert_eq!(hit, CacheStatus::Miss);
196 }
197
198 #[test]
199 fn test_put_get() {
200 let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
201 let (res, hit) = cache.get(&1);
202 assert_eq!(res, None);
203 assert_eq!(hit, CacheStatus::Miss);
204 cache.put(&1, 2, None);
205 let (res, hit) = cache.get(&1);
206 assert_eq!(res.unwrap(), 2);
207 assert_eq!(hit, CacheStatus::Hit);
208 }
209
210 #[test]
211 fn test_put_get_remove() {
212 let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
213 let (res, hit) = cache.get(&1);
214 assert_eq!(res, None);
215 assert_eq!(hit, CacheStatus::Miss);
216 cache.put(&1, 2, None);
217 cache.put(&3, 4, None);
218 cache.put(&5, 6, None);
219 let (res, hit) = cache.get(&1);
220 assert_eq!(res.unwrap(), 2);
221 assert_eq!(hit, CacheStatus::Hit);
222 cache.remove(&1);
223 cache.remove(&3);
224 let (res, hit) = cache.get(&1);
225 assert_eq!(res, None);
226 assert_eq!(hit, CacheStatus::Miss);
227 let (res, hit) = cache.get(&3);
228 assert_eq!(res, None);
229 assert_eq!(hit, CacheStatus::Miss);
230 let (res, hit) = cache.get(&5);
231 assert_eq!(res.unwrap(), 6);
232 assert_eq!(hit, CacheStatus::Hit);
233 }
234
235 #[test]
236 fn test_get_expired() {
237 let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
238 let (res, hit) = cache.get(&1);
239 assert_eq!(res, None);
240 assert_eq!(hit, CacheStatus::Miss);
241 cache.put(&1, 2, Some(Duration::from_secs(1)));
242 sleep(Duration::from_millis(1100));
243 let (res, hit) = cache.get(&1);
244 assert_eq!(res, None);
245 assert_eq!(hit, CacheStatus::Expired);
246 }
247
248 #[test]
249 fn test_eviction() {
250 let cache: MemoryCache<i32, i32> = MemoryCache::new(2);
251 cache.put(&1, 2, None);
252 cache.put(&2, 4, None);
253 cache.put(&3, 6, None);
254 let (res, hit) = cache.get(&1);
255 assert_eq!(res, None);
256 assert_eq!(hit, CacheStatus::Miss);
257 let (res, hit) = cache.get(&2);
258 assert_eq!(res.unwrap(), 4);
259 assert_eq!(hit, CacheStatus::Hit);
260 let (res, hit) = cache.get(&3);
261 assert_eq!(res.unwrap(), 6);
262 assert_eq!(hit, CacheStatus::Hit);
263 }
264
265 #[test]
266 fn test_multi_get() {
267 let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
268 cache.put(&2, -2, None);
269 let keys: Vec<i32> = vec![1, 2, 3];
270 let resp = cache.multi_get(keys.iter());
271 assert_eq!(resp[0].0, None);
272 assert_eq!(resp[0].1, CacheStatus::Miss);
273 assert_eq!(resp[1].0.unwrap(), -2);
274 assert_eq!(resp[1].1, CacheStatus::Hit);
275 assert_eq!(resp[2].0, None);
276 assert_eq!(resp[2].1, CacheStatus::Miss);
277
278 let (resp, missed) = cache.multi_get_with_miss(keys.iter());
279 assert_eq!(resp[0].0, None);
280 assert_eq!(resp[0].1, CacheStatus::Miss);
281 assert_eq!(resp[1].0.unwrap(), -2);
282 assert_eq!(resp[1].1, CacheStatus::Hit);
283 assert_eq!(resp[2].0, None);
284 assert_eq!(resp[2].1, CacheStatus::Miss);
285 assert_eq!(missed[0], &1);
286 assert_eq!(missed[1], &3);
287 }
288}