1use super::*;
22use crate::key::CompactCacheKey;
23use crate::storage::{streaming_write::U64WriteId, HandleHit, HandleMiss};
24use crate::trace::SpanHandle;
25
26use async_trait::async_trait;
27use bytes::Bytes;
28use parking_lot::RwLock;
29use pingora_error::*;
30use std::any::Any;
31use std::collections::HashMap;
32use std::sync::atomic::{AtomicU64, Ordering};
33use std::sync::Arc;
34use tokio::sync::watch;
35
36type BinaryMeta = (Vec<u8>, Vec<u8>);
37
38pub(crate) struct CacheObject {
39 pub meta: BinaryMeta,
40 pub body: Arc<Vec<u8>>,
41}
42
43pub(crate) struct TempObject {
44 pub meta: BinaryMeta,
45 pub body: Arc<RwLock<Vec<u8>>>,
47 bytes_written: Arc<watch::Sender<PartialState>>, }
49
50impl TempObject {
51 fn new(meta: BinaryMeta) -> Self {
52 let (tx, _rx) = watch::channel(PartialState::Partial(0));
53 TempObject {
54 meta,
55 body: Arc::new(RwLock::new(Vec::new())),
56 bytes_written: Arc::new(tx),
57 }
58 }
59 fn make_cache_object(&self) -> CacheObject {
61 let meta = self.meta.clone();
62 let body = Arc::new(self.body.read().clone());
63 CacheObject { meta, body }
64 }
65}
66
67pub struct MemCache {
71 pub(crate) cached: Arc<RwLock<HashMap<String, CacheObject>>>,
72 pub(crate) temp: Arc<RwLock<HashMap<String, HashMap<u64, TempObject>>>>,
73 pub(crate) last_temp_id: AtomicU64,
74}
75
76impl MemCache {
77 pub fn new() -> Self {
79 MemCache {
80 cached: Arc::new(RwLock::new(HashMap::new())),
81 temp: Arc::new(RwLock::new(HashMap::new())),
82 last_temp_id: AtomicU64::new(0),
83 }
84 }
85}
86
87pub enum MemHitHandler {
88 Complete(CompleteHit),
89 Partial(PartialHit),
90}
91
92#[derive(Copy, Clone)]
93enum PartialState {
94 Partial(usize),
95 Complete(usize),
96}
97
98pub struct CompleteHit {
99 body: Arc<Vec<u8>>,
100 done: bool,
101 range_start: usize,
102 range_end: usize,
103}
104
105impl CompleteHit {
106 fn get(&mut self) -> Option<Bytes> {
107 if self.done {
108 None
109 } else {
110 self.done = true;
111 Some(Bytes::copy_from_slice(
112 &self.body.as_slice()[self.range_start..self.range_end],
113 ))
114 }
115 }
116
117 fn seek(&mut self, start: usize, end: Option<usize>) -> Result<()> {
118 if start >= self.body.len() {
119 return Error::e_explain(
120 ErrorType::InternalError,
121 format!("seek start out of range {start} >= {}", self.body.len()),
122 );
123 }
124 self.range_start = start;
125 if let Some(end) = end {
126 self.range_end = std::cmp::min(self.body.len(), end);
128 }
129 self.done = false;
131 Ok(())
132 }
133}
134
135pub struct PartialHit {
136 body: Arc<RwLock<Vec<u8>>>,
137 bytes_written: watch::Receiver<PartialState>,
138 bytes_read: usize,
139}
140
141impl PartialHit {
142 async fn read(&mut self) -> Option<Bytes> {
143 loop {
144 let bytes_written = *self.bytes_written.borrow_and_update();
145 let bytes_end = match bytes_written {
146 PartialState::Partial(s) => s,
147 PartialState::Complete(c) => {
148 if c == self.bytes_read {
150 return None;
151 }
152 c
153 }
154 };
155 assert!(bytes_end >= self.bytes_read);
156
157 if bytes_end > self.bytes_read {
159 let new_bytes =
160 Bytes::copy_from_slice(&self.body.read()[self.bytes_read..bytes_end]);
161 self.bytes_read = bytes_end;
162 return Some(new_bytes);
163 }
164
165 if self.bytes_written.changed().await.is_err() {
167 return None;
170 }
171 }
172 }
173}
174
175#[async_trait]
176impl HandleHit for MemHitHandler {
177 async fn read_body(&mut self) -> Result<Option<Bytes>> {
178 match self {
179 Self::Complete(c) => Ok(c.get()),
180 Self::Partial(p) => Ok(p.read().await),
181 }
182 }
183 async fn finish(
184 self: Box<Self>, _storage: &'static (dyn storage::Storage + Sync),
186 _key: &CacheKey,
187 _trace: &SpanHandle,
188 ) -> Result<()> {
189 Ok(())
190 }
191
192 fn can_seek(&self) -> bool {
193 match self {
194 Self::Complete(_) => true,
195 Self::Partial(_) => false, }
197 }
198
199 fn seek(&mut self, start: usize, end: Option<usize>) -> Result<()> {
200 match self {
201 Self::Complete(c) => c.seek(start, end),
202 Self::Partial(_) => Error::e_explain(
203 ErrorType::InternalError,
204 "seek not supported for partial cache",
205 ),
206 }
207 }
208
209 fn should_count_access(&self) -> bool {
210 match self {
211 Self::Complete(_) => true,
213 Self::Partial(_) => false,
214 }
215 }
216
217 fn get_eviction_weight(&self) -> usize {
218 match self {
219 Self::Complete(c) => c.body.len(),
221 Self::Partial(_) => 0,
223 }
224 }
225
226 fn as_any(&self) -> &(dyn Any + Send + Sync) {
227 self
228 }
229}
230
231pub struct MemMissHandler {
232 body: Arc<RwLock<Vec<u8>>>,
233 bytes_written: Arc<watch::Sender<PartialState>>,
234 key: String,
236 temp_id: U64WriteId,
237 cache: Arc<RwLock<HashMap<String, CacheObject>>>,
239 temp: Arc<RwLock<HashMap<String, HashMap<u64, TempObject>>>>,
241}
242
243#[async_trait]
244impl HandleMiss for MemMissHandler {
245 async fn write_body(&mut self, data: bytes::Bytes, eof: bool) -> Result<()> {
246 let current_bytes = match *self.bytes_written.borrow() {
247 PartialState::Partial(p) => p,
248 PartialState::Complete(_) => panic!("already EOF"),
249 };
250 self.body.write().extend_from_slice(&data);
251 let written = current_bytes + data.len();
252 let new_state = if eof {
253 PartialState::Complete(written)
254 } else {
255 PartialState::Partial(written)
256 };
257 self.bytes_written.send_replace(new_state);
258 Ok(())
259 }
260
261 async fn finish(self: Box<Self>) -> Result<MissFinishType> {
262 let cache_object = self
264 .temp
265 .read()
266 .get(&self.key)
267 .unwrap()
268 .get(&self.temp_id.into())
269 .unwrap()
270 .make_cache_object();
271 let size = cache_object.body.len(); self.cache.write().insert(self.key.clone(), cache_object);
273 self.temp
274 .write()
275 .get_mut(&self.key)
276 .and_then(|map| map.remove(&self.temp_id.into()));
277 Ok(MissFinishType::Created(size))
278 }
279
280 fn streaming_write_tag(&self) -> Option<&[u8]> {
281 Some(self.temp_id.as_bytes())
282 }
283}
284
285impl Drop for MemMissHandler {
286 fn drop(&mut self) {
287 self.temp
288 .write()
289 .get_mut(&self.key)
290 .and_then(|map| map.remove(&self.temp_id.into()));
291 }
292}
293
294fn hit_from_temp_obj(temp_obj: &TempObject) -> Result<Option<(CacheMeta, HitHandler)>> {
295 let meta = CacheMeta::deserialize(&temp_obj.meta.0, &temp_obj.meta.1)?;
296 let partial = PartialHit {
297 body: temp_obj.body.clone(),
298 bytes_written: temp_obj.bytes_written.subscribe(),
299 bytes_read: 0,
300 };
301 let hit_handler = MemHitHandler::Partial(partial);
302 Ok(Some((meta, Box::new(hit_handler))))
303}
304
305#[async_trait]
306impl Storage for MemCache {
307 async fn lookup(
308 &'static self,
309 key: &CacheKey,
310 _trace: &SpanHandle,
311 ) -> Result<Option<(CacheMeta, HitHandler)>> {
312 let hash = key.combined();
313 if let Some((_, temp_obj)) = self
317 .temp
318 .read()
319 .get(&hash)
320 .and_then(|map| map.iter().next())
321 {
322 hit_from_temp_obj(temp_obj)
323 } else if let Some(obj) = self.cached.read().get(&hash) {
324 let meta = CacheMeta::deserialize(&obj.meta.0, &obj.meta.1)?;
325 let hit_handler = CompleteHit {
326 body: obj.body.clone(),
327 done: false,
328 range_start: 0,
329 range_end: obj.body.len(),
330 };
331 let hit_handler = MemHitHandler::Complete(hit_handler);
332 Ok(Some((meta, Box::new(hit_handler))))
333 } else {
334 Ok(None)
335 }
336 }
337
338 async fn lookup_streaming_write(
339 &'static self,
340 key: &CacheKey,
341 streaming_write_tag: Option<&[u8]>,
342 _trace: &SpanHandle,
343 ) -> Result<Option<(CacheMeta, HitHandler)>> {
344 let hash = key.combined();
345 let write_tag: U64WriteId = streaming_write_tag
346 .expect("tag must be set during streaming write")
347 .try_into()
348 .expect("tag must be correct length");
349 hit_from_temp_obj(
350 self.temp
351 .read()
352 .get(&hash)
353 .and_then(|map| map.get(&write_tag.into()))
354 .expect("must have partial write in progress"),
355 )
356 }
357
358 async fn get_miss_handler(
359 &'static self,
360 key: &CacheKey,
361 meta: &CacheMeta,
362 _trace: &SpanHandle,
363 ) -> Result<MissHandler> {
364 let hash = key.combined();
365 let meta = meta.serialize()?;
366 let temp_obj = TempObject::new(meta);
367 let temp_id = self.last_temp_id.fetch_add(1, Ordering::Relaxed);
368 let miss_handler = MemMissHandler {
369 body: temp_obj.body.clone(),
370 bytes_written: temp_obj.bytes_written.clone(),
371 key: hash.clone(),
372 cache: self.cached.clone(),
373 temp: self.temp.clone(),
374 temp_id: temp_id.into(),
375 };
376 self.temp
377 .write()
378 .entry(hash)
379 .or_default()
380 .insert(miss_handler.temp_id.into(), temp_obj);
381 Ok(Box::new(miss_handler))
382 }
383
384 async fn purge(
385 &'static self,
386 key: &CompactCacheKey,
387 _type: PurgeType,
388 _trace: &SpanHandle,
389 ) -> Result<bool> {
390 let hash = key.combined();
393 let temp_removed = self.temp.write().remove(&hash).is_some();
394 let cache_removed = self.cached.write().remove(&hash).is_some();
395 Ok(temp_removed || cache_removed)
396 }
397
398 async fn update_meta(
399 &'static self,
400 key: &CacheKey,
401 meta: &CacheMeta,
402 _trace: &SpanHandle,
403 ) -> Result<bool> {
404 let hash = key.combined();
405 if let Some(obj) = self.cached.write().get_mut(&hash) {
406 obj.meta = meta.serialize()?;
407 Ok(true)
408 } else {
409 panic!("no meta found")
410 }
411 }
412
413 fn support_streaming_partial_write(&self) -> bool {
414 true
415 }
416
417 fn as_any(&self) -> &(dyn Any + Send + Sync) {
418 self
419 }
420}
421
422#[cfg(test)]
423mod test {
424 use super::*;
425 use cf_rustracing::span::Span;
426 use once_cell::sync::Lazy;
427
428 fn gen_meta() -> CacheMeta {
429 let mut header = ResponseHeader::build(200, None).unwrap();
430 header.append_header("foo1", "bar1").unwrap();
431 header.append_header("foo2", "bar2").unwrap();
432 header.append_header("foo3", "bar3").unwrap();
433 header.append_header("Server", "Pingora").unwrap();
434 let internal = crate::meta::InternalMeta::default();
435 CacheMeta(Box::new(crate::meta::CacheMetaInner {
436 internal,
437 header,
438 extensions: http::Extensions::new(),
439 }))
440 }
441
442 #[tokio::test]
443 async fn test_write_then_read() {
444 static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
445 let span = &Span::inactive().handle();
446
447 let key1 = CacheKey::new("", "a", "1");
448 let res = MEM_CACHE.lookup(&key1, span).await.unwrap();
449 assert!(res.is_none());
450
451 let cache_meta = gen_meta();
452
453 let mut miss_handler = MEM_CACHE
454 .get_miss_handler(&key1, &cache_meta, span)
455 .await
456 .unwrap();
457 miss_handler
458 .write_body(b"test1"[..].into(), false)
459 .await
460 .unwrap();
461 miss_handler
462 .write_body(b"test2"[..].into(), false)
463 .await
464 .unwrap();
465 miss_handler.finish().await.unwrap();
466
467 let (cache_meta2, mut hit_handler) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
468 assert_eq!(
469 cache_meta.0.internal.fresh_until,
470 cache_meta2.0.internal.fresh_until
471 );
472
473 let data = hit_handler.read_body().await.unwrap().unwrap();
474 assert_eq!("test1test2", data);
475 let data = hit_handler.read_body().await.unwrap();
476 assert!(data.is_none());
477 }
478
479 #[tokio::test]
480 async fn test_read_range() {
481 static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
482 let span = &Span::inactive().handle();
483
484 let key1 = CacheKey::new("", "a", "1");
485 let res = MEM_CACHE.lookup(&key1, span).await.unwrap();
486 assert!(res.is_none());
487
488 let cache_meta = gen_meta();
489
490 let mut miss_handler = MEM_CACHE
491 .get_miss_handler(&key1, &cache_meta, span)
492 .await
493 .unwrap();
494 miss_handler
495 .write_body(b"test1test2"[..].into(), false)
496 .await
497 .unwrap();
498 miss_handler.finish().await.unwrap();
499
500 let (cache_meta2, mut hit_handler) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
501 assert_eq!(
502 cache_meta.0.internal.fresh_until,
503 cache_meta2.0.internal.fresh_until
504 );
505
506 assert!(hit_handler.seek(10000, None).is_err());
508
509 assert!(hit_handler.seek(5, None).is_ok());
510 let data = hit_handler.read_body().await.unwrap().unwrap();
511 assert_eq!("test2", data);
512 let data = hit_handler.read_body().await.unwrap();
513 assert!(data.is_none());
514
515 assert!(hit_handler.seek(4, Some(5)).is_ok());
516 let data = hit_handler.read_body().await.unwrap().unwrap();
517 assert_eq!("1", data);
518 let data = hit_handler.read_body().await.unwrap();
519 assert!(data.is_none());
520 }
521
522 #[tokio::test]
523 async fn test_write_while_read() {
524 use futures::FutureExt;
525
526 static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
527 let span = &Span::inactive().handle();
528
529 let key1 = CacheKey::new("", "a", "1");
530 let res = MEM_CACHE.lookup(&key1, span).await.unwrap();
531 assert!(res.is_none());
532
533 let cache_meta = gen_meta();
534
535 let mut miss_handler = MEM_CACHE
536 .get_miss_handler(&key1, &cache_meta, span)
537 .await
538 .unwrap();
539
540 let (cache_meta1, mut hit_handler1) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
542 assert_eq!(
543 cache_meta.0.internal.fresh_until,
544 cache_meta1.0.internal.fresh_until
545 );
546
547 let res = hit_handler1.read_body().now_or_never();
549 assert!(res.is_none());
550
551 miss_handler
552 .write_body(b"test1"[..].into(), false)
553 .await
554 .unwrap();
555
556 let data = hit_handler1.read_body().await.unwrap().unwrap();
557 assert_eq!("test1", data);
558 let res = hit_handler1.read_body().now_or_never();
559 assert!(res.is_none());
560
561 miss_handler
562 .write_body(b"test2"[..].into(), false)
563 .await
564 .unwrap();
565 let data = hit_handler1.read_body().await.unwrap().unwrap();
566 assert_eq!("test2", data);
567
568 let (cache_meta2, mut hit_handler2) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
570 assert_eq!(
571 cache_meta.0.internal.fresh_until,
572 cache_meta2.0.internal.fresh_until
573 );
574
575 let data = hit_handler2.read_body().await.unwrap().unwrap();
576 assert_eq!("test1test2", data);
577 let res = hit_handler2.read_body().now_or_never();
578 assert!(res.is_none());
579
580 let res = hit_handler1.read_body().now_or_never();
581 assert!(res.is_none());
582
583 miss_handler.finish().await.unwrap();
584
585 let data = hit_handler1.read_body().await.unwrap();
586 assert!(data.is_none());
587 let data = hit_handler2.read_body().await.unwrap();
588 assert!(data.is_none());
589 }
590
591 #[tokio::test]
592 async fn test_purge_partial() {
593 static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
594 let cache = &MEM_CACHE;
595
596 let key = CacheKey::new("", "a", "1").to_compact();
597 let hash = key.combined();
598 let meta = (
599 "meta_key".as_bytes().to_vec(),
600 "meta_value".as_bytes().to_vec(),
601 );
602
603 let temp_obj = TempObject::new(meta);
604 let mut map = HashMap::new();
605 map.insert(0, temp_obj);
606 cache.temp.write().insert(hash.clone(), map);
607
608 assert!(cache.temp.read().contains_key(&hash));
609
610 let result = cache
611 .purge(&key, PurgeType::Invalidation, &Span::inactive().handle())
612 .await;
613 assert!(result.is_ok());
614
615 assert!(!cache.temp.read().contains_key(&hash));
616 }
617
618 #[tokio::test]
619 async fn test_purge_complete() {
620 static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
621 let cache = &MEM_CACHE;
622
623 let key = CacheKey::new("", "a", "1").to_compact();
624 let hash = key.combined();
625 let meta = (
626 "meta_key".as_bytes().to_vec(),
627 "meta_value".as_bytes().to_vec(),
628 );
629 let body = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0];
630 let cache_obj = CacheObject {
631 meta,
632 body: Arc::new(body),
633 };
634 cache.cached.write().insert(hash.clone(), cache_obj);
635
636 assert!(cache.cached.read().contains_key(&hash));
637
638 let result = cache
639 .purge(&key, PurgeType::Invalidation, &Span::inactive().handle())
640 .await;
641 assert!(result.is_ok());
642
643 assert!(!cache.cached.read().contains_key(&hash));
644 }
645}