1use super::{CacheStatus, MemoryCache};
19
20use async_trait::async_trait;
21use log::warn;
22use parking_lot::RwLock;
23use pingora_error::{Error, ErrorTrait};
24use std::collections::HashMap;
25use std::hash::Hash;
26use std::marker::PhantomData;
27use std::sync::Arc;
28use std::time::{Duration, Instant};
29use tokio::sync::Semaphore;
30
31struct CacheLock {
32 pub lock_start: Instant,
33 pub lock: Semaphore,
34}
35
36impl CacheLock {
37 pub fn new_arc() -> Arc<Self> {
38 Arc::new(CacheLock {
39 lock: Semaphore::new(0),
40 lock_start: Instant::now(),
41 })
42 }
43
44 pub fn too_old(&self, age: Option<&Duration>) -> bool {
45 match age {
46 Some(t) => Instant::now() - self.lock_start > *t,
47 None => false,
48 }
49 }
50}
51
52#[async_trait]
53pub trait Lookup<K, T, S> {
76 async fn lookup(
78 key: &K,
79 extra: Option<&S>,
80 ) -> Result<(T, Option<Duration>), Box<dyn ErrorTrait + Send + Sync>>
81 where
82 K: 'async_trait,
83 S: 'async_trait;
84}
85
86#[async_trait]
87pub trait MultiLookup<K, T, S> {
90 async fn multi_lookup(
92 keys: &[&K],
93 extra: Option<&S>,
94 ) -> Result<Vec<(T, Option<Duration>)>, Box<dyn ErrorTrait + Send + Sync>>
95 where
96 K: 'async_trait,
97 S: 'async_trait;
98}
99
100const LOOKUP_ERR_MSG: &str = "RTCache: lookup error";
101
102pub struct RTCache<K, T, CB, S>
111where
112 K: Hash + Send,
113 T: Clone + Send,
114{
115 inner: MemoryCache<K, T>,
116 _callback: PhantomData<CB>,
117 lockers: RwLock<HashMap<u64, Arc<CacheLock>>>,
118 lock_age: Option<Duration>,
119 lock_timeout: Option<Duration>,
120 phantom: PhantomData<S>,
121}
122
123impl<K, T, CB, S> RTCache<K, T, CB, S>
124where
125 K: Hash + Send,
126 T: Clone + Send + Sync + 'static,
127{
128 pub fn new(size: usize, lock_age: Option<Duration>, lock_timeout: Option<Duration>) -> Self {
131 RTCache {
132 inner: MemoryCache::new(size),
133 lockers: RwLock::new(HashMap::new()),
134 _callback: PhantomData,
135 lock_age,
136 lock_timeout,
137 phantom: PhantomData,
138 }
139 }
140}
141
142impl<K, T, CB, S> RTCache<K, T, CB, S>
143where
144 K: Hash + Send,
145 T: Clone + Send + Sync + 'static,
146 CB: Lookup<K, T, S>,
147{
148 pub async fn get(
151 &self,
152 key: &K,
153 ttl: Option<Duration>,
154 extra: Option<&S>,
155 ) -> (Result<T, Box<Error>>, CacheStatus) {
156 let (result, cache_state) = self.inner.get(key);
157 if let Some(result) = result {
158 return (Ok(result), cache_state);
160 }
161
162 let hashed_key = self.inner.hasher.hash_one(key);
163
164 let my_lock = {
166 let lockers = self.lockers.read();
167 lockers.get(&hashed_key).cloned()
169 }; let (my_write, my_read) = match my_lock {
173 Some(lock) => {
175 if lock.too_old(self.lock_age.as_ref()) {
177 (None, None)
178 } else {
179 (None, Some(lock))
180 }
181 }
182 None => {
183 let mut lockers = self.lockers.write();
184 match lockers.get(&hashed_key) {
185 Some(lock) => {
186 if lock.too_old(self.lock_age.as_ref()) {
188 (None, None)
189 } else {
190 (None, Some(lock.clone()))
191 }
192 }
193 None => {
194 let new_lock = CacheLock::new_arc();
195 let new_lock2 = new_lock.clone();
196 lockers.insert(hashed_key, new_lock2);
197 (Some(new_lock), None)
198 }
199 } }
201 };
202
203 if my_read.is_some() {
204 let my_lock = my_read.unwrap();
207 if my_lock.lock.available_permits() == 0 {
209 let lock_fut = my_lock.lock.acquire();
211 let timed_out = match self.lock_timeout {
212 Some(t) => pingora_timeout::timeout(t, lock_fut).await.is_err(),
213 None => {
214 let _ = lock_fut.await;
215 false
216 }
217 };
218 if timed_out {
219 let value = CB::lookup(key, extra).await;
220 return match value {
221 Ok((v, _ttl)) => (Ok(v), cache_state),
222 Err(e) => {
223 let mut err = Error::new_str(LOOKUP_ERR_MSG);
224 err.set_cause(e);
225 (Err(err), cache_state)
226 }
227 };
228 }
229 } let (result, cache_state) = self.inner.get(key);
232 if let Some(result) = result {
233 (Ok(result), CacheStatus::LockHit)
235 } else {
236 warn!(
238 "RTCache: no result after read lock, cache status: {:?}",
239 cache_state
240 );
241 match CB::lookup(key, extra).await {
242 Ok((v, new_ttl)) => {
243 self.inner.force_put(key, v.clone(), new_ttl.or(ttl));
244 (Ok(v), cache_state)
245 }
246 Err(e) => {
247 let mut err = Error::new_str(LOOKUP_ERR_MSG);
248 err.set_cause(e);
249 (Err(err), cache_state)
250 }
251 }
252 }
253 } else {
254 let value = CB::lookup(key, extra).await;
257 let ret = match value {
258 Ok((v, new_ttl)) => {
259 if my_write.is_some() {
261 self.inner.force_put(key, v.clone(), new_ttl.or(ttl));
262 }
263 (Ok(v), cache_state) }
265 Err(e) => {
266 let mut err = Error::new_str(LOOKUP_ERR_MSG);
267 err.set_cause(e);
268 (Err(err), cache_state)
269 }
270 };
271 if my_write.is_some() {
272 my_write.unwrap().lock.add_permits(10);
275
276 {
277 let mut lockers = self.lockers.write();
279 lockers.remove(&hashed_key);
280 } }
282
283 ret
284 }
285 }
286
287 pub async fn get_stale(
291 &self,
292 key: &K,
293 ttl: Option<Duration>,
294 extra: Option<&S>,
295 stale_ttl: Duration,
296 ) -> (Result<T, Box<Error>>, CacheStatus) {
297 let (result, cache_status) = self.inner.get_stale(key);
298 if let Some(result) = result {
299 let stale_duration = cache_status.stale();
300 if stale_duration.unwrap_or(Duration::ZERO) <= stale_ttl {
301 return (Ok(result), cache_status);
302 }
303 }
304 let (res, status) = self.get(key, ttl, extra).await;
305 (res, status)
306 }
307}
308
309impl<K, T, CB, S> RTCache<K, T, CB, S>
310where
311 K: Hash + Clone + Send + Sync,
312 T: Clone + Send + Sync + 'static,
313 S: Clone + Send + Sync,
314 CB: Lookup<K, T, S> + Sync + Send,
315{
316 pub async fn get_stale_while_update(
324 &'static self,
325 key: &K,
326 ttl: Option<Duration>,
327 extra: Option<&S>,
328 stale_ttl: Duration,
329 ) -> (Result<T, Box<Error>>, CacheStatus) {
330 let (result, cache_status) = self.get_stale(key, ttl, extra, stale_ttl).await;
331 let key = key.clone();
332 let extra = extra.cloned();
333 if cache_status.stale().is_some() {
334 tokio::spawn(async move {
335 let _ = self.get(&key, ttl, extra.as_ref()).await;
336 });
337 }
338 (result, cache_status)
339 }
340}
341
342impl<K, T, CB, S> RTCache<K, T, CB, S>
343where
344 K: Hash + Send,
345 T: Clone + Send + Sync + 'static,
346 CB: MultiLookup<K, T, S>,
347{
348 pub async fn multi_get<'a, I>(
357 &self,
358 keys: I,
359 ttl: Option<Duration>,
360 extra: Option<&S>,
361 ) -> Result<Vec<(T, CacheStatus)>, Box<Error>>
362 where
363 I: Iterator<Item = &'a K>,
364 K: 'a,
365 {
366 let size = keys.size_hint().0;
367 let (hits, misses) = self.inner.multi_get_with_miss(keys);
368 let mut final_results = Vec::with_capacity(size);
369 let miss_results = if !misses.is_empty() {
370 match CB::multi_lookup(&misses, extra).await {
371 Ok(miss_results) => {
372 assert!(
375 miss_results.len() == misses.len(),
376 "multi_lookup() failed to return the matching number of results"
377 );
378 for item in misses.iter().zip(miss_results.iter()) {
380 self.inner
381 .force_put(item.0, (item.1).0.clone(), (item.1).1.or(ttl));
382 }
383 miss_results
384 }
385 Err(e) => {
386 let mut err = Error::new_str(LOOKUP_ERR_MSG);
388 err.set_cause(e);
389 return Err(err);
390 }
391 }
392 } else {
393 vec![] };
395 let mut n_miss = 0;
397 for item in hits {
398 match item.0 {
399 Some(v) => final_results.push((v, item.1)),
400 None => {
401 final_results .push((miss_results[n_miss].0.clone(), CacheStatus::Miss));
403 n_miss += 1;
404 }
405 }
406 }
407 Ok(final_results)
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use atomic::AtomicI32;
415 use std::sync::atomic;
416
417 #[derive(Clone, Debug)]
418 struct ExtraOpt {
419 error: bool,
420 empty: bool,
421 delay_for: Option<Duration>,
422 used: Arc<AtomicI32>,
423 }
424
425 struct TestCB();
426
427 #[async_trait]
428 impl Lookup<i32, i32, ExtraOpt> for TestCB {
429 async fn lookup(
430 _key: &i32,
431 extra: Option<&ExtraOpt>,
432 ) -> Result<(i32, Option<Duration>), Box<dyn ErrorTrait + Send + Sync>> {
433 let mut used = 0;
435 if let Some(e) = extra {
436 used = e.used.fetch_add(1, atomic::Ordering::Relaxed) + 1;
437 if e.error {
438 return Err(Error::new_str("test error"));
439 }
440 if let Some(delay_for) = e.delay_for {
441 tokio::time::sleep(delay_for).await;
442 }
443 }
444 Ok((used, None))
445 }
446 }
447
448 #[async_trait]
449 impl MultiLookup<i32, i32, ExtraOpt> for TestCB {
450 async fn multi_lookup(
451 keys: &[&i32],
452 extra: Option<&ExtraOpt>,
453 ) -> Result<Vec<(i32, Option<Duration>)>, Box<dyn ErrorTrait + Send + Sync>> {
454 let mut resp = vec![];
455 if let Some(extra) = extra {
456 if extra.empty {
457 return Ok(resp);
458 }
459 }
460 for key in keys {
461 resp.push((**key, None));
462 }
463 Ok(resp)
464 }
465 }
466
467 #[tokio::test]
468 async fn test_basic_get() {
469 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
470 let opt = Some(ExtraOpt {
471 error: false,
472 empty: false,
473 delay_for: None,
474 used: Arc::new(AtomicI32::new(0)),
475 });
476 let (res, hit) = cache.get(&1, None, opt.as_ref()).await;
477 assert_eq!(res.unwrap(), 1);
478 assert_eq!(hit, CacheStatus::Miss);
479 let (res, hit) = cache.get(&1, None, opt.as_ref()).await;
480 assert_eq!(res.unwrap(), 1);
481 assert_eq!(hit, CacheStatus::Hit);
482 }
483
484 #[tokio::test]
485 async fn test_basic_get_error() {
486 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
487 let opt1 = Some(ExtraOpt {
488 error: true,
489 empty: false,
490 delay_for: None,
491 used: Arc::new(AtomicI32::new(0)),
492 });
493 let (res, hit) = cache.get(&-1, None, opt1.as_ref()).await;
494 assert!(res.is_err());
495 assert_eq!(hit, CacheStatus::Miss);
496 }
497
498 #[tokio::test]
499 async fn test_concurrent_get() {
500 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
501 let cache = Arc::new(cache);
502 let opt = Some(ExtraOpt {
503 error: false,
504 empty: false,
505 delay_for: None,
506 used: Arc::new(AtomicI32::new(0)),
507 });
508 let cache_c = cache.clone();
509 let opt1 = opt.clone();
510 let t1 = tokio::spawn(async move {
512 let (res, _hit) = cache_c.get(&1, None, opt1.as_ref()).await;
513 res.unwrap()
514 });
515 let cache_c = cache.clone();
516 let opt2 = opt.clone();
517 let t2 = tokio::spawn(async move {
518 let (res, _hit) = cache_c.get(&1, None, opt2.as_ref()).await;
519 res.unwrap()
520 });
521 let opt3 = opt.clone();
522 let cache_c = cache.clone();
523 let t3 = tokio::spawn(async move {
524 let (res, _hit) = cache_c.get(&1, None, opt3.as_ref()).await;
525 res.unwrap()
526 });
527 let (r1, r2, r3) = tokio::join!(t1, t2, t3);
528 assert_eq!(r1.unwrap(), 1);
529 assert_eq!(r2.unwrap(), 1);
530 assert_eq!(r3.unwrap(), 1);
531 }
532
533 #[tokio::test]
534 async fn test_concurrent_get_error() {
535 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
536 let cache = Arc::new(cache);
537 let cache_c = cache.clone();
538 let opt1 = Some(ExtraOpt {
539 error: true,
540 empty: false,
541 delay_for: None,
542 used: Arc::new(AtomicI32::new(0)),
543 });
544 let opt2 = opt1.clone();
545 let opt3 = opt1.clone();
546 let t1 = tokio::spawn(async move {
548 let (res, _hit) = cache_c.get(&-1, None, opt1.as_ref()).await;
549 res.is_err()
550 });
551 let cache_c = cache.clone();
552 let t2 = tokio::spawn(async move {
553 let (res, _hit) = cache_c.get(&-1, None, opt2.as_ref()).await;
554 res.is_err()
555 });
556 let cache_c = cache.clone();
557 let t3 = tokio::spawn(async move {
558 let (res, _hit) = cache_c.get(&-1, None, opt3.as_ref()).await;
559 res.is_err()
560 });
561 let (r1, r2, r3) = tokio::join!(t1, t2, t3);
562 assert!(r1.unwrap());
563 assert!(r2.unwrap());
564 assert!(r3.unwrap());
565 }
566
567 #[tokio::test]
568 async fn test_concurrent_get_different_value() {
569 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
570 let cache = Arc::new(cache);
571 let opt1 = Some(ExtraOpt {
572 error: false,
573 empty: false,
574 delay_for: None,
575 used: Arc::new(AtomicI32::new(0)),
576 });
577 let opt2 = opt1.clone();
578 let opt3 = opt1.clone();
579 let cache_c = cache.clone();
580 let t1 = tokio::spawn(async move {
582 let (res, _hit) = cache_c.get(&1, None, opt1.as_ref()).await;
583 res.unwrap()
584 });
585 let cache_c = cache.clone();
586 let t2 = tokio::spawn(async move {
587 let (res, _hit) = cache_c.get(&3, None, opt2.as_ref()).await;
588 res.unwrap()
589 });
590 let cache_c = cache.clone();
591 let t3 = tokio::spawn(async move {
592 let (res, _hit) = cache_c.get(&5, None, opt3.as_ref()).await;
593 res.unwrap()
594 });
595 let (r1, r2, r3) = tokio::join!(t1, t2, t3);
596 assert_eq!(r1.unwrap() + r2.unwrap() + r3.unwrap(), 6);
598 }
599
600 #[tokio::test]
601 async fn test_get_lock_age() {
602 let cache: RTCache<i32, i32, TestCB, ExtraOpt> =
604 RTCache::new(10, Some(Duration::from_secs(1)), None);
605 let cache = Arc::new(cache);
606 let counter = Arc::new(AtomicI32::new(0));
607 let opt1 = Some(ExtraOpt {
608 error: false,
609 empty: false,
610 delay_for: Some(Duration::from_secs(2)),
611 used: counter.clone(),
612 });
613
614 let opt2 = Some(ExtraOpt {
615 error: false,
616 empty: false,
617 delay_for: None,
618 used: counter.clone(),
619 });
620 let opt3 = opt2.clone();
621 let cache_c = cache.clone();
622 let t1 = tokio::spawn(async move {
624 let (res, _hit) = cache_c.get(&1, None, opt1.as_ref()).await;
625 res.unwrap()
626 });
627 tokio::time::sleep(Duration::from_secs_f32(1.5)).await;
629 let cache_c = cache.clone();
630 let t2 = tokio::spawn(async move {
631 let (res, _hit) = cache_c.get(&1, None, opt2.as_ref()).await;
632 res.unwrap()
633 });
634 let cache_c = cache.clone();
635 let t3 = tokio::spawn(async move {
636 let (res, _hit) = cache_c.get(&1, None, opt3.as_ref()).await;
637 res.unwrap()
638 });
639 let (r1, r2, r3) = tokio::join!(t1, t2, t3);
640 assert_eq!(r1.unwrap() + r2.unwrap() + r3.unwrap(), 6);
642 }
643
644 #[tokio::test]
645 async fn test_get_lock_timeout() {
646 let cache: RTCache<i32, i32, TestCB, ExtraOpt> =
648 RTCache::new(10, None, Some(Duration::from_secs(1)));
649 let cache = Arc::new(cache);
650 let counter = Arc::new(AtomicI32::new(0));
651 let opt1 = Some(ExtraOpt {
652 error: false,
653 empty: false,
654 delay_for: Some(Duration::from_secs(2)),
655 used: counter.clone(),
656 });
657 let opt2 = Some(ExtraOpt {
658 error: false,
659 empty: false,
660 delay_for: None,
661 used: counter.clone(),
662 });
663 let opt3 = opt2.clone();
664 let cache_c = cache.clone();
665 let t1 = tokio::spawn(async move {
667 let (res, _hit) = cache_c.get(&1, None, opt1.as_ref()).await;
668 res.unwrap()
669 });
670 let cache_c = cache.clone();
672 let t2 = tokio::spawn(async move {
673 let (res, _hit) = cache_c.get(&1, None, opt2.as_ref()).await;
674 res.unwrap()
675 });
676 let cache_c = cache.clone();
677 let t3 = tokio::spawn(async move {
678 let (res, _hit) = cache_c.get(&1, None, opt3.as_ref()).await;
679 res.unwrap()
680 });
681 let (r1, r2, r3) = tokio::join!(t1, t2, t3);
682 assert_eq!(r1.unwrap() + r2.unwrap() + r3.unwrap(), 6);
684 }
685
686 #[tokio::test]
687 async fn test_multi_get() {
688 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
689 let counter = Arc::new(AtomicI32::new(0));
690 let opt1 = Some(ExtraOpt {
691 error: false,
692 empty: false,
693 delay_for: Some(Duration::from_secs(2)),
694 used: counter.clone(),
695 });
696 let (res, hit) = cache.get(&1, None, opt1.as_ref()).await;
698 assert_eq!(res.unwrap(), 1);
699 assert_eq!(hit, CacheStatus::Miss);
700 let (res, hit) = cache.get(&1, None, opt1.as_ref()).await;
701 assert_eq!(res.unwrap(), 1);
702 assert_eq!(hit, CacheStatus::Hit);
703 let resp = cache
705 .multi_get([1, 2, 3].iter(), None, opt1.as_ref())
706 .await
707 .unwrap();
708 assert_eq!(resp[0].0, 1);
709 assert_eq!(resp[0].1, CacheStatus::Hit);
710 assert_eq!(resp[1].0, 2);
711 assert_eq!(resp[1].1, CacheStatus::Miss);
712 assert_eq!(resp[2].0, 3);
713 assert_eq!(resp[2].1, CacheStatus::Miss);
714 let resp = cache
716 .multi_get([1, 2, 3].iter(), None, opt1.as_ref())
717 .await
718 .unwrap();
719 assert_eq!(resp[0].0, 1);
720 assert_eq!(resp[0].1, CacheStatus::Hit);
721 assert_eq!(resp[1].0, 2);
722 assert_eq!(resp[1].1, CacheStatus::Hit);
723 assert_eq!(resp[2].0, 3);
724 assert_eq!(resp[2].1, CacheStatus::Hit);
725 }
726
727 #[tokio::test]
728 #[should_panic(expected = "multi_lookup() failed to return the matching number of results")]
729 async fn test_inconsistent_miss_results() {
730 let opt1 = Some(ExtraOpt {
732 error: false,
733 empty: true,
734 delay_for: None,
735 used: Arc::new(AtomicI32::new(0)),
736 });
737 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
738 cache
739 .multi_get([4, 5, 6].iter(), None, opt1.as_ref())
740 .await
741 .unwrap();
742 }
743
744 #[tokio::test]
745 async fn test_get_stale() {
746 let ttl = Some(Duration::from_millis(100));
747 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
748 let opt = Some(ExtraOpt {
749 error: false,
750 empty: false,
751 delay_for: None,
752 used: Arc::new(AtomicI32::new(0)),
753 });
754 let (res, hit) = cache.get(&1, ttl, opt.as_ref()).await;
755 assert_eq!(res.unwrap(), 1);
756 assert_eq!(hit, CacheStatus::Miss);
757 let (res, hit) = cache.get(&1, ttl, opt.as_ref()).await;
758 assert_eq!(res.unwrap(), 1);
759 assert_eq!(hit, CacheStatus::Hit);
760 tokio::time::sleep(Duration::from_millis(150)).await;
761 let (res, hit) = cache
762 .get_stale(&1, ttl, opt.as_ref(), Duration::from_millis(1000))
763 .await;
764 assert_eq!(res.unwrap(), 1);
765 assert!(hit.stale().is_some());
766
767 let (res, hit) = cache
768 .get_stale(&1, ttl, opt.as_ref(), Duration::from_millis(30))
769 .await;
770 assert_eq!(res.unwrap(), 2);
771 assert_eq!(hit, CacheStatus::Expired);
772 }
773
774 #[tokio::test]
775 async fn test_get_stale_while_update() {
776 use once_cell::sync::Lazy;
777 let ttl = Some(Duration::from_millis(100));
778 static CACHE: Lazy<RTCache<i32, i32, TestCB, ExtraOpt>> =
779 Lazy::new(|| RTCache::new(10, None, None));
780 let opt = Some(ExtraOpt {
781 error: false,
782 empty: false,
783 delay_for: None,
784 used: Arc::new(AtomicI32::new(0)),
785 });
786 let (res, hit) = CACHE.get(&1, ttl, opt.as_ref()).await;
787 assert_eq!(res.unwrap(), 1);
788 assert_eq!(hit, CacheStatus::Miss);
789 let (res, hit) = CACHE.get(&1, ttl, opt.as_ref()).await;
790 assert_eq!(res.unwrap(), 1);
791 assert_eq!(hit, CacheStatus::Hit);
792 tokio::time::sleep(Duration::from_millis(150)).await;
793 let (res, hit) = CACHE
794 .get_stale_while_update(&1, ttl, opt.as_ref(), Duration::from_millis(1000))
795 .await;
796 assert_eq!(res.unwrap(), 1);
797 assert!(hit.stale().is_some());
798
799 tokio::time::sleep(Duration::from_millis(10)).await;
801
802 let (res, hit) = CACHE.get(&1, ttl, opt.as_ref()).await;
803 assert_eq!(res.unwrap(), 2);
804 assert_eq!(hit, CacheStatus::Hit);
805 }
806}