1use super::{CacheStatus, MemoryCache};
19
20use async_trait::async_trait;
21use log::warn;
22use parking_lot::RwLock;
23use pingora_error::{Error, ErrorTrait};
24use std::collections::HashMap;
25use std::hash::Hash;
26use std::marker::PhantomData;
27use std::sync::Arc;
28use std::time::{Duration, Instant};
29use tokio::sync::Semaphore;
30
31struct CacheLock {
32 pub lock_start: Instant,
33 pub lock: Semaphore,
34}
35
36impl CacheLock {
37 pub fn new_arc() -> Arc<Self> {
38 Arc::new(CacheLock {
39 lock: Semaphore::new(0),
40 lock_start: Instant::now(),
41 })
42 }
43
44 pub fn too_old(&self, age: Option<&Duration>) -> bool {
45 match age {
46 Some(t) => Instant::now() - self.lock_start > *t,
47 None => false,
48 }
49 }
50}
51
52#[async_trait]
53pub trait Lookup<K, T, S> {
76 async fn lookup(
78 key: &K,
79 extra: Option<&S>,
80 ) -> Result<(T, Option<Duration>), Box<dyn ErrorTrait + Send + Sync>>
81 where
82 K: 'async_trait,
83 S: 'async_trait;
84}
85
86#[async_trait]
87pub trait MultiLookup<K, T, S> {
90 async fn multi_lookup(
92 keys: &[&K],
93 extra: Option<&S>,
94 ) -> Result<Vec<(T, Option<Duration>)>, Box<dyn ErrorTrait + Send + Sync>>
95 where
96 K: 'async_trait,
97 S: 'async_trait;
98}
99
100const LOOKUP_ERR_MSG: &str = "RTCache: lookup error";
101
102pub struct RTCache<K, T, CB, S>
111where
112 K: Hash + Send,
113 T: Clone + Send,
114{
115 inner: MemoryCache<K, T>,
116 _callback: PhantomData<CB>,
117 lockers: RwLock<HashMap<u64, Arc<CacheLock>>>,
118 lock_age: Option<Duration>,
119 lock_timeout: Option<Duration>,
120 phantom: PhantomData<S>,
121}
122
123impl<K, T, CB, S> RTCache<K, T, CB, S>
124where
125 K: Hash + Send,
126 T: Clone + Send + Sync + 'static,
127{
128 pub fn new(size: usize, lock_age: Option<Duration>, lock_timeout: Option<Duration>) -> Self {
131 RTCache {
132 inner: MemoryCache::new(size),
133 lockers: RwLock::new(HashMap::new()),
134 _callback: PhantomData,
135 lock_age,
136 lock_timeout,
137 phantom: PhantomData,
138 }
139 }
140}
141
142impl<K, T, CB, S> RTCache<K, T, CB, S>
143where
144 K: Hash + Send,
145 T: Clone + Send + Sync + 'static,
146 CB: Lookup<K, T, S>,
147{
148 pub async fn get(
151 &self,
152 key: &K,
153 ttl: Option<Duration>,
154 extra: Option<&S>,
155 ) -> (Result<T, Box<Error>>, CacheStatus) {
156 let (result, cache_state) = self.inner.get(key);
157 if let Some(result) = result {
158 return (Ok(result), cache_state);
160 }
161
162 let hashed_key = self.inner.hasher.hash_one(key);
163
164 let my_lock = {
166 let lockers = self.lockers.read();
167 lockers.get(&hashed_key).cloned()
169 }; let (my_write, my_read) = match my_lock {
173 Some(lock) => {
175 if lock.too_old(self.lock_age.as_ref()) {
177 (None, None)
178 } else {
179 (None, Some(lock))
180 }
181 }
182 None => {
183 let mut lockers = self.lockers.write();
184 match lockers.get(&hashed_key) {
185 Some(lock) => {
186 if lock.too_old(self.lock_age.as_ref()) {
188 (None, None)
189 } else {
190 (None, Some(lock.clone()))
191 }
192 }
193 None => {
194 let new_lock = CacheLock::new_arc();
195 let new_lock2 = new_lock.clone();
196 lockers.insert(hashed_key, new_lock2);
197 (Some(new_lock), None)
198 }
199 } }
201 };
202
203 if let Some(my_lock) = my_read {
204 if my_lock.lock.available_permits() == 0 {
208 let lock_fut = my_lock.lock.acquire();
210 let timed_out = match self.lock_timeout {
211 Some(t) => pingora_timeout::timeout(t, lock_fut).await.is_err(),
212 None => {
213 let _ = lock_fut.await;
214 false
215 }
216 };
217 if timed_out {
218 let value = CB::lookup(key, extra).await;
219 return match value {
220 Ok((v, _ttl)) => (Ok(v), cache_state),
221 Err(e) => {
222 let mut err = Error::new_str(LOOKUP_ERR_MSG);
223 err.set_cause(e);
224 (Err(err), cache_state)
225 }
226 };
227 }
228 } let (result, cache_state) = self.inner.get(key);
231 if let Some(result) = result {
232 (Ok(result), CacheStatus::LockHit)
234 } else {
235 warn!(
237 "RTCache: no result after read lock, cache status: {:?}",
238 cache_state
239 );
240 match CB::lookup(key, extra).await {
241 Ok((v, new_ttl)) => {
242 self.inner.force_put(key, v.clone(), new_ttl.or(ttl));
243 (Ok(v), cache_state)
244 }
245 Err(e) => {
246 let mut err = Error::new_str(LOOKUP_ERR_MSG);
247 err.set_cause(e);
248 (Err(err), cache_state)
249 }
250 }
251 }
252 } else {
253 let value = CB::lookup(key, extra).await;
256 let ret = match value {
257 Ok((v, new_ttl)) => {
258 if my_write.is_some() {
260 self.inner.force_put(key, v.clone(), new_ttl.or(ttl));
261 }
262 (Ok(v), cache_state) }
264 Err(e) => {
265 let mut err = Error::new_str(LOOKUP_ERR_MSG);
266 err.set_cause(e);
267 (Err(err), cache_state)
268 }
269 };
270 if let Some(my_write) = my_write {
271 my_write.lock.add_permits(10);
274
275 {
276 let mut lockers = self.lockers.write();
278 lockers.remove(&hashed_key);
279 } }
281
282 ret
283 }
284 }
285
286 pub async fn get_stale(
290 &self,
291 key: &K,
292 ttl: Option<Duration>,
293 extra: Option<&S>,
294 stale_ttl: Duration,
295 ) -> (Result<T, Box<Error>>, CacheStatus) {
296 let (result, cache_status) = self.inner.get_stale(key);
297 if let Some(result) = result {
298 let stale_duration = cache_status.stale();
299 if stale_duration.unwrap_or(Duration::ZERO) <= stale_ttl {
300 return (Ok(result), cache_status);
301 }
302 }
303 let (res, status) = self.get(key, ttl, extra).await;
304 (res, status)
305 }
306}
307
308impl<K, T, CB, S> RTCache<K, T, CB, S>
309where
310 K: Hash + Clone + Send + Sync,
311 T: Clone + Send + Sync + 'static,
312 S: Clone + Send + Sync,
313 CB: Lookup<K, T, S> + Sync + Send,
314{
315 pub async fn get_stale_while_update(
323 &'static self,
324 key: &K,
325 ttl: Option<Duration>,
326 extra: Option<&S>,
327 stale_ttl: Duration,
328 ) -> (Result<T, Box<Error>>, CacheStatus) {
329 let (result, cache_status) = self.get_stale(key, ttl, extra, stale_ttl).await;
330 let key = key.clone();
331 let extra = extra.cloned();
332 if cache_status.stale().is_some() {
333 tokio::spawn(async move {
334 let _ = self.get(&key, ttl, extra.as_ref()).await;
335 });
336 }
337 (result, cache_status)
338 }
339}
340
341impl<K, T, CB, S> RTCache<K, T, CB, S>
342where
343 K: Hash + Send,
344 T: Clone + Send + Sync + 'static,
345 CB: MultiLookup<K, T, S>,
346{
347 pub async fn multi_get<'a, I>(
356 &self,
357 keys: I,
358 ttl: Option<Duration>,
359 extra: Option<&S>,
360 ) -> Result<Vec<(T, CacheStatus)>, Box<Error>>
361 where
362 I: Iterator<Item = &'a K>,
363 K: 'a,
364 {
365 let size = keys.size_hint().0;
366 let (hits, misses) = self.inner.multi_get_with_miss(keys);
367 let mut final_results = Vec::with_capacity(size);
368 let miss_results = if !misses.is_empty() {
369 match CB::multi_lookup(&misses, extra).await {
370 Ok(miss_results) => {
371 assert!(
374 miss_results.len() == misses.len(),
375 "multi_lookup() failed to return the matching number of results"
376 );
377 for item in misses.iter().zip(miss_results.iter()) {
379 self.inner
380 .force_put(item.0, (item.1).0.clone(), (item.1).1.or(ttl));
381 }
382 miss_results
383 }
384 Err(e) => {
385 let mut err = Error::new_str(LOOKUP_ERR_MSG);
387 err.set_cause(e);
388 return Err(err);
389 }
390 }
391 } else {
392 vec![] };
394 let mut n_miss = 0;
396 for item in hits {
397 match item.0 {
398 Some(v) => final_results.push((v, item.1)),
399 None => {
400 final_results .push((miss_results[n_miss].0.clone(), CacheStatus::Miss));
402 n_miss += 1;
403 }
404 }
405 }
406 Ok(final_results)
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use atomic::AtomicI32;
414 use std::sync::atomic;
415
416 #[derive(Clone, Debug)]
417 struct ExtraOpt {
418 error: bool,
419 empty: bool,
420 delay_for: Option<Duration>,
421 used: Arc<AtomicI32>,
422 }
423
424 struct TestCB();
425
426 #[async_trait]
427 impl Lookup<i32, i32, ExtraOpt> for TestCB {
428 async fn lookup(
429 _key: &i32,
430 extra: Option<&ExtraOpt>,
431 ) -> Result<(i32, Option<Duration>), Box<dyn ErrorTrait + Send + Sync>> {
432 let mut used = 0;
434 if let Some(e) = extra {
435 used = e.used.fetch_add(1, atomic::Ordering::Relaxed) + 1;
436 if e.error {
437 return Err(Error::new_str("test error"));
438 }
439 if let Some(delay_for) = e.delay_for {
440 tokio::time::sleep(delay_for).await;
441 }
442 }
443 Ok((used, None))
444 }
445 }
446
447 #[async_trait]
448 impl MultiLookup<i32, i32, ExtraOpt> for TestCB {
449 async fn multi_lookup(
450 keys: &[&i32],
451 extra: Option<&ExtraOpt>,
452 ) -> Result<Vec<(i32, Option<Duration>)>, Box<dyn ErrorTrait + Send + Sync>> {
453 let mut resp = vec![];
454 if let Some(extra) = extra {
455 if extra.empty {
456 return Ok(resp);
457 }
458 }
459 for key in keys {
460 resp.push((**key, None));
461 }
462 Ok(resp)
463 }
464 }
465
466 #[tokio::test]
467 async fn test_basic_get() {
468 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
469 let opt = Some(ExtraOpt {
470 error: false,
471 empty: false,
472 delay_for: None,
473 used: Arc::new(AtomicI32::new(0)),
474 });
475 let (res, hit) = cache.get(&1, None, opt.as_ref()).await;
476 assert_eq!(res.unwrap(), 1);
477 assert_eq!(hit, CacheStatus::Miss);
478 let (res, hit) = cache.get(&1, None, opt.as_ref()).await;
479 assert_eq!(res.unwrap(), 1);
480 assert_eq!(hit, CacheStatus::Hit);
481 }
482
483 #[tokio::test]
484 async fn test_basic_get_error() {
485 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
486 let opt1 = Some(ExtraOpt {
487 error: true,
488 empty: false,
489 delay_for: None,
490 used: Arc::new(AtomicI32::new(0)),
491 });
492 let (res, hit) = cache.get(&-1, None, opt1.as_ref()).await;
493 assert!(res.is_err());
494 assert_eq!(hit, CacheStatus::Miss);
495 }
496
497 #[tokio::test]
498 async fn test_concurrent_get() {
499 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
500 let cache = Arc::new(cache);
501 let opt = Some(ExtraOpt {
502 error: false,
503 empty: false,
504 delay_for: None,
505 used: Arc::new(AtomicI32::new(0)),
506 });
507 let cache_c = cache.clone();
508 let opt1 = opt.clone();
509 let t1 = tokio::spawn(async move {
511 let (res, _hit) = cache_c.get(&1, None, opt1.as_ref()).await;
512 res.unwrap()
513 });
514 let cache_c = cache.clone();
515 let opt2 = opt.clone();
516 let t2 = tokio::spawn(async move {
517 let (res, _hit) = cache_c.get(&1, None, opt2.as_ref()).await;
518 res.unwrap()
519 });
520 let opt3 = opt.clone();
521 let cache_c = cache.clone();
522 let t3 = tokio::spawn(async move {
523 let (res, _hit) = cache_c.get(&1, None, opt3.as_ref()).await;
524 res.unwrap()
525 });
526 let (r1, r2, r3) = tokio::join!(t1, t2, t3);
527 assert_eq!(r1.unwrap(), 1);
528 assert_eq!(r2.unwrap(), 1);
529 assert_eq!(r3.unwrap(), 1);
530 }
531
532 #[tokio::test]
533 async fn test_concurrent_get_error() {
534 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
535 let cache = Arc::new(cache);
536 let cache_c = cache.clone();
537 let opt1 = Some(ExtraOpt {
538 error: true,
539 empty: false,
540 delay_for: None,
541 used: Arc::new(AtomicI32::new(0)),
542 });
543 let opt2 = opt1.clone();
544 let opt3 = opt1.clone();
545 let t1 = tokio::spawn(async move {
547 let (res, _hit) = cache_c.get(&-1, None, opt1.as_ref()).await;
548 res.is_err()
549 });
550 let cache_c = cache.clone();
551 let t2 = tokio::spawn(async move {
552 let (res, _hit) = cache_c.get(&-1, None, opt2.as_ref()).await;
553 res.is_err()
554 });
555 let cache_c = cache.clone();
556 let t3 = tokio::spawn(async move {
557 let (res, _hit) = cache_c.get(&-1, None, opt3.as_ref()).await;
558 res.is_err()
559 });
560 let (r1, r2, r3) = tokio::join!(t1, t2, t3);
561 assert!(r1.unwrap());
562 assert!(r2.unwrap());
563 assert!(r3.unwrap());
564 }
565
566 #[tokio::test]
567 async fn test_concurrent_get_different_value() {
568 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
569 let cache = Arc::new(cache);
570 let opt1 = Some(ExtraOpt {
571 error: false,
572 empty: false,
573 delay_for: None,
574 used: Arc::new(AtomicI32::new(0)),
575 });
576 let opt2 = opt1.clone();
577 let opt3 = opt1.clone();
578 let cache_c = cache.clone();
579 let t1 = tokio::spawn(async move {
581 let (res, _hit) = cache_c.get(&1, None, opt1.as_ref()).await;
582 res.unwrap()
583 });
584 let cache_c = cache.clone();
585 let t2 = tokio::spawn(async move {
586 let (res, _hit) = cache_c.get(&3, None, opt2.as_ref()).await;
587 res.unwrap()
588 });
589 let cache_c = cache.clone();
590 let t3 = tokio::spawn(async move {
591 let (res, _hit) = cache_c.get(&5, None, opt3.as_ref()).await;
592 res.unwrap()
593 });
594 let (r1, r2, r3) = tokio::join!(t1, t2, t3);
595 assert_eq!(r1.unwrap() + r2.unwrap() + r3.unwrap(), 6);
597 }
598
599 #[tokio::test]
600 async fn test_get_lock_age() {
601 let cache: RTCache<i32, i32, TestCB, ExtraOpt> =
603 RTCache::new(10, Some(Duration::from_secs(1)), None);
604 let cache = Arc::new(cache);
605 let counter = Arc::new(AtomicI32::new(0));
606 let opt1 = Some(ExtraOpt {
607 error: false,
608 empty: false,
609 delay_for: Some(Duration::from_secs(2)),
610 used: counter.clone(),
611 });
612
613 let opt2 = Some(ExtraOpt {
614 error: false,
615 empty: false,
616 delay_for: None,
617 used: counter.clone(),
618 });
619 let opt3 = opt2.clone();
620 let cache_c = cache.clone();
621 let t1 = tokio::spawn(async move {
623 let (res, _hit) = cache_c.get(&1, None, opt1.as_ref()).await;
624 res.unwrap()
625 });
626 tokio::time::sleep(Duration::from_secs_f32(1.5)).await;
628 let cache_c = cache.clone();
629 let t2 = tokio::spawn(async move {
630 let (res, _hit) = cache_c.get(&1, None, opt2.as_ref()).await;
631 res.unwrap()
632 });
633 let cache_c = cache.clone();
634 let t3 = tokio::spawn(async move {
635 let (res, _hit) = cache_c.get(&1, None, opt3.as_ref()).await;
636 res.unwrap()
637 });
638 let (r1, r2, r3) = tokio::join!(t1, t2, t3);
639 assert_eq!(r1.unwrap() + r2.unwrap() + r3.unwrap(), 6);
641 }
642
643 #[tokio::test]
644 async fn test_get_lock_timeout() {
645 let cache: RTCache<i32, i32, TestCB, ExtraOpt> =
647 RTCache::new(10, None, Some(Duration::from_secs(1)));
648 let cache = Arc::new(cache);
649 let counter = Arc::new(AtomicI32::new(0));
650 let opt1 = Some(ExtraOpt {
651 error: false,
652 empty: false,
653 delay_for: Some(Duration::from_secs(2)),
654 used: counter.clone(),
655 });
656 let opt2 = Some(ExtraOpt {
657 error: false,
658 empty: false,
659 delay_for: None,
660 used: counter.clone(),
661 });
662 let opt3 = opt2.clone();
663 let cache_c = cache.clone();
664 let t1 = tokio::spawn(async move {
666 let (res, _hit) = cache_c.get(&1, None, opt1.as_ref()).await;
667 res.unwrap()
668 });
669 let cache_c = cache.clone();
671 let t2 = tokio::spawn(async move {
672 let (res, _hit) = cache_c.get(&1, None, opt2.as_ref()).await;
673 res.unwrap()
674 });
675 let cache_c = cache.clone();
676 let t3 = tokio::spawn(async move {
677 let (res, _hit) = cache_c.get(&1, None, opt3.as_ref()).await;
678 res.unwrap()
679 });
680 let (r1, r2, r3) = tokio::join!(t1, t2, t3);
681 assert_eq!(r1.unwrap() + r2.unwrap() + r3.unwrap(), 6);
683 }
684
685 #[tokio::test]
686 async fn test_multi_get() {
687 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
688 let counter = Arc::new(AtomicI32::new(0));
689 let opt1 = Some(ExtraOpt {
690 error: false,
691 empty: false,
692 delay_for: Some(Duration::from_secs(2)),
693 used: counter.clone(),
694 });
695 let (res, hit) = cache.get(&1, None, opt1.as_ref()).await;
697 assert_eq!(res.unwrap(), 1);
698 assert_eq!(hit, CacheStatus::Miss);
699 let (res, hit) = cache.get(&1, None, opt1.as_ref()).await;
700 assert_eq!(res.unwrap(), 1);
701 assert_eq!(hit, CacheStatus::Hit);
702 let resp = cache
704 .multi_get([1, 2, 3].iter(), None, opt1.as_ref())
705 .await
706 .unwrap();
707 assert_eq!(resp[0].0, 1);
708 assert_eq!(resp[0].1, CacheStatus::Hit);
709 assert_eq!(resp[1].0, 2);
710 assert_eq!(resp[1].1, CacheStatus::Miss);
711 assert_eq!(resp[2].0, 3);
712 assert_eq!(resp[2].1, CacheStatus::Miss);
713 let resp = cache
715 .multi_get([1, 2, 3].iter(), None, opt1.as_ref())
716 .await
717 .unwrap();
718 assert_eq!(resp[0].0, 1);
719 assert_eq!(resp[0].1, CacheStatus::Hit);
720 assert_eq!(resp[1].0, 2);
721 assert_eq!(resp[1].1, CacheStatus::Hit);
722 assert_eq!(resp[2].0, 3);
723 assert_eq!(resp[2].1, CacheStatus::Hit);
724 }
725
726 #[tokio::test]
727 #[should_panic(expected = "multi_lookup() failed to return the matching number of results")]
728 async fn test_inconsistent_miss_results() {
729 let opt1 = Some(ExtraOpt {
731 error: false,
732 empty: true,
733 delay_for: None,
734 used: Arc::new(AtomicI32::new(0)),
735 });
736 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
737 cache
738 .multi_get([4, 5, 6].iter(), None, opt1.as_ref())
739 .await
740 .unwrap();
741 }
742
743 #[tokio::test]
744 async fn test_get_stale() {
745 let ttl = Some(Duration::from_millis(100));
746 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
747 let opt = Some(ExtraOpt {
748 error: false,
749 empty: false,
750 delay_for: None,
751 used: Arc::new(AtomicI32::new(0)),
752 });
753 let (res, hit) = cache.get(&1, ttl, opt.as_ref()).await;
754 assert_eq!(res.unwrap(), 1);
755 assert_eq!(hit, CacheStatus::Miss);
756 let (res, hit) = cache.get(&1, ttl, opt.as_ref()).await;
757 assert_eq!(res.unwrap(), 1);
758 assert_eq!(hit, CacheStatus::Hit);
759 tokio::time::sleep(Duration::from_millis(150)).await;
760 let (res, hit) = cache
761 .get_stale(&1, ttl, opt.as_ref(), Duration::from_millis(1000))
762 .await;
763 assert_eq!(res.unwrap(), 1);
764 assert!(hit.stale().is_some());
765
766 let (res, hit) = cache
767 .get_stale(&1, ttl, opt.as_ref(), Duration::from_millis(30))
768 .await;
769 assert_eq!(res.unwrap(), 2);
770 assert_eq!(hit, CacheStatus::Expired);
771 }
772
773 #[tokio::test]
774 async fn test_get_stale_while_update() {
775 use once_cell::sync::Lazy;
776 let ttl = Some(Duration::from_millis(100));
777 static CACHE: Lazy<RTCache<i32, i32, TestCB, ExtraOpt>> =
778 Lazy::new(|| RTCache::new(10, None, None));
779 let opt = Some(ExtraOpt {
780 error: false,
781 empty: false,
782 delay_for: None,
783 used: Arc::new(AtomicI32::new(0)),
784 });
785 let (res, hit) = CACHE.get(&1, ttl, opt.as_ref()).await;
786 assert_eq!(res.unwrap(), 1);
787 assert_eq!(hit, CacheStatus::Miss);
788 let (res, hit) = CACHE.get(&1, ttl, opt.as_ref()).await;
789 assert_eq!(res.unwrap(), 1);
790 assert_eq!(hit, CacheStatus::Hit);
791 tokio::time::sleep(Duration::from_millis(150)).await;
792 let (res, hit) = CACHE
793 .get_stale_while_update(&1, ttl, opt.as_ref(), Duration::from_millis(1000))
794 .await;
795 assert_eq!(res.unwrap(), 1);
796 assert!(hit.stale().is_some());
797
798 tokio::time::sleep(Duration::from_millis(10)).await;
800
801 let (res, hit) = CACHE.get(&1, ttl, opt.as_ref()).await;
802 assert_eq!(res.unwrap(), 2);
803 assert_eq!(hit, CacheStatus::Hit);
804 }
805}