pingora_cache/
memory.rs

1// Copyright 2025 Cloudflare, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Hash map based in memory cache
16//!
17//! For testing only, not for production use
18
19//TODO: Mark this module #[test] only
20
21use super::*;
22use crate::key::CompactCacheKey;
23use crate::storage::{streaming_write::U64WriteId, HandleHit, HandleMiss};
24use crate::trace::SpanHandle;
25
26use async_trait::async_trait;
27use bytes::Bytes;
28use parking_lot::RwLock;
29use pingora_error::*;
30use std::any::Any;
31use std::collections::HashMap;
32use std::sync::atomic::{AtomicU64, Ordering};
33use std::sync::Arc;
34use tokio::sync::watch;
35
36type BinaryMeta = (Vec<u8>, Vec<u8>);
37
38pub(crate) struct CacheObject {
39    pub meta: BinaryMeta,
40    pub body: Arc<Vec<u8>>,
41}
42
43pub(crate) struct TempObject {
44    pub meta: BinaryMeta,
45    // these are Arc because they need to continue to exist after this TempObject is removed
46    pub body: Arc<RwLock<Vec<u8>>>,
47    bytes_written: Arc<watch::Sender<PartialState>>, // this should match body.len()
48}
49
50impl TempObject {
51    fn new(meta: BinaryMeta) -> Self {
52        let (tx, _rx) = watch::channel(PartialState::Partial(0));
53        TempObject {
54            meta,
55            body: Arc::new(RwLock::new(Vec::new())),
56            bytes_written: Arc::new(tx),
57        }
58    }
59    // this is not at all optimized
60    fn make_cache_object(&self) -> CacheObject {
61        let meta = self.meta.clone();
62        let body = Arc::new(self.body.read().clone());
63        CacheObject { meta, body }
64    }
65}
66
67/// Hash map based in memory cache
68///
69/// For testing only, not for production use.
70pub struct MemCache {
71    pub(crate) cached: Arc<RwLock<HashMap<String, CacheObject>>>,
72    pub(crate) temp: Arc<RwLock<HashMap<String, HashMap<u64, TempObject>>>>,
73    pub(crate) last_temp_id: AtomicU64,
74}
75
76impl MemCache {
77    /// Create a new [MemCache]
78    pub fn new() -> Self {
79        MemCache {
80            cached: Arc::new(RwLock::new(HashMap::new())),
81            temp: Arc::new(RwLock::new(HashMap::new())),
82            last_temp_id: AtomicU64::new(0),
83        }
84    }
85}
86
87pub enum MemHitHandler {
88    Complete(CompleteHit),
89    Partial(PartialHit),
90}
91
92#[derive(Copy, Clone)]
93enum PartialState {
94    Partial(usize),
95    Complete(usize),
96}
97
98pub struct CompleteHit {
99    body: Arc<Vec<u8>>,
100    done: bool,
101    range_start: usize,
102    range_end: usize,
103}
104
105impl CompleteHit {
106    fn get(&mut self) -> Option<Bytes> {
107        if self.done {
108            None
109        } else {
110            self.done = true;
111            Some(Bytes::copy_from_slice(
112                &self.body.as_slice()[self.range_start..self.range_end],
113            ))
114        }
115    }
116
117    fn seek(&mut self, start: usize, end: Option<usize>) -> Result<()> {
118        if start >= self.body.len() {
119            return Error::e_explain(
120                ErrorType::InternalError,
121                format!("seek start out of range {start} >= {}", self.body.len()),
122            );
123        }
124        self.range_start = start;
125        if let Some(end) = end {
126            // end over the actual last byte is allowed, we just need to return the actual bytes
127            self.range_end = std::cmp::min(self.body.len(), end);
128        }
129        // seek resets read so that one handler can be used for multiple ranges
130        self.done = false;
131        Ok(())
132    }
133}
134
135pub struct PartialHit {
136    body: Arc<RwLock<Vec<u8>>>,
137    bytes_written: watch::Receiver<PartialState>,
138    bytes_read: usize,
139}
140
141impl PartialHit {
142    async fn read(&mut self) -> Option<Bytes> {
143        loop {
144            let bytes_written = *self.bytes_written.borrow_and_update();
145            let bytes_end = match bytes_written {
146                PartialState::Partial(s) => s,
147                PartialState::Complete(c) => {
148                    // no more data will arrive
149                    if c == self.bytes_read {
150                        return None;
151                    }
152                    c
153                }
154            };
155            assert!(bytes_end >= self.bytes_read);
156
157            // more data available to read
158            if bytes_end > self.bytes_read {
159                let new_bytes =
160                    Bytes::copy_from_slice(&self.body.read()[self.bytes_read..bytes_end]);
161                self.bytes_read = bytes_end;
162                return Some(new_bytes);
163            }
164
165            // wait for more data
166            if self.bytes_written.changed().await.is_err() {
167                // err: sender dropped, body is finished
168                // FIXME: sender could drop because of an error
169                return None;
170            }
171        }
172    }
173}
174
175#[async_trait]
176impl HandleHit for MemHitHandler {
177    async fn read_body(&mut self) -> Result<Option<Bytes>> {
178        match self {
179            Self::Complete(c) => Ok(c.get()),
180            Self::Partial(p) => Ok(p.read().await),
181        }
182    }
183    async fn finish(
184        self: Box<Self>, // because self is always used as a trait object
185        _storage: &'static (dyn storage::Storage + Sync),
186        _key: &CacheKey,
187        _trace: &SpanHandle,
188    ) -> Result<()> {
189        Ok(())
190    }
191
192    fn can_seek(&self) -> bool {
193        match self {
194            Self::Complete(_) => true,
195            Self::Partial(_) => false, // TODO: support seeking in partial reads
196        }
197    }
198
199    fn seek(&mut self, start: usize, end: Option<usize>) -> Result<()> {
200        match self {
201            Self::Complete(c) => c.seek(start, end),
202            Self::Partial(_) => Error::e_explain(
203                ErrorType::InternalError,
204                "seek not supported for partial cache",
205            ),
206        }
207    }
208
209    fn should_count_access(&self) -> bool {
210        match self {
211            // avoid counting accesses for partial reads to keep things simple
212            Self::Complete(_) => true,
213            Self::Partial(_) => false,
214        }
215    }
216
217    fn get_eviction_weight(&self) -> usize {
218        match self {
219            // FIXME: just body size, also track meta size
220            Self::Complete(c) => c.body.len(),
221            // partial read cannot be estimated since body size is unknown
222            Self::Partial(_) => 0,
223        }
224    }
225
226    fn as_any(&self) -> &(dyn Any + Send + Sync) {
227        self
228    }
229
230    fn as_any_mut(&mut self) -> &mut (dyn Any + Send + Sync) {
231        self
232    }
233}
234
235pub struct MemMissHandler {
236    body: Arc<RwLock<Vec<u8>>>,
237    bytes_written: Arc<watch::Sender<PartialState>>,
238    // these are used only in finish() to data from temp to cache
239    key: String,
240    temp_id: U64WriteId,
241    // key -> cache object
242    cache: Arc<RwLock<HashMap<String, CacheObject>>>,
243    // key -> (temp writer id -> temp object) to support concurrent writers
244    temp: Arc<RwLock<HashMap<String, HashMap<u64, TempObject>>>>,
245}
246
247#[async_trait]
248impl HandleMiss for MemMissHandler {
249    async fn write_body(&mut self, data: bytes::Bytes, eof: bool) -> Result<()> {
250        let current_bytes = match *self.bytes_written.borrow() {
251            PartialState::Partial(p) => p,
252            PartialState::Complete(_) => panic!("already EOF"),
253        };
254        self.body.write().extend_from_slice(&data);
255        let written = current_bytes + data.len();
256        let new_state = if eof {
257            PartialState::Complete(written)
258        } else {
259            PartialState::Partial(written)
260        };
261        self.bytes_written.send_replace(new_state);
262        Ok(())
263    }
264
265    async fn finish(self: Box<Self>) -> Result<MissFinishType> {
266        // safe, the temp object is inserted when the miss handler is created
267        let cache_object = self
268            .temp
269            .read()
270            .get(&self.key)
271            .unwrap()
272            .get(&self.temp_id.into())
273            .unwrap()
274            .make_cache_object();
275        let size = cache_object.body.len(); // FIXME: this just body size, also track meta size
276        self.cache.write().insert(self.key.clone(), cache_object);
277        self.temp
278            .write()
279            .get_mut(&self.key)
280            .and_then(|map| map.remove(&self.temp_id.into()));
281        Ok(MissFinishType::Created(size))
282    }
283
284    fn streaming_write_tag(&self) -> Option<&[u8]> {
285        Some(self.temp_id.as_bytes())
286    }
287}
288
289impl Drop for MemMissHandler {
290    fn drop(&mut self) {
291        self.temp
292            .write()
293            .get_mut(&self.key)
294            .and_then(|map| map.remove(&self.temp_id.into()));
295    }
296}
297
298fn hit_from_temp_obj(temp_obj: &TempObject) -> Result<Option<(CacheMeta, HitHandler)>> {
299    let meta = CacheMeta::deserialize(&temp_obj.meta.0, &temp_obj.meta.1)?;
300    let partial = PartialHit {
301        body: temp_obj.body.clone(),
302        bytes_written: temp_obj.bytes_written.subscribe(),
303        bytes_read: 0,
304    };
305    let hit_handler = MemHitHandler::Partial(partial);
306    Ok(Some((meta, Box::new(hit_handler))))
307}
308
309#[async_trait]
310impl Storage for MemCache {
311    async fn lookup(
312        &'static self,
313        key: &CacheKey,
314        _trace: &SpanHandle,
315    ) -> Result<Option<(CacheMeta, HitHandler)>> {
316        let hash = key.combined();
317        // always prefer partial read otherwise fresh asset will not be visible on expired asset
318        // until it is fully updated
319        // no preference on which partial read we get (if there are multiple writers)
320        if let Some((_, temp_obj)) = self
321            .temp
322            .read()
323            .get(&hash)
324            .and_then(|map| map.iter().next())
325        {
326            hit_from_temp_obj(temp_obj)
327        } else if let Some(obj) = self.cached.read().get(&hash) {
328            let meta = CacheMeta::deserialize(&obj.meta.0, &obj.meta.1)?;
329            let hit_handler = CompleteHit {
330                body: obj.body.clone(),
331                done: false,
332                range_start: 0,
333                range_end: obj.body.len(),
334            };
335            let hit_handler = MemHitHandler::Complete(hit_handler);
336            Ok(Some((meta, Box::new(hit_handler))))
337        } else {
338            Ok(None)
339        }
340    }
341
342    async fn lookup_streaming_write(
343        &'static self,
344        key: &CacheKey,
345        streaming_write_tag: Option<&[u8]>,
346        _trace: &SpanHandle,
347    ) -> Result<Option<(CacheMeta, HitHandler)>> {
348        let hash = key.combined();
349        let write_tag: U64WriteId = streaming_write_tag
350            .expect("tag must be set during streaming write")
351            .try_into()
352            .expect("tag must be correct length");
353        hit_from_temp_obj(
354            self.temp
355                .read()
356                .get(&hash)
357                .and_then(|map| map.get(&write_tag.into()))
358                .expect("must have partial write in progress"),
359        )
360    }
361
362    async fn get_miss_handler(
363        &'static self,
364        key: &CacheKey,
365        meta: &CacheMeta,
366        _trace: &SpanHandle,
367    ) -> Result<MissHandler> {
368        let hash = key.combined();
369        let meta = meta.serialize()?;
370        let temp_obj = TempObject::new(meta);
371        let temp_id = self.last_temp_id.fetch_add(1, Ordering::Relaxed);
372        let miss_handler = MemMissHandler {
373            body: temp_obj.body.clone(),
374            bytes_written: temp_obj.bytes_written.clone(),
375            key: hash.clone(),
376            cache: self.cached.clone(),
377            temp: self.temp.clone(),
378            temp_id: temp_id.into(),
379        };
380        self.temp
381            .write()
382            .entry(hash)
383            .or_default()
384            .insert(miss_handler.temp_id.into(), temp_obj);
385        Ok(Box::new(miss_handler))
386    }
387
388    async fn purge(
389        &'static self,
390        key: &CompactCacheKey,
391        _type: PurgeType,
392        _trace: &SpanHandle,
393    ) -> Result<bool> {
394        // This usually purges the primary key because, without a lookup, the variance key is usually
395        // empty
396        let hash = key.combined();
397        let temp_removed = self.temp.write().remove(&hash).is_some();
398        let cache_removed = self.cached.write().remove(&hash).is_some();
399        Ok(temp_removed || cache_removed)
400    }
401
402    async fn update_meta(
403        &'static self,
404        key: &CacheKey,
405        meta: &CacheMeta,
406        _trace: &SpanHandle,
407    ) -> Result<bool> {
408        let hash = key.combined();
409        if let Some(obj) = self.cached.write().get_mut(&hash) {
410            obj.meta = meta.serialize()?;
411            Ok(true)
412        } else {
413            panic!("no meta found")
414        }
415    }
416
417    fn support_streaming_partial_write(&self) -> bool {
418        true
419    }
420
421    fn as_any(&self) -> &(dyn Any + Send + Sync) {
422        self
423    }
424}
425
426#[cfg(test)]
427mod test {
428    use super::*;
429    use cf_rustracing::span::Span;
430    use once_cell::sync::Lazy;
431
432    fn gen_meta() -> CacheMeta {
433        let mut header = ResponseHeader::build(200, None).unwrap();
434        header.append_header("foo1", "bar1").unwrap();
435        header.append_header("foo2", "bar2").unwrap();
436        header.append_header("foo3", "bar3").unwrap();
437        header.append_header("Server", "Pingora").unwrap();
438        let internal = crate::meta::InternalMeta::default();
439        CacheMeta(Box::new(crate::meta::CacheMetaInner {
440            internal,
441            header,
442            extensions: http::Extensions::new(),
443        }))
444    }
445
446    #[tokio::test]
447    async fn test_write_then_read() {
448        static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
449        let span = &Span::inactive().handle();
450
451        let key1 = CacheKey::new("", "a", "1");
452        let res = MEM_CACHE.lookup(&key1, span).await.unwrap();
453        assert!(res.is_none());
454
455        let cache_meta = gen_meta();
456
457        let mut miss_handler = MEM_CACHE
458            .get_miss_handler(&key1, &cache_meta, span)
459            .await
460            .unwrap();
461        miss_handler
462            .write_body(b"test1"[..].into(), false)
463            .await
464            .unwrap();
465        miss_handler
466            .write_body(b"test2"[..].into(), false)
467            .await
468            .unwrap();
469        miss_handler.finish().await.unwrap();
470
471        let (cache_meta2, mut hit_handler) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
472        assert_eq!(
473            cache_meta.0.internal.fresh_until,
474            cache_meta2.0.internal.fresh_until
475        );
476
477        let data = hit_handler.read_body().await.unwrap().unwrap();
478        assert_eq!("test1test2", data);
479        let data = hit_handler.read_body().await.unwrap();
480        assert!(data.is_none());
481    }
482
483    #[tokio::test]
484    async fn test_read_range() {
485        static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
486        let span = &Span::inactive().handle();
487
488        let key1 = CacheKey::new("", "a", "1");
489        let res = MEM_CACHE.lookup(&key1, span).await.unwrap();
490        assert!(res.is_none());
491
492        let cache_meta = gen_meta();
493
494        let mut miss_handler = MEM_CACHE
495            .get_miss_handler(&key1, &cache_meta, span)
496            .await
497            .unwrap();
498        miss_handler
499            .write_body(b"test1test2"[..].into(), false)
500            .await
501            .unwrap();
502        miss_handler.finish().await.unwrap();
503
504        let (cache_meta2, mut hit_handler) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
505        assert_eq!(
506            cache_meta.0.internal.fresh_until,
507            cache_meta2.0.internal.fresh_until
508        );
509
510        // out of range
511        assert!(hit_handler.seek(10000, None).is_err());
512
513        assert!(hit_handler.seek(5, None).is_ok());
514        let data = hit_handler.read_body().await.unwrap().unwrap();
515        assert_eq!("test2", data);
516        let data = hit_handler.read_body().await.unwrap();
517        assert!(data.is_none());
518
519        assert!(hit_handler.seek(4, Some(5)).is_ok());
520        let data = hit_handler.read_body().await.unwrap().unwrap();
521        assert_eq!("1", data);
522        let data = hit_handler.read_body().await.unwrap();
523        assert!(data.is_none());
524    }
525
526    #[tokio::test]
527    async fn test_write_while_read() {
528        use futures::FutureExt;
529
530        static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
531        let span = &Span::inactive().handle();
532
533        let key1 = CacheKey::new("", "a", "1");
534        let res = MEM_CACHE.lookup(&key1, span).await.unwrap();
535        assert!(res.is_none());
536
537        let cache_meta = gen_meta();
538
539        let mut miss_handler = MEM_CACHE
540            .get_miss_handler(&key1, &cache_meta, span)
541            .await
542            .unwrap();
543
544        // first reader
545        let (cache_meta1, mut hit_handler1) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
546        assert_eq!(
547            cache_meta.0.internal.fresh_until,
548            cache_meta1.0.internal.fresh_until
549        );
550
551        // No body to read
552        let res = hit_handler1.read_body().now_or_never();
553        assert!(res.is_none());
554
555        miss_handler
556            .write_body(b"test1"[..].into(), false)
557            .await
558            .unwrap();
559
560        let data = hit_handler1.read_body().await.unwrap().unwrap();
561        assert_eq!("test1", data);
562        let res = hit_handler1.read_body().now_or_never();
563        assert!(res.is_none());
564
565        miss_handler
566            .write_body(b"test2"[..].into(), false)
567            .await
568            .unwrap();
569        let data = hit_handler1.read_body().await.unwrap().unwrap();
570        assert_eq!("test2", data);
571
572        // second reader
573        let (cache_meta2, mut hit_handler2) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
574        assert_eq!(
575            cache_meta.0.internal.fresh_until,
576            cache_meta2.0.internal.fresh_until
577        );
578
579        let data = hit_handler2.read_body().await.unwrap().unwrap();
580        assert_eq!("test1test2", data);
581        let res = hit_handler2.read_body().now_or_never();
582        assert!(res.is_none());
583
584        let res = hit_handler1.read_body().now_or_never();
585        assert!(res.is_none());
586
587        miss_handler.finish().await.unwrap();
588
589        let data = hit_handler1.read_body().await.unwrap();
590        assert!(data.is_none());
591        let data = hit_handler2.read_body().await.unwrap();
592        assert!(data.is_none());
593    }
594
595    #[tokio::test]
596    async fn test_purge_partial() {
597        static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
598        let cache = &MEM_CACHE;
599
600        let key = CacheKey::new("", "a", "1").to_compact();
601        let hash = key.combined();
602        let meta = (
603            "meta_key".as_bytes().to_vec(),
604            "meta_value".as_bytes().to_vec(),
605        );
606
607        let temp_obj = TempObject::new(meta);
608        let mut map = HashMap::new();
609        map.insert(0, temp_obj);
610        cache.temp.write().insert(hash.clone(), map);
611
612        assert!(cache.temp.read().contains_key(&hash));
613
614        let result = cache
615            .purge(&key, PurgeType::Invalidation, &Span::inactive().handle())
616            .await;
617        assert!(result.is_ok());
618
619        assert!(!cache.temp.read().contains_key(&hash));
620    }
621
622    #[tokio::test]
623    async fn test_purge_complete() {
624        static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
625        let cache = &MEM_CACHE;
626
627        let key = CacheKey::new("", "a", "1").to_compact();
628        let hash = key.combined();
629        let meta = (
630            "meta_key".as_bytes().to_vec(),
631            "meta_value".as_bytes().to_vec(),
632        );
633        let body = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0];
634        let cache_obj = CacheObject {
635            meta,
636            body: Arc::new(body),
637        };
638        cache.cached.write().insert(hash.clone(), cache_obj);
639
640        assert!(cache.cached.read().contains_key(&hash));
641
642        let result = cache
643            .purge(&key, PurgeType::Invalidation, &Span::inactive().handle())
644            .await;
645        assert!(result.is_ok());
646
647        assert!(!cache.cached.read().contains_key(&hash));
648    }
649}