1use super::*;
22use crate::key::CompactCacheKey;
23use crate::storage::{streaming_write::U64WriteId, HandleHit, HandleMiss};
24use crate::trace::SpanHandle;
25
26use async_trait::async_trait;
27use bytes::Bytes;
28use parking_lot::RwLock;
29use pingora_error::*;
30use std::any::Any;
31use std::collections::HashMap;
32use std::sync::atomic::{AtomicU64, Ordering};
33use std::sync::Arc;
34use tokio::sync::watch;
35
36type BinaryMeta = (Vec<u8>, Vec<u8>);
37
38pub(crate) struct CacheObject {
39 pub meta: BinaryMeta,
40 pub body: Arc<Vec<u8>>,
41}
42
43pub(crate) struct TempObject {
44 pub meta: BinaryMeta,
45 pub body: Arc<RwLock<Vec<u8>>>,
47 bytes_written: Arc<watch::Sender<PartialState>>, }
49
50impl TempObject {
51 fn new(meta: BinaryMeta) -> Self {
52 let (tx, _rx) = watch::channel(PartialState::Partial(0));
53 TempObject {
54 meta,
55 body: Arc::new(RwLock::new(Vec::new())),
56 bytes_written: Arc::new(tx),
57 }
58 }
59 fn make_cache_object(&self) -> CacheObject {
61 let meta = self.meta.clone();
62 let body = Arc::new(self.body.read().clone());
63 CacheObject { meta, body }
64 }
65}
66
67pub struct MemCache {
71 pub(crate) cached: Arc<RwLock<HashMap<String, CacheObject>>>,
72 pub(crate) temp: Arc<RwLock<HashMap<String, HashMap<u64, TempObject>>>>,
73 pub(crate) last_temp_id: AtomicU64,
74}
75
76impl MemCache {
77 pub fn new() -> Self {
79 MemCache {
80 cached: Arc::new(RwLock::new(HashMap::new())),
81 temp: Arc::new(RwLock::new(HashMap::new())),
82 last_temp_id: AtomicU64::new(0),
83 }
84 }
85}
86
87pub enum MemHitHandler {
88 Complete(CompleteHit),
89 Partial(PartialHit),
90}
91
92#[derive(Copy, Clone)]
93enum PartialState {
94 Partial(usize),
95 Complete(usize),
96}
97
98pub struct CompleteHit {
99 body: Arc<Vec<u8>>,
100 done: bool,
101 range_start: usize,
102 range_end: usize,
103}
104
105impl CompleteHit {
106 fn get(&mut self) -> Option<Bytes> {
107 if self.done {
108 None
109 } else {
110 self.done = true;
111 Some(Bytes::copy_from_slice(
112 &self.body.as_slice()[self.range_start..self.range_end],
113 ))
114 }
115 }
116
117 fn seek(&mut self, start: usize, end: Option<usize>) -> Result<()> {
118 if start >= self.body.len() {
119 return Error::e_explain(
120 ErrorType::InternalError,
121 format!("seek start out of range {start} >= {}", self.body.len()),
122 );
123 }
124 self.range_start = start;
125 if let Some(end) = end {
126 self.range_end = std::cmp::min(self.body.len(), end);
128 }
129 self.done = false;
131 Ok(())
132 }
133}
134
135pub struct PartialHit {
136 body: Arc<RwLock<Vec<u8>>>,
137 bytes_written: watch::Receiver<PartialState>,
138 bytes_read: usize,
139}
140
141impl PartialHit {
142 async fn read(&mut self) -> Option<Bytes> {
143 loop {
144 let bytes_written = *self.bytes_written.borrow_and_update();
145 let bytes_end = match bytes_written {
146 PartialState::Partial(s) => s,
147 PartialState::Complete(c) => {
148 if c == self.bytes_read {
150 return None;
151 }
152 c
153 }
154 };
155 assert!(bytes_end >= self.bytes_read);
156
157 if bytes_end > self.bytes_read {
159 let new_bytes =
160 Bytes::copy_from_slice(&self.body.read()[self.bytes_read..bytes_end]);
161 self.bytes_read = bytes_end;
162 return Some(new_bytes);
163 }
164
165 if self.bytes_written.changed().await.is_err() {
167 return None;
170 }
171 }
172 }
173}
174
175#[async_trait]
176impl HandleHit for MemHitHandler {
177 async fn read_body(&mut self) -> Result<Option<Bytes>> {
178 match self {
179 Self::Complete(c) => Ok(c.get()),
180 Self::Partial(p) => Ok(p.read().await),
181 }
182 }
183 async fn finish(
184 self: Box<Self>, _storage: &'static (dyn storage::Storage + Sync),
186 _key: &CacheKey,
187 _trace: &SpanHandle,
188 ) -> Result<()> {
189 Ok(())
190 }
191
192 fn can_seek(&self) -> bool {
193 match self {
194 Self::Complete(_) => true,
195 Self::Partial(_) => false, }
197 }
198
199 fn seek(&mut self, start: usize, end: Option<usize>) -> Result<()> {
200 match self {
201 Self::Complete(c) => c.seek(start, end),
202 Self::Partial(_) => Error::e_explain(
203 ErrorType::InternalError,
204 "seek not supported for partial cache",
205 ),
206 }
207 }
208
209 fn should_count_access(&self) -> bool {
210 match self {
211 Self::Complete(_) => true,
213 Self::Partial(_) => false,
214 }
215 }
216
217 fn get_eviction_weight(&self) -> usize {
218 match self {
219 Self::Complete(c) => c.body.len(),
221 Self::Partial(_) => 0,
223 }
224 }
225
226 fn as_any(&self) -> &(dyn Any + Send + Sync) {
227 self
228 }
229
230 fn as_any_mut(&mut self) -> &mut (dyn Any + Send + Sync) {
231 self
232 }
233}
234
235pub struct MemMissHandler {
236 body: Arc<RwLock<Vec<u8>>>,
237 bytes_written: Arc<watch::Sender<PartialState>>,
238 key: String,
240 temp_id: U64WriteId,
241 cache: Arc<RwLock<HashMap<String, CacheObject>>>,
243 temp: Arc<RwLock<HashMap<String, HashMap<u64, TempObject>>>>,
245}
246
247#[async_trait]
248impl HandleMiss for MemMissHandler {
249 async fn write_body(&mut self, data: bytes::Bytes, eof: bool) -> Result<()> {
250 let current_bytes = match *self.bytes_written.borrow() {
251 PartialState::Partial(p) => p,
252 PartialState::Complete(_) => panic!("already EOF"),
253 };
254 self.body.write().extend_from_slice(&data);
255 let written = current_bytes + data.len();
256 let new_state = if eof {
257 PartialState::Complete(written)
258 } else {
259 PartialState::Partial(written)
260 };
261 self.bytes_written.send_replace(new_state);
262 Ok(())
263 }
264
265 async fn finish(self: Box<Self>) -> Result<MissFinishType> {
266 let cache_object = self
268 .temp
269 .read()
270 .get(&self.key)
271 .unwrap()
272 .get(&self.temp_id.into())
273 .unwrap()
274 .make_cache_object();
275 let size = cache_object.body.len(); self.cache.write().insert(self.key.clone(), cache_object);
277 self.temp
278 .write()
279 .get_mut(&self.key)
280 .and_then(|map| map.remove(&self.temp_id.into()));
281 Ok(MissFinishType::Created(size))
282 }
283
284 fn streaming_write_tag(&self) -> Option<&[u8]> {
285 Some(self.temp_id.as_bytes())
286 }
287}
288
289impl Drop for MemMissHandler {
290 fn drop(&mut self) {
291 self.temp
292 .write()
293 .get_mut(&self.key)
294 .and_then(|map| map.remove(&self.temp_id.into()));
295 }
296}
297
298fn hit_from_temp_obj(temp_obj: &TempObject) -> Result<Option<(CacheMeta, HitHandler)>> {
299 let meta = CacheMeta::deserialize(&temp_obj.meta.0, &temp_obj.meta.1)?;
300 let partial = PartialHit {
301 body: temp_obj.body.clone(),
302 bytes_written: temp_obj.bytes_written.subscribe(),
303 bytes_read: 0,
304 };
305 let hit_handler = MemHitHandler::Partial(partial);
306 Ok(Some((meta, Box::new(hit_handler))))
307}
308
309#[async_trait]
310impl Storage for MemCache {
311 async fn lookup(
312 &'static self,
313 key: &CacheKey,
314 _trace: &SpanHandle,
315 ) -> Result<Option<(CacheMeta, HitHandler)>> {
316 let hash = key.combined();
317 if let Some((_, temp_obj)) = self
321 .temp
322 .read()
323 .get(&hash)
324 .and_then(|map| map.iter().next())
325 {
326 hit_from_temp_obj(temp_obj)
327 } else if let Some(obj) = self.cached.read().get(&hash) {
328 let meta = CacheMeta::deserialize(&obj.meta.0, &obj.meta.1)?;
329 let hit_handler = CompleteHit {
330 body: obj.body.clone(),
331 done: false,
332 range_start: 0,
333 range_end: obj.body.len(),
334 };
335 let hit_handler = MemHitHandler::Complete(hit_handler);
336 Ok(Some((meta, Box::new(hit_handler))))
337 } else {
338 Ok(None)
339 }
340 }
341
342 async fn lookup_streaming_write(
343 &'static self,
344 key: &CacheKey,
345 streaming_write_tag: Option<&[u8]>,
346 _trace: &SpanHandle,
347 ) -> Result<Option<(CacheMeta, HitHandler)>> {
348 let hash = key.combined();
349 let write_tag: U64WriteId = streaming_write_tag
350 .expect("tag must be set during streaming write")
351 .try_into()
352 .expect("tag must be correct length");
353 hit_from_temp_obj(
354 self.temp
355 .read()
356 .get(&hash)
357 .and_then(|map| map.get(&write_tag.into()))
358 .expect("must have partial write in progress"),
359 )
360 }
361
362 async fn get_miss_handler(
363 &'static self,
364 key: &CacheKey,
365 meta: &CacheMeta,
366 _trace: &SpanHandle,
367 ) -> Result<MissHandler> {
368 let hash = key.combined();
369 let meta = meta.serialize()?;
370 let temp_obj = TempObject::new(meta);
371 let temp_id = self.last_temp_id.fetch_add(1, Ordering::Relaxed);
372 let miss_handler = MemMissHandler {
373 body: temp_obj.body.clone(),
374 bytes_written: temp_obj.bytes_written.clone(),
375 key: hash.clone(),
376 cache: self.cached.clone(),
377 temp: self.temp.clone(),
378 temp_id: temp_id.into(),
379 };
380 self.temp
381 .write()
382 .entry(hash)
383 .or_default()
384 .insert(miss_handler.temp_id.into(), temp_obj);
385 Ok(Box::new(miss_handler))
386 }
387
388 async fn purge(
389 &'static self,
390 key: &CompactCacheKey,
391 _type: PurgeType,
392 _trace: &SpanHandle,
393 ) -> Result<bool> {
394 let hash = key.combined();
397 let temp_removed = self.temp.write().remove(&hash).is_some();
398 let cache_removed = self.cached.write().remove(&hash).is_some();
399 Ok(temp_removed || cache_removed)
400 }
401
402 async fn update_meta(
403 &'static self,
404 key: &CacheKey,
405 meta: &CacheMeta,
406 _trace: &SpanHandle,
407 ) -> Result<bool> {
408 let hash = key.combined();
409 if let Some(obj) = self.cached.write().get_mut(&hash) {
410 obj.meta = meta.serialize()?;
411 Ok(true)
412 } else {
413 panic!("no meta found")
414 }
415 }
416
417 fn support_streaming_partial_write(&self) -> bool {
418 true
419 }
420
421 fn as_any(&self) -> &(dyn Any + Send + Sync) {
422 self
423 }
424}
425
426#[cfg(test)]
427mod test {
428 use super::*;
429 use cf_rustracing::span::Span;
430 use once_cell::sync::Lazy;
431
432 fn gen_meta() -> CacheMeta {
433 let mut header = ResponseHeader::build(200, None).unwrap();
434 header.append_header("foo1", "bar1").unwrap();
435 header.append_header("foo2", "bar2").unwrap();
436 header.append_header("foo3", "bar3").unwrap();
437 header.append_header("Server", "Pingora").unwrap();
438 let internal = crate::meta::InternalMeta::default();
439 CacheMeta(Box::new(crate::meta::CacheMetaInner {
440 internal,
441 header,
442 extensions: http::Extensions::new(),
443 }))
444 }
445
446 #[tokio::test]
447 async fn test_write_then_read() {
448 static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
449 let span = &Span::inactive().handle();
450
451 let key1 = CacheKey::new("", "a", "1");
452 let res = MEM_CACHE.lookup(&key1, span).await.unwrap();
453 assert!(res.is_none());
454
455 let cache_meta = gen_meta();
456
457 let mut miss_handler = MEM_CACHE
458 .get_miss_handler(&key1, &cache_meta, span)
459 .await
460 .unwrap();
461 miss_handler
462 .write_body(b"test1"[..].into(), false)
463 .await
464 .unwrap();
465 miss_handler
466 .write_body(b"test2"[..].into(), false)
467 .await
468 .unwrap();
469 miss_handler.finish().await.unwrap();
470
471 let (cache_meta2, mut hit_handler) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
472 assert_eq!(
473 cache_meta.0.internal.fresh_until,
474 cache_meta2.0.internal.fresh_until
475 );
476
477 let data = hit_handler.read_body().await.unwrap().unwrap();
478 assert_eq!("test1test2", data);
479 let data = hit_handler.read_body().await.unwrap();
480 assert!(data.is_none());
481 }
482
483 #[tokio::test]
484 async fn test_read_range() {
485 static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
486 let span = &Span::inactive().handle();
487
488 let key1 = CacheKey::new("", "a", "1");
489 let res = MEM_CACHE.lookup(&key1, span).await.unwrap();
490 assert!(res.is_none());
491
492 let cache_meta = gen_meta();
493
494 let mut miss_handler = MEM_CACHE
495 .get_miss_handler(&key1, &cache_meta, span)
496 .await
497 .unwrap();
498 miss_handler
499 .write_body(b"test1test2"[..].into(), false)
500 .await
501 .unwrap();
502 miss_handler.finish().await.unwrap();
503
504 let (cache_meta2, mut hit_handler) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
505 assert_eq!(
506 cache_meta.0.internal.fresh_until,
507 cache_meta2.0.internal.fresh_until
508 );
509
510 assert!(hit_handler.seek(10000, None).is_err());
512
513 assert!(hit_handler.seek(5, None).is_ok());
514 let data = hit_handler.read_body().await.unwrap().unwrap();
515 assert_eq!("test2", data);
516 let data = hit_handler.read_body().await.unwrap();
517 assert!(data.is_none());
518
519 assert!(hit_handler.seek(4, Some(5)).is_ok());
520 let data = hit_handler.read_body().await.unwrap().unwrap();
521 assert_eq!("1", data);
522 let data = hit_handler.read_body().await.unwrap();
523 assert!(data.is_none());
524 }
525
526 #[tokio::test]
527 async fn test_write_while_read() {
528 use futures::FutureExt;
529
530 static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
531 let span = &Span::inactive().handle();
532
533 let key1 = CacheKey::new("", "a", "1");
534 let res = MEM_CACHE.lookup(&key1, span).await.unwrap();
535 assert!(res.is_none());
536
537 let cache_meta = gen_meta();
538
539 let mut miss_handler = MEM_CACHE
540 .get_miss_handler(&key1, &cache_meta, span)
541 .await
542 .unwrap();
543
544 let (cache_meta1, mut hit_handler1) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
546 assert_eq!(
547 cache_meta.0.internal.fresh_until,
548 cache_meta1.0.internal.fresh_until
549 );
550
551 let res = hit_handler1.read_body().now_or_never();
553 assert!(res.is_none());
554
555 miss_handler
556 .write_body(b"test1"[..].into(), false)
557 .await
558 .unwrap();
559
560 let data = hit_handler1.read_body().await.unwrap().unwrap();
561 assert_eq!("test1", data);
562 let res = hit_handler1.read_body().now_or_never();
563 assert!(res.is_none());
564
565 miss_handler
566 .write_body(b"test2"[..].into(), false)
567 .await
568 .unwrap();
569 let data = hit_handler1.read_body().await.unwrap().unwrap();
570 assert_eq!("test2", data);
571
572 let (cache_meta2, mut hit_handler2) = MEM_CACHE.lookup(&key1, span).await.unwrap().unwrap();
574 assert_eq!(
575 cache_meta.0.internal.fresh_until,
576 cache_meta2.0.internal.fresh_until
577 );
578
579 let data = hit_handler2.read_body().await.unwrap().unwrap();
580 assert_eq!("test1test2", data);
581 let res = hit_handler2.read_body().now_or_never();
582 assert!(res.is_none());
583
584 let res = hit_handler1.read_body().now_or_never();
585 assert!(res.is_none());
586
587 miss_handler.finish().await.unwrap();
588
589 let data = hit_handler1.read_body().await.unwrap();
590 assert!(data.is_none());
591 let data = hit_handler2.read_body().await.unwrap();
592 assert!(data.is_none());
593 }
594
595 #[tokio::test]
596 async fn test_purge_partial() {
597 static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
598 let cache = &MEM_CACHE;
599
600 let key = CacheKey::new("", "a", "1").to_compact();
601 let hash = key.combined();
602 let meta = (
603 "meta_key".as_bytes().to_vec(),
604 "meta_value".as_bytes().to_vec(),
605 );
606
607 let temp_obj = TempObject::new(meta);
608 let mut map = HashMap::new();
609 map.insert(0, temp_obj);
610 cache.temp.write().insert(hash.clone(), map);
611
612 assert!(cache.temp.read().contains_key(&hash));
613
614 let result = cache
615 .purge(&key, PurgeType::Invalidation, &Span::inactive().handle())
616 .await;
617 assert!(result.is_ok());
618
619 assert!(!cache.temp.read().contains_key(&hash));
620 }
621
622 #[tokio::test]
623 async fn test_purge_complete() {
624 static MEM_CACHE: Lazy<MemCache> = Lazy::new(MemCache::new);
625 let cache = &MEM_CACHE;
626
627 let key = CacheKey::new("", "a", "1").to_compact();
628 let hash = key.combined();
629 let meta = (
630 "meta_key".as_bytes().to_vec(),
631 "meta_value".as_bytes().to_vec(),
632 );
633 let body = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0];
634 let cache_obj = CacheObject {
635 meta,
636 body: Arc::new(body),
637 };
638 cache.cached.write().insert(hash.clone(), cache_obj);
639
640 assert!(cache.cached.read().contains_key(&hash));
641
642 let result = cache
643 .purge(&key, PurgeType::Invalidation, &Span::inactive().handle())
644 .await;
645 assert!(result.is_ok());
646
647 assert!(!cache.cached.read().contains_key(&hash));
648 }
649}