pingora_cache/
memory.rs

1// Copyright 2025 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Hash map based in memory cache
16//!
17//! For testing only, not for production use
18
19//TODO: Mark this module #[test] only
20
21use super::*;
22use crate::key::CompactCacheKey;
23use crate::storage::{streaming_write::U64WriteId, HandleHit, HandleMiss};
24use crate::trace::SpanHandle;
25
26use async_trait::async_trait;
27use bytes::Bytes;
28use parking_lot::RwLock;
29use pingora_error::*;
30use std::any::Any;
31use std::collections::HashMap;
32use std::sync::atomic::{AtomicU64, Ordering};
33use std::sync::Arc;
34use tokio::sync::watch;
35
36type BinaryMeta = (Vec<u8>, Vec<u8>);
37
38pub(crate) struct CacheObject {
39    pub meta: BinaryMeta,
40    pub body: Arc<Vec<u8>>,
41}
42
43pub(crate) struct TempObject {
44    pub meta: BinaryMeta,
45    // these are Arc because they need to continue to exist after this TempObject is removed
46    pub body: Arc<RwLock<Vec<u8>>>,
47    bytes_written: Arc<watch::Sender<PartialState>>, // this should match body.len()
48}
49
50impl TempObject {
51    fn new(meta: BinaryMeta) -> Self {
52        let (tx, _rx) = watch::channel(PartialState::Partial(0));
53        TempObject {
54            meta,
55            body: Arc::new(RwLock::new(Vec::new())),
56            bytes_written: Arc::new(tx),
57        }
58    }
59    // this is not at all optimized
60    fn make_cache_object(&self) -> CacheObject {
61        let meta = self.meta.clone();
62        let body = Arc::new(self.body.read().clone());
63        CacheObject { meta, body }
64    }
65}
66
67/// Hash map based in memory cache
68///
69/// For testing only, not for production use.
70pub struct MemCache {
71    pub(crate) cached: Arc<RwLock<HashMap<String, CacheObject>>>,
72    pub(crate) temp: Arc<RwLock<HashMap<String, HashMap<u64, TempObject>>>>,
73    pub(crate) last_temp_id: AtomicU64,
74}
75
76impl MemCache {
77    /// Create a new [MemCache]
78    pub fn new() -> Self {
79        MemCache {
80            cached: Arc::new(RwLock::new(HashMap::new())),
81            temp: Arc::new(RwLock::new(HashMap::new())),
82            last_temp_id: AtomicU64::new(0),
83        }
84    }
85}
86
87pub enum MemHitHandler {
88    Complete(CompleteHit),
89    Partial(PartialHit),
90}
91
92#[derive(Copy, Clone)]
93enum PartialState {
94    Partial(usize),
95    Complete(usize),
96}
97
98pub struct CompleteHit {
99    body: Arc<Vec<u8>>,
100    done: bool,
101    range_start: usize,
102    range_end: usize,
103}
104
105impl CompleteHit {
106    fn get(&mut self) -> Option<Bytes> {
107        if self.done {
108            None
109        } else {
110            self.done = true;
111            Some(Bytes::copy_from_slice(
112                &self.body.as_slice()[self.range_start..self.range_end],
113            ))
114        }
115    }
116
117    fn seek(&mut self, start: usize, end: Option<usize>) -> Result<()> {
118        if start >= self.body.len() {
119            return Error::e_explain(
120                ErrorType::InternalError,
121                format!("seek start out of range {start} >= {}", self.body.len()),
122            );
123        }
124        self.range_start = start;
125        if let Some(end) = end {
126            // end over the actual last byte is allowed, we just need to return the actual bytes
127            self.range_end = std::cmp::min(self.body.len(), end);
128        }
129        // seek resets read so that one handler can be used for multiple ranges
130        self.done = false;
131        Ok(())
132    }
133}
134
135pub struct PartialHit {
136    body: Arc<RwLock<Vec<u8>>>,
137    bytes_written: watch::Receiver<PartialState>,
138    bytes_read: usize,
139}
140
141impl PartialHit {
142    async fn read(&mut self) -> Option<Bytes> {
143        loop {
144            let bytes_written = *self.bytes_written.borrow_and_update();
145            let bytes_end = match bytes_written {
146                PartialState::Partial(s) => s,
147                PartialState::Complete(c) => {
148                    // no more data will arrive
149                    if c == self.bytes_read {
150                        return None;
151                    }
152                    c
153                }
154            };
155            assert!(bytes_end >= self.bytes_read);
156
157            // more data available to read
158            if bytes_end > self.bytes_read {
159                let new_bytes =
160                    Bytes::copy_from_slice(&self.body.read()[self.bytes_read..bytes_end]);
161                self.bytes_read = bytes_end;
162                return Some(new_bytes);
163            }
164
165            // wait for more data
166            if self.bytes_written.changed().await.is_err() {
167                // err: sender dropped, body is finished
168                // FIXME: sender could drop because of an error
169                return None;
170            }
171        }
172    }
173}
174
175#[async_trait]
176impl HandleHit for MemHitHandler {
177    async fn read_body(&mut self) -> Result<Option<Bytes>> {
178        match self {
179            Self::Complete(c) => Ok(c.get()),
180            Self::Partial(p) => Ok(p.read().await),
181        }
182    }
183    async fn finish(
184        self: Box<Self>, // because self is always used as a trait object
185        _storage: &'static (dyn storage::Storage + Sync),
186        _key: &CacheKey,
187        _trace: &SpanHandle,
188    ) -> Result<()> {
189        Ok(())
190    }
191
192    fn can_seek(&self) -> bool {
193        match self {
194            Self::Complete(_) => true,
195            Self::Partial(_) => false, // TODO: support seeking in partial reads
196        }
197    }
198
199    fn seek(&mut self, start: usize, end: Option<usize>) -> Result<()> {
200        match self {
201            Self::Complete(c) => c.seek(start, end),
202            Self::Partial(_) => Error::e_explain(
203                ErrorType::InternalError,
204                "seek not supported for partial cache",
205            ),
206        }
207    }
208
209    fn should_count_access(&self) -> bool {
210        match self {
211            // avoid counting accesses for partial reads to keep things simple
212            Self::Complete(_) => true,
213            Self::Partial(_) => false,
214        }
215    }
216
217    fn get_eviction_weight(&self) -> usize {
218        match self {
219            // FIXME: just body size, also track meta size
220            Self::Complete(c) => c.body.len(),
221            // partial read cannot be estimated since body size is unknown
222            Self::Partial(_) => 0,
223        }
224    }
225
226    fn as_any(&self) -> &(dyn Any + Send + Sync) {
227        self
228    }
229}
230
231pub struct MemMissHandler {
232    body: Arc<RwLock<Vec<u8>>>,
233    bytes_written: Arc<watch::Sender<PartialState>>,
234    // these are used only in finish() to data from temp to cache
235    key: String,
236    temp_id: U64WriteId,
237    // key -> cache object
238    cache: Arc<RwLock<HashMap<String, CacheObject>>>,
239    // key -> (temp writer id -> temp object) to support concurrent writers
240    temp: Arc<RwLock<HashMap<String, HashMap<u64, TempObject>>>>,
241}
242
243#[async_trait]
244impl HandleMiss for MemMissHandler {
245    async fn write_body(&mut self, data: bytes::Bytes, eof: bool) -> Result<()> {
246        let current_bytes = match *self.bytes_written.borrow() {
247            PartialState::Partial(p) => p,
248            PartialState::Complete(_) => panic!("already EOF"),
249        };
250        self.body.write().extend_from_slice(&data);
251        let written = current_bytes + data.len();
252        let new_state = if eof {
253            PartialState::Complete(written)
254        } else {
255            PartialState::Partial(written)
256        };
257        self.bytes_written.send_replace(new_state);
258        Ok(())
259    }
260
261    async fn finish(self: Box<Self>) -> Result<MissFinishType> {
262        // safe, the temp object is inserted when the miss handler is created
263        let cache_object = self
264            .temp
265            .read()
266            .get(&self.key)
267            .unwrap()
268            .get(&self.temp_id.into())
269            .unwrap()
270            .make_cache_object();
271        let size = cache_object.body.len(); // FIXME: this just body size, also track meta size
272        self.cache.write().insert(self.key.clone(), cache_object);
273        self.temp
274            .write()
275            .get_mut(&self.key)
276            .and_then(|map| map.remove(&self.temp_id.into()));
277        Ok(MissFinishType::Created(size))
278    }
279
280    fn streaming_write_tag(&self) -> Option<&[u8]> {
281        Some(self.temp_id.as_bytes())
282    }
283}
284
285impl Drop for MemMissHandler {
286    fn drop(&mut self) {
287        self.temp
288            .write()
289            .get_mut(&self.key)
290            .and_then(|map| map.remove(&self.temp_id.into()));
291    }
292}
293
294fn hit_from_temp_obj(temp_obj: &TempObject) -> Result<Option<(CacheMeta, HitHandler)>> {
295    let meta = CacheMeta::deserialize(&temp_obj.meta.0, &temp_obj.meta.1)?;
296    let partial = PartialHit {
297        body: temp_obj.body.clone(),
298        bytes_written: temp_obj.bytes_written.subscribe(),
299        bytes_read: 0,
300    };
301    let hit_handler = MemHitHandler::Partial(partial);
302    Ok(Some((meta, Box::new(hit_handler))))
303}
304
305#[async_trait]
306impl Storage for MemCache {
307    async fn lookup(
308        &'static self,
309        key: &CacheKey,
310        _trace: &SpanHandle,
311    ) -> Result<Option<(CacheMeta, HitHandler)>> {
312        let hash = key.combined();
313        // always prefer partial read otherwise fresh asset will not be visible on expired asset
314        // until it is fully updated
315        // no preference on which partial read we get (if there are multiple writers)
316        if let Some((_, temp_obj)) = self
317            .temp
318            .read()
319            .get(&hash)
320            .and_then(|map| map.iter().next())
321        {
322            hit_from_temp_obj(temp_obj)
323        } else if let Some(obj) = self.cached.read().get(&hash) {
324            let meta = CacheMeta::deserialize(&obj.meta.0, &obj.meta.1)?;
325            let hit_handler = CompleteHit {
326                body: obj.body.clone(),
327                done: false,
328                range_start: 0,
329                range_end: obj.body.len(),
330            };
331            let hit_handler = MemHitHandler::Complete(hit_handler);
332            Ok(Some((meta, Box::new(hit_handler))))
333        } else {
334            Ok(None)
335        }
336    }
337
338    async fn lookup_streaming_write(
339        &'static self,
340        key: &CacheKey,
341        streaming_write_tag: Option<&[u8]>,
342        _trace: &SpanHandle,
343    ) -> Result<Option<(CacheMeta, HitHandler)>> {
344        let hash = key.combined();
345        let write_tag: U64WriteId = streaming_write_tag
346            .expect("tag must be set during streaming write")
347            .try_into()
348            .expect("tag must be correct length");
349        hit_from_temp_obj(
350            self.temp
351                .read()
352                .get(&hash)
353                .and_then(|map| map.get(&write_tag.into()))
354                .expect("must have partial write in progress"),
355        )
356    }
357
358    async fn get_miss_handler(
359        &'static self,
360        key: &CacheKey,
361        meta: &CacheMeta,
362        _trace: &SpanHandle,
363    ) -> Result<MissHandler> {
364        let hash = key.combined();
365        let meta = meta.serialize()?;
366        let temp_obj = TempObject::new(meta);
367        let temp_id = self.last_temp_id.fetch_add(1, Ordering::Relaxed);
368        let miss_handler = MemMissHandler {
369            body: temp_obj.body.clone(),
370            bytes_written: temp_obj.bytes_written.clone(),
371            key: hash.clone(),
372            cache: self.cached.clone(),
373            temp: self.temp.clone(),
374            temp_id: temp_id.into(),
375        };
376        self.temp
377            .write()
378            .entry(hash)
379            .or_default()
380            .insert(miss_handler.temp_id.into(), temp_obj);
381        Ok(Box::new(miss_handler))
382    }
383
384    async fn purge(
385        &'static self,
386        key: &CompactCacheKey,
387        _type: PurgeType,
388        _trace: &SpanHandle,
389    ) -> Result<bool> {
390        // This usually purges the primary key because, without a lookup, the variance key is usually
391        // empty
392        let hash = key.combined();
393        let temp_removed = self.temp.write().remove(&hash).is_some();
394        let cache_removed = self.cached.write().remove(&hash).is_some();
395        Ok(temp_removed || cache_removed)
396    }
397
398    async fn update_meta(
399        &'static self,
400        key: &CacheKey,
401        meta: &CacheMeta,
402        _trace: &SpanHandle,
403    ) -> Result<bool> {
404        let hash = key.combined();
405        if let Some(obj) = self.cached.write().get_mut(&hash) {
406            obj.meta = meta.serialize()?;
407            Ok(true)
408        } else {
409            panic!("no meta found")
410        }
411    }
412
413    fn support_streaming_partial_write(&self) -> bool {
414        true
415    }
416
417    fn as_any(&self) -> &(dyn Any + Send + Sync) {
418        self
419    }
420}
421
422#[cfg(test)]
423mod test {
424    use super::*;
425    use cf_rustracing::span::Span;
426    use once_cell::sync::Lazy;
427
428    fn gen_meta() -> CacheMeta {
429        let mut header = ResponseHeader::build(200, None).unwrap();
430        header.append_header("foo1", "bar1").unwrap();
431        header.append_header("foo2", "bar2").unwrap();
432        header.append_header("foo3", "bar3").unwrap();
433        header.append_header("Server", "Pingora").unwrap();
434        let internal = crate::meta::InternalMeta::default();
435        CacheMeta(Box::new(crate::meta::CacheMetaInner {
436            internal,
437            header,
438            extensions: http::Extensions::new(),
439        }))
440    }
441
442    #[tokio::test]
443    async fn test_write_then_read() {
444        static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
445        let span = &Span::inactive().handle();
446
447        let key1 = CacheKey::new("", "a", "1");
448        let res = MEM_CACHE.lookup(&key1, span).await.unwrap();
449        assert!(res.is_none());
450
451        let cache_meta = gen_meta();
452
453        let mut miss_handler = MEM_CACHE
454            .get_miss_handler(&key1, &cache_meta, span)
455            .await
456            .unwrap();
457        miss_handler
458            .write_body(b"test1"[..].into(), false)
459            .await
460            .unwrap();
461        miss_handler
462            .write_body(b"test2"[..].into(), false)
463            .await
464            .unwrap();
465        miss_handler.finish().await.unwrap();
466
467        let (cache_meta2, mut hit_handler) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
468        assert_eq!(
469            cache_meta.0.internal.fresh_until,
470            cache_meta2.0.internal.fresh_until
471        );
472
473        let data = hit_handler.read_body().await.unwrap().unwrap();
474        assert_eq!("test1test2", data);
475        let data = hit_handler.read_body().await.unwrap();
476        assert!(data.is_none());
477    }
478
479    #[tokio::test]
480    async fn test_read_range() {
481        static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
482        let span = &Span::inactive().handle();
483
484        let key1 = CacheKey::new("", "a", "1");
485        let res = MEM_CACHE.lookup(&key1, span).await.unwrap();
486        assert!(res.is_none());
487
488        let cache_meta = gen_meta();
489
490        let mut miss_handler = MEM_CACHE
491            .get_miss_handler(&key1, &cache_meta, span)
492            .await
493            .unwrap();
494        miss_handler
495            .write_body(b"test1test2"[..].into(), false)
496            .await
497            .unwrap();
498        miss_handler.finish().await.unwrap();
499
500        let (cache_meta2, mut hit_handler) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
501        assert_eq!(
502            cache_meta.0.internal.fresh_until,
503            cache_meta2.0.internal.fresh_until
504        );
505
506        // out of range
507        assert!(hit_handler.seek(10000, None).is_err());
508
509        assert!(hit_handler.seek(5, None).is_ok());
510        let data = hit_handler.read_body().await.unwrap().unwrap();
511        assert_eq!("test2", data);
512        let data = hit_handler.read_body().await.unwrap();
513        assert!(data.is_none());
514
515        assert!(hit_handler.seek(4, Some(5)).is_ok());
516        let data = hit_handler.read_body().await.unwrap().unwrap();
517        assert_eq!("1", data);
518        let data = hit_handler.read_body().await.unwrap();
519        assert!(data.is_none());
520    }
521
522    #[tokio::test]
523    async fn test_write_while_read() {
524        use futures::FutureExt;
525
526        static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
527        let span = &Span::inactive().handle();
528
529        let key1 = CacheKey::new("", "a", "1");
530        let res = MEM_CACHE.lookup(&key1, span).await.unwrap();
531        assert!(res.is_none());
532
533        let cache_meta = gen_meta();
534
535        let mut miss_handler = MEM_CACHE
536            .get_miss_handler(&key1, &cache_meta, span)
537            .await
538            .unwrap();
539
540        // first reader
541        let (cache_meta1, mut hit_handler1) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
542        assert_eq!(
543            cache_meta.0.internal.fresh_until,
544            cache_meta1.0.internal.fresh_until
545        );
546
547        // No body to read
548        let res = hit_handler1.read_body().now_or_never();
549        assert!(res.is_none());
550
551        miss_handler
552            .write_body(b"test1"[..].into(), false)
553            .await
554            .unwrap();
555
556        let data = hit_handler1.read_body().await.unwrap().unwrap();
557        assert_eq!("test1", data);
558        let res = hit_handler1.read_body().now_or_never();
559        assert!(res.is_none());
560
561        miss_handler
562            .write_body(b"test2"[..].into(), false)
563            .await
564            .unwrap();
565        let data = hit_handler1.read_body().await.unwrap().unwrap();
566        assert_eq!("test2", data);
567
568        // second reader
569        let (cache_meta2, mut hit_handler2) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
570        assert_eq!(
571            cache_meta.0.internal.fresh_until,
572            cache_meta2.0.internal.fresh_until
573        );
574
575        let data = hit_handler2.read_body().await.unwrap().unwrap();
576        assert_eq!("test1test2", data);
577        let res = hit_handler2.read_body().now_or_never();
578        assert!(res.is_none());
579
580        let res = hit_handler1.read_body().now_or_never();
581        assert!(res.is_none());
582
583        miss_handler.finish().await.unwrap();
584
585        let data = hit_handler1.read_body().await.unwrap();
586        assert!(data.is_none());
587        let data = hit_handler2.read_body().await.unwrap();
588        assert!(data.is_none());
589    }
590
591    #[tokio::test]
592    async fn test_purge_partial() {
593        static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
594        let cache = &MEM_CACHE;
595
596        let key = CacheKey::new("", "a", "1").to_compact();
597        let hash = key.combined();
598        let meta = (
599            "meta_key".as_bytes().to_vec(),
600            "meta_value".as_bytes().to_vec(),
601        );
602
603        let temp_obj = TempObject::new(meta);
604        let mut map = HashMap::new();
605        map.insert(0, temp_obj);
606        cache.temp.write().insert(hash.clone(), map);
607
608        assert!(cache.temp.read().contains_key(&hash));
609
610        let result = cache
611            .purge(&key, PurgeType::Invalidation, &Span::inactive().handle())
612            .await;
613        assert!(result.is_ok());
614
615        assert!(!cache.temp.read().contains_key(&hash));
616    }
617
618    #[tokio::test]
619    async fn test_purge_complete() {
620        static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
621        let cache = &MEM_CACHE;
622
623        let key = CacheKey::new("", "a", "1").to_compact();
624        let hash = key.combined();
625        let meta = (
626            "meta_key".as_bytes().to_vec(),
627            "meta_value".as_bytes().to_vec(),
628        );
629        let body = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0];
630        let cache_obj = CacheObject {
631            meta,
632            body: Arc::new(body),
633        };
634        cache.cached.write().insert(hash.clone(), cache_obj);
635
636        assert!(cache.cached.read().contains_key(&hash));
637
638        let result = cache
639            .purge(&key, PurgeType::Invalidation, &Span::inactive().handle())
640            .await;
641        assert!(result.is_ok());
642
643        assert!(!cache.cached.read().contains_key(&hash));
644    }
645}