pingora_memory_cache/
lib.rs

1// Copyright 2025 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use ahash::RandomState;
16use std::borrow::Borrow;
17use std::hash::Hash;
18use std::marker::PhantomData;
19use std::time::{Duration, Instant};
20
21use tinyufo::TinyUfo;
22
23mod read_through;
24pub use read_through::{Lookup, MultiLookup, RTCache};
25
26#[derive(Debug, PartialEq, Eq)]
27/// [CacheStatus] indicates the response type for a query.
28pub enum CacheStatus {
29    /// The key was found in the cache
30    Hit,
31    /// The key was not found.
32    Miss,
33    /// The key was found but it was expired.
34    Expired,
35    /// The key was not initially found but was found after awaiting a lock.
36    LockHit,
37    /// The returned value was expired but still returned. The [Duration] is
38    /// how long it has been since its expiration time.
39    Stale(Duration),
40}
41
42impl CacheStatus {
43    /// Return the string representation for [CacheStatus].
44    pub fn as_str(&self) -> &str {
45        match self {
46            Self::Hit => "hit",
47            Self::Miss => "miss",
48            Self::Expired => "expired",
49            Self::LockHit => "lock_hit",
50            Self::Stale(_) => "stale",
51        }
52    }
53
54    /// Returns whether this status represents a cache hit.
55    pub fn is_hit(&self) -> bool {
56        match self {
57            CacheStatus::Hit | CacheStatus::LockHit | CacheStatus::Stale(_) => true,
58            CacheStatus::Miss | CacheStatus::Expired => false,
59        }
60    }
61
62    /// Returns the stale duration if any
63    pub fn stale(&self) -> Option<Duration> {
64        match self {
65            CacheStatus::Stale(time) => Some(*time),
66            _ => None,
67        }
68    }
69}
70
71#[derive(Debug, Clone)]
72struct Node<T: Clone> {
73    pub value: T,
74    expire_on: Option<Instant>,
75}
76
77impl<T: Clone> Node<T> {
78    fn new(value: T, ttl: Option<Duration>) -> Self {
79        let expire_on = match ttl {
80            Some(t) => Instant::now().checked_add(t),
81            None => None,
82        };
83        Node { value, expire_on }
84    }
85
86    fn will_expire_at(&self, time: &Instant) -> bool {
87        self.stale_duration(time).is_some()
88    }
89
90    fn is_expired(&self) -> bool {
91        self.will_expire_at(&Instant::now())
92    }
93
94    fn stale_duration(&self, time: &Instant) -> Option<Duration> {
95        let expire_time = self.expire_on?;
96        if &expire_time <= time {
97            Some(time.duration_since(expire_time))
98        } else {
99            None
100        }
101    }
102}
103
104/// A high performant in-memory cache with S3-FIFO + TinyLFU
105pub struct MemoryCache<K: Hash, T: Clone> {
106    store: TinyUfo<u64, Node<T>>,
107    _key_type: PhantomData<K>,
108    pub(crate) hasher: RandomState,
109}
110
111impl<K: Hash, T: Clone + Send + Sync + 'static> MemoryCache<K, T> {
112    /// Create a new [MemoryCache] with the given size.
113    pub fn new(size: usize) -> Self {
114        MemoryCache {
115            store: TinyUfo::new(size, size),
116            _key_type: PhantomData,
117            hasher: RandomState::new(),
118        }
119    }
120
121    /// Fetch the key and return its value in addition to a [CacheStatus].
122    pub fn get<Q>(&self, key: &Q) -> (Option<T>, CacheStatus)
123    where
124        K: Borrow<Q>,
125        Q: Hash + ?Sized,
126    {
127        let hashed_key = self.hasher.hash_one(key);
128
129        if let Some(n) = self.store.get(&hashed_key) {
130            if !n.is_expired() {
131                (Some(n.value), CacheStatus::Hit)
132            } else {
133                (None, CacheStatus::Expired)
134            }
135        } else {
136            (None, CacheStatus::Miss)
137        }
138    }
139
140    /// Similar to [Self::get], fetch the key and return its value in addition to a
141    /// [CacheStatus] but also return the value even if it is expired. When the
142    /// value is expired, the [Duration] of how long it has been stale will
143    /// also be returned.
144    pub fn get_stale<Q>(&self, key: &Q) -> (Option<T>, CacheStatus)
145    where
146        K: Borrow<Q>,
147        Q: Hash + ?Sized,
148    {
149        let hashed_key = self.hasher.hash_one(key);
150
151        if let Some(n) = self.store.get(&hashed_key) {
152            let stale_duration = n.stale_duration(&Instant::now());
153            if let Some(stale_duration) = stale_duration {
154                (Some(n.value), CacheStatus::Stale(stale_duration))
155            } else {
156                (Some(n.value), CacheStatus::Hit)
157            }
158        } else {
159            (None, CacheStatus::Miss)
160        }
161    }
162
163    /// Insert a key and value pair with an optional TTL into the cache.
164    ///
165    /// An item with zero TTL of zero will not be inserted.
166    pub fn put<Q>(&self, key: &Q, value: T, ttl: Option<Duration>)
167    where
168        K: Borrow<Q>,
169        Q: Hash + ?Sized,
170    {
171        if let Some(t) = ttl {
172            if t.is_zero() {
173                return;
174            }
175        }
176        let hashed_key = self.hasher.hash_one(key);
177        let node = Node::new(value, ttl);
178        // weight is always 1 for now
179        self.store.put(hashed_key, node, 1);
180    }
181
182    /// Remove a key from the cache if it exists.
183    pub fn remove<Q>(&self, key: &Q)
184    where
185        K: Borrow<Q>,
186        Q: Hash + ?Sized,
187    {
188        let hashed_key = self.hasher.hash_one(key);
189        self.store.remove(&hashed_key);
190    }
191
192    pub(crate) fn force_put(&self, key: &K, value: T, ttl: Option<Duration>) {
193        if let Some(t) = ttl {
194            if t.is_zero() {
195                return;
196            }
197        }
198        let hashed_key = self.hasher.hash_one(key);
199        let node = Node::new(value, ttl);
200        // weight is always 1 for now
201        self.store.force_put(hashed_key, node, 1);
202    }
203
204    /// This is equivalent to [MemoryCache::get] but for an arbitrary amount of keys.
205    pub fn multi_get<'a, I, Q>(&self, keys: I) -> Vec<(Option<T>, CacheStatus)>
206    where
207        I: Iterator<Item = &'a Q>,
208        Q: Hash + ?Sized + 'a,
209        K: Borrow<Q> + 'a,
210    {
211        let mut resp = Vec::with_capacity(keys.size_hint().0);
212        for key in keys {
213            resp.push(self.get(key));
214        }
215        resp
216    }
217
218    /// Same as [MemoryCache::multi_get] but returns the keys that are missing from the cache.
219    pub fn multi_get_with_miss<'a, I, Q>(
220        &self,
221        keys: I,
222    ) -> (Vec<(Option<T>, CacheStatus)>, Vec<&'a Q>)
223    where
224        I: Iterator<Item = &'a Q>,
225        Q: Hash + ?Sized + 'a,
226        K: Borrow<Q> + 'a,
227    {
228        let mut resp = Vec::with_capacity(keys.size_hint().0);
229        let mut missed = Vec::with_capacity(keys.size_hint().0 / 2);
230        for key in keys {
231            let (lookup, cache_status) = self.get(key);
232            if lookup.is_none() {
233                missed.push(key);
234            }
235            resp.push((lookup, cache_status));
236        }
237        (resp, missed)
238    }
239
240    // TODO: evict expired first
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use std::thread::sleep;
247
248    #[test]
249    fn test_get() {
250        let cache: MemoryCache<i32, ()> = MemoryCache::new(10);
251        let (res, hit) = cache.get(&1);
252        assert_eq!(res, None);
253        assert_eq!(hit, CacheStatus::Miss);
254    }
255
256    #[test]
257    fn test_put_get() {
258        let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
259        let (res, hit) = cache.get(&1);
260        assert_eq!(res, None);
261        assert_eq!(hit, CacheStatus::Miss);
262        cache.put(&1, 2, None);
263        let (res, hit) = cache.get(&1);
264        assert_eq!(res.unwrap(), 2);
265        assert_eq!(hit, CacheStatus::Hit);
266    }
267
268    #[test]
269    fn test_put_get_remove() {
270        let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
271        let (res, hit) = cache.get(&1);
272        assert_eq!(res, None);
273        assert_eq!(hit, CacheStatus::Miss);
274        cache.put(&1, 2, None);
275        cache.put(&3, 4, None);
276        cache.put(&5, 6, None);
277        let (res, hit) = cache.get(&1);
278        assert_eq!(res.unwrap(), 2);
279        assert_eq!(hit, CacheStatus::Hit);
280        cache.remove(&1);
281        cache.remove(&3);
282        let (res, hit) = cache.get(&1);
283        assert_eq!(res, None);
284        assert_eq!(hit, CacheStatus::Miss);
285        let (res, hit) = cache.get(&3);
286        assert_eq!(res, None);
287        assert_eq!(hit, CacheStatus::Miss);
288        let (res, hit) = cache.get(&5);
289        assert_eq!(res.unwrap(), 6);
290        assert_eq!(hit, CacheStatus::Hit);
291    }
292
293    #[test]
294    fn test_get_expired() {
295        let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
296        let (res, hit) = cache.get(&1);
297        assert_eq!(res, None);
298        assert_eq!(hit, CacheStatus::Miss);
299        cache.put(&1, 2, Some(Duration::from_secs(1)));
300        sleep(Duration::from_millis(1100));
301        let (res, hit) = cache.get(&1);
302        assert_eq!(res, None);
303        assert_eq!(hit, CacheStatus::Expired);
304    }
305
306    #[test]
307    fn test_get_stale() {
308        let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
309        let (res, hit) = cache.get(&1);
310        assert_eq!(res, None);
311        assert_eq!(hit, CacheStatus::Miss);
312        cache.put(&1, 2, Some(Duration::from_secs(1)));
313        sleep(Duration::from_millis(1100));
314        let (res, hit) = cache.get_stale(&1);
315        assert_eq!(res.unwrap(), 2);
316        // we slept 1100ms and the ttl is 1000ms
317        assert!(hit.stale().unwrap() >= Duration::from_millis(100));
318    }
319
320    #[test]
321    fn test_eviction() {
322        let cache: MemoryCache<i32, i32> = MemoryCache::new(2);
323        cache.put(&1, 2, None);
324        cache.put(&2, 4, None);
325        cache.put(&3, 6, None);
326        let (res, hit) = cache.get(&1);
327        assert_eq!(res, None);
328        assert_eq!(hit, CacheStatus::Miss);
329        let (res, hit) = cache.get(&2);
330        assert_eq!(res.unwrap(), 4);
331        assert_eq!(hit, CacheStatus::Hit);
332        let (res, hit) = cache.get(&3);
333        assert_eq!(res.unwrap(), 6);
334        assert_eq!(hit, CacheStatus::Hit);
335    }
336
337    #[test]
338    fn test_multi_get() {
339        let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
340        cache.put(&2, -2, None);
341        let keys: Vec<i32> = vec![1, 2, 3];
342        let resp = cache.multi_get(keys.iter());
343        assert_eq!(resp[0].0, None);
344        assert_eq!(resp[0].1, CacheStatus::Miss);
345        assert_eq!(resp[1].0.unwrap(), -2);
346        assert_eq!(resp[1].1, CacheStatus::Hit);
347        assert_eq!(resp[2].0, None);
348        assert_eq!(resp[2].1, CacheStatus::Miss);
349
350        let (resp, missed) = cache.multi_get_with_miss(keys.iter());
351        assert_eq!(resp[0].0, None);
352        assert_eq!(resp[0].1, CacheStatus::Miss);
353        assert_eq!(resp[1].0.unwrap(), -2);
354        assert_eq!(resp[1].1, CacheStatus::Hit);
355        assert_eq!(resp[2].0, None);
356        assert_eq!(resp[2].1, CacheStatus::Miss);
357        assert_eq!(missed[0], &1);
358        assert_eq!(missed[1], &3);
359    }
360
361    #[test]
362    fn test_get_with_mismatched_key() {
363        let cache: MemoryCache<String, ()> = MemoryCache::new(10);
364        let (res, hit) = cache.get("Hello");
365        assert_eq!(res, None);
366        assert_eq!(hit, CacheStatus::Miss);
367    }
368
369    #[test]
370    fn test_put_get_with_mismatched_key() {
371        let cache: MemoryCache<String, i32> = MemoryCache::new(10);
372        let (res, hit) = cache.get("1");
373        assert_eq!(res, None);
374        assert_eq!(hit, CacheStatus::Miss);
375        cache.put("1", 2, None);
376        let (res, hit) = cache.get("1");
377        assert_eq!(res.unwrap(), 2);
378        assert_eq!(hit, CacheStatus::Hit);
379    }
380}