pingora_memory_cache/
lib.rs

1// Copyright 2024 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use ahash::RandomState;
16use std::hash::Hash;
17use std::marker::PhantomData;
18use std::time::{Duration, Instant};
19
20use tinyufo::TinyUfo;
21
22mod read_through;
23pub use read_through::{Lookup, MultiLookup, RTCache};
24
25#[derive(Debug, PartialEq, Eq)]
26/// [CacheStatus] indicates the response type for a query.
27pub enum CacheStatus {
28    /// The key was found in the cache
29    Hit,
30    /// The key was not found.
31    Miss,
32    /// The key was found but it was expired.
33    Expired,
34    /// The key was not initially found but was found after awaiting a lock.
35    LockHit,
36}
37
38impl CacheStatus {
39    /// Return the string representation for [CacheStatus].
40    pub fn as_str(&self) -> &str {
41        match self {
42            Self::Hit => "hit",
43            Self::Miss => "miss",
44            Self::Expired => "expired",
45            Self::LockHit => "lock_hit",
46        }
47    }
48
49    /// Returns whether this status represents a cache hit.
50    pub fn is_hit(&self) -> bool {
51        match self {
52            CacheStatus::Hit | CacheStatus::LockHit => true,
53            CacheStatus::Miss | CacheStatus::Expired => false,
54        }
55    }
56}
57
58#[derive(Debug, Clone)]
59struct Node<T: Clone> {
60    pub value: T,
61    expire_on: Option<Instant>,
62}
63
64impl<T: Clone> Node<T> {
65    fn new(value: T, ttl: Option<Duration>) -> Self {
66        let expire_on = match ttl {
67            Some(t) => Instant::now().checked_add(t),
68            None => None,
69        };
70        Node { value, expire_on }
71    }
72
73    fn will_expire_at(&self, time: &Instant) -> bool {
74        match self.expire_on.as_ref() {
75            Some(t) => t <= time,
76            None => false,
77        }
78    }
79
80    fn is_expired(&self) -> bool {
81        self.will_expire_at(&Instant::now())
82    }
83}
84
85/// A high performant in-memory cache with S3-FIFO + TinyLFU
86pub struct MemoryCache<K: Hash, T: Clone> {
87    store: TinyUfo<u64, Node<T>>,
88    _key_type: PhantomData<K>,
89    pub(crate) hasher: RandomState,
90}
91
92impl<K: Hash, T: Clone + Send + Sync + 'static> MemoryCache<K, T> {
93    /// Create a new [MemoryCache] with the given size.
94    pub fn new(size: usize) -> Self {
95        MemoryCache {
96            store: TinyUfo::new(size, size),
97            _key_type: PhantomData,
98            hasher: RandomState::new(),
99        }
100    }
101
102    /// Fetch the key and return its value in addition to a [CacheStatus].
103    pub fn get(&self, key: &K) -> (Option<T>, CacheStatus) {
104        let hashed_key = self.hasher.hash_one(key);
105
106        if let Some(n) = self.store.get(&hashed_key) {
107            if !n.is_expired() {
108                (Some(n.value), CacheStatus::Hit)
109            } else {
110                // TODO: consider returning the staled value
111                (None, CacheStatus::Expired)
112            }
113        } else {
114            (None, CacheStatus::Miss)
115        }
116    }
117
118    /// Insert a key and value pair with an optional TTL into the cache.
119    ///
120    /// An item with zero TTL of zero will not be inserted.
121    pub fn put(&self, key: &K, value: T, ttl: Option<Duration>) {
122        if let Some(t) = ttl {
123            if t.is_zero() {
124                return;
125            }
126        }
127        let hashed_key = self.hasher.hash_one(key);
128        let node = Node::new(value, ttl);
129        // weight is always 1 for now
130        self.store.put(hashed_key, node, 1);
131    }
132
133    /// Remove a key from the cache if it exists.
134    pub fn remove(&self, key: &K) {
135        let hashed_key = self.hasher.hash_one(key);
136        self.store.remove(&hashed_key);
137    }
138
139    pub(crate) fn force_put(&self, key: &K, value: T, ttl: Option<Duration>) {
140        if let Some(t) = ttl {
141            if t.is_zero() {
142                return;
143            }
144        }
145        let hashed_key = self.hasher.hash_one(key);
146        let node = Node::new(value, ttl);
147        // weight is always 1 for now
148        self.store.force_put(hashed_key, node, 1);
149    }
150
151    /// This is equivalent to [MemoryCache::get] but for an arbitrary amount of keys.
152    pub fn multi_get<'a, I>(&self, keys: I) -> Vec<(Option<T>, CacheStatus)>
153    where
154        I: Iterator<Item = &'a K>,
155        K: 'a,
156    {
157        let mut resp = Vec::with_capacity(keys.size_hint().0);
158        for key in keys {
159            resp.push(self.get(key));
160        }
161        resp
162    }
163
164    /// Same as [MemoryCache::multi_get] but returns the keys that are missing from the cache.
165    pub fn multi_get_with_miss<'a, I>(&self, keys: I) -> (Vec<(Option<T>, CacheStatus)>, Vec<&'a K>)
166    where
167        I: Iterator<Item = &'a K>,
168        K: 'a,
169    {
170        let mut resp = Vec::with_capacity(keys.size_hint().0);
171        let mut missed = Vec::with_capacity(keys.size_hint().0 / 2);
172        for key in keys {
173            let (lookup, cache_status) = self.get(key);
174            if lookup.is_none() {
175                missed.push(key);
176            }
177            resp.push((lookup, cache_status));
178        }
179        (resp, missed)
180    }
181
182    // TODO: evict expired first
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use std::thread::sleep;
189
190    #[test]
191    fn test_get() {
192        let cache: MemoryCache<i32, ()> = MemoryCache::new(10);
193        let (res, hit) = cache.get(&1);
194        assert_eq!(res, None);
195        assert_eq!(hit, CacheStatus::Miss);
196    }
197
198    #[test]
199    fn test_put_get() {
200        let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
201        let (res, hit) = cache.get(&1);
202        assert_eq!(res, None);
203        assert_eq!(hit, CacheStatus::Miss);
204        cache.put(&1, 2, None);
205        let (res, hit) = cache.get(&1);
206        assert_eq!(res.unwrap(), 2);
207        assert_eq!(hit, CacheStatus::Hit);
208    }
209
210    #[test]
211    fn test_put_get_remove() {
212        let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
213        let (res, hit) = cache.get(&1);
214        assert_eq!(res, None);
215        assert_eq!(hit, CacheStatus::Miss);
216        cache.put(&1, 2, None);
217        cache.put(&3, 4, None);
218        cache.put(&5, 6, None);
219        let (res, hit) = cache.get(&1);
220        assert_eq!(res.unwrap(), 2);
221        assert_eq!(hit, CacheStatus::Hit);
222        cache.remove(&1);
223        cache.remove(&3);
224        let (res, hit) = cache.get(&1);
225        assert_eq!(res, None);
226        assert_eq!(hit, CacheStatus::Miss);
227        let (res, hit) = cache.get(&3);
228        assert_eq!(res, None);
229        assert_eq!(hit, CacheStatus::Miss);
230        let (res, hit) = cache.get(&5);
231        assert_eq!(res.unwrap(), 6);
232        assert_eq!(hit, CacheStatus::Hit);
233    }
234
235    #[test]
236    fn test_get_expired() {
237        let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
238        let (res, hit) = cache.get(&1);
239        assert_eq!(res, None);
240        assert_eq!(hit, CacheStatus::Miss);
241        cache.put(&1, 2, Some(Duration::from_secs(1)));
242        sleep(Duration::from_millis(1100));
243        let (res, hit) = cache.get(&1);
244        assert_eq!(res, None);
245        assert_eq!(hit, CacheStatus::Expired);
246    }
247
248    #[test]
249    fn test_eviction() {
250        let cache: MemoryCache<i32, i32> = MemoryCache::new(2);
251        cache.put(&1, 2, None);
252        cache.put(&2, 4, None);
253        cache.put(&3, 6, None);
254        let (res, hit) = cache.get(&1);
255        assert_eq!(res, None);
256        assert_eq!(hit, CacheStatus::Miss);
257        let (res, hit) = cache.get(&2);
258        assert_eq!(res.unwrap(), 4);
259        assert_eq!(hit, CacheStatus::Hit);
260        let (res, hit) = cache.get(&3);
261        assert_eq!(res.unwrap(), 6);
262        assert_eq!(hit, CacheStatus::Hit);
263    }
264
265    #[test]
266    fn test_multi_get() {
267        let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
268        cache.put(&2, -2, None);
269        let keys: Vec<i32> = vec![1, 2, 3];
270        let resp = cache.multi_get(keys.iter());
271        assert_eq!(resp[0].0, None);
272        assert_eq!(resp[0].1, CacheStatus::Miss);
273        assert_eq!(resp[1].0.unwrap(), -2);
274        assert_eq!(resp[1].1, CacheStatus::Hit);
275        assert_eq!(resp[2].0, None);
276        assert_eq!(resp[2].1, CacheStatus::Miss);
277
278        let (resp, missed) = cache.multi_get_with_miss(keys.iter());
279        assert_eq!(resp[0].0, None);
280        assert_eq!(resp[0].1, CacheStatus::Miss);
281        assert_eq!(resp[1].0.unwrap(), -2);
282        assert_eq!(resp[1].1, CacheStatus::Hit);
283        assert_eq!(resp[2].0, None);
284        assert_eq!(resp[2].1, CacheStatus::Miss);
285        assert_eq!(missed[0], &1);
286        assert_eq!(missed[1], &3);
287    }
288}