1use super::{CacheStatus, MemoryCache};
19
20use async_trait::async_trait;
21use log::warn;
22use parking_lot::RwLock;
23use pingora_error::{Error, ErrorTrait};
24use std::collections::HashMap;
25use std::hash::Hash;
26use std::marker::PhantomData;
27use std::sync::Arc;
28use std::time::{Duration, Instant};
29use tokio::sync::Semaphore;
30
31struct CacheLock {
32 pub lock_start: Instant,
33 pub lock: Semaphore,
34}
35
36impl CacheLock {
37 pub fn new_arc() -> Arc<Self> {
38 Arc::new(CacheLock {
39 lock: Semaphore::new(0),
40 lock_start: Instant::now(),
41 })
42 }
43
44 pub fn too_old(&self, age: Option<&Duration>) -> bool {
45 match age {
46 Some(t) => Instant::now() - self.lock_start > *t,
47 None => false,
48 }
49 }
50}
51
52#[async_trait]
53pub trait Lookup<K, T, S> {
76 async fn lookup(
78 key: &K,
79 extra: Option<&S>,
80 ) -> Result<(T, Option<Duration>), Box<dyn ErrorTrait + Send + Sync>>
81 where
82 K: 'async_trait,
83 S: 'async_trait;
84}
85
86#[async_trait]
87pub trait MultiLookup<K, T, S> {
90 async fn multi_lookup(
92 keys: &[&K],
93 extra: Option<&S>,
94 ) -> Result<Vec<(T, Option<Duration>)>, Box<dyn ErrorTrait + Send + Sync>>
95 where
96 K: 'async_trait,
97 S: 'async_trait;
98}
99
100const LOOKUP_ERR_MSG: &str = "RTCache: lookup error";
101
102pub struct RTCache<K, T, CB, S>
111where
112 K: Hash + Send,
113 T: Clone + Send,
114{
115 inner: MemoryCache<K, T>,
116 _callback: PhantomData<CB>,
117 lockers: RwLock<HashMap<u64, Arc<CacheLock>>>,
118 lock_age: Option<Duration>,
119 lock_timeout: Option<Duration>,
120 phantom: PhantomData<S>,
121}
122
123impl<K, T, CB, S> RTCache<K, T, CB, S>
124where
125 K: Hash + Send,
126 T: Clone + Send + Sync + 'static,
127{
128 pub fn new(size: usize, lock_age: Option<Duration>, lock_timeout: Option<Duration>) -> Self {
131 RTCache {
132 inner: MemoryCache::new(size),
133 lockers: RwLock::new(HashMap::new()),
134 _callback: PhantomData,
135 lock_age,
136 lock_timeout,
137 phantom: PhantomData,
138 }
139 }
140}
141
142impl<K, T, CB, S> RTCache<K, T, CB, S>
143where
144 K: Hash + Send,
145 T: Clone + Send + Sync + 'static,
146 CB: Lookup<K, T, S>,
147{
148 pub async fn get(
151 &self,
152 key: &K,
153 ttl: Option<Duration>,
154 extra: Option<&S>,
155 ) -> (Result<T, Box<Error>>, CacheStatus) {
156 let (result, cache_state) = self.inner.get(key);
157 if let Some(result) = result {
158 return (Ok(result), cache_state);
160 }
161
162 let hashed_key = self.inner.hasher.hash_one(key);
163
164 let my_lock = {
166 let lockers = self.lockers.read();
167 lockers.get(&hashed_key).cloned()
169 }; let (my_write, my_read) = match my_lock {
173 Some(lock) => {
175 if lock.too_old(self.lock_age.as_ref()) {
177 (None, None)
178 } else {
179 (None, Some(lock))
180 }
181 }
182 None => {
183 let mut lockers = self.lockers.write();
184 match lockers.get(&hashed_key) {
185 Some(lock) => {
186 if lock.too_old(self.lock_age.as_ref()) {
188 (None, None)
189 } else {
190 (None, Some(lock.clone()))
191 }
192 }
193 None => {
194 let new_lock = CacheLock::new_arc();
195 let new_lock2 = new_lock.clone();
196 lockers.insert(hashed_key, new_lock2);
197 (Some(new_lock), None)
198 }
199 } }
201 };
202
203 if my_read.is_some() {
204 let my_lock = my_read.unwrap();
207 if my_lock.lock.available_permits() == 0 {
209 let lock_fut = my_lock.lock.acquire();
211 let timed_out = match self.lock_timeout {
212 Some(t) => pingora_timeout::timeout(t, lock_fut).await.is_err(),
213 None => {
214 let _ = lock_fut.await;
215 false
216 }
217 };
218 if timed_out {
219 let value = CB::lookup(key, extra).await;
220 return match value {
221 Ok((v, _ttl)) => (Ok(v), cache_state),
222 Err(e) => {
223 let mut err = Error::new_str(LOOKUP_ERR_MSG);
224 err.set_cause(e);
225 (Err(err), cache_state)
226 }
227 };
228 }
229 } let (result, cache_state) = self.inner.get(key);
232 if let Some(result) = result {
233 (Ok(result), CacheStatus::LockHit)
235 } else {
236 warn!(
238 "RTCache: no result after read lock, cache status: {:?}",
239 cache_state
240 );
241 match CB::lookup(key, extra).await {
242 Ok((v, new_ttl)) => {
243 self.inner.force_put(key, v.clone(), new_ttl.or(ttl));
244 (Ok(v), cache_state)
245 }
246 Err(e) => {
247 let mut err = Error::new_str(LOOKUP_ERR_MSG);
248 err.set_cause(e);
249 (Err(err), cache_state)
250 }
251 }
252 }
253 } else {
254 let value = CB::lookup(key, extra).await;
257 let ret = match value {
258 Ok((v, new_ttl)) => {
259 if my_write.is_some() {
261 self.inner.force_put(key, v.clone(), new_ttl.or(ttl));
262 }
263 (Ok(v), cache_state) }
265 Err(e) => {
266 let mut err = Error::new_str(LOOKUP_ERR_MSG);
267 err.set_cause(e);
268 (Err(err), cache_state)
269 }
270 };
271 if my_write.is_some() {
272 my_write.unwrap().lock.add_permits(10);
275
276 {
277 let mut lockers = self.lockers.write();
279 lockers.remove(&hashed_key);
280 } }
282
283 ret
284 }
285 }
286}
287
288impl<K, T, CB, S> RTCache<K, T, CB, S>
289where
290 K: Hash + Send,
291 T: Clone + Send + Sync + 'static,
292 CB: MultiLookup<K, T, S>,
293{
294 pub async fn multi_get<'a, I>(
303 &self,
304 keys: I,
305 ttl: Option<Duration>,
306 extra: Option<&S>,
307 ) -> Result<Vec<(T, CacheStatus)>, Box<Error>>
308 where
309 I: Iterator<Item = &'a K>,
310 K: 'a,
311 {
312 let size = keys.size_hint().0;
313 let (hits, misses) = self.inner.multi_get_with_miss(keys);
314 let mut final_results = Vec::with_capacity(size);
315 let miss_results = if !misses.is_empty() {
316 match CB::multi_lookup(&misses, extra).await {
317 Ok(miss_results) => {
318 assert!(
321 miss_results.len() == misses.len(),
322 "multi_lookup() failed to return the matching number of results"
323 );
324 for item in misses.iter().zip(miss_results.iter()) {
326 self.inner
327 .force_put(item.0, (item.1).0.clone(), (item.1).1.or(ttl));
328 }
329 miss_results
330 }
331 Err(e) => {
332 let mut err = Error::new_str(LOOKUP_ERR_MSG);
334 err.set_cause(e);
335 return Err(err);
336 }
337 }
338 } else {
339 vec![] };
341 let mut n_miss = 0;
343 for item in hits {
344 match item.0 {
345 Some(v) => final_results.push((v, item.1)),
346 None => {
347 final_results .push((miss_results[n_miss].0.clone(), CacheStatus::Miss));
349 n_miss += 1;
350 }
351 }
352 }
353 Ok(final_results)
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use atomic::AtomicI32;
361 use std::sync::atomic;
362
363 #[derive(Clone, Debug)]
364 struct ExtraOpt {
365 error: bool,
366 empty: bool,
367 delay_for: Option<Duration>,
368 used: Arc<AtomicI32>,
369 }
370
371 struct TestCB();
372
373 #[async_trait]
374 impl Lookup<i32, i32, ExtraOpt> for TestCB {
375 async fn lookup(
376 _key: &i32,
377 extra: Option<&ExtraOpt>,
378 ) -> Result<(i32, Option<Duration>), Box<dyn ErrorTrait + Send + Sync>> {
379 let mut used = 0;
381 if let Some(e) = extra {
382 used = e.used.fetch_add(1, atomic::Ordering::Relaxed) + 1;
383 if e.error {
384 return Err(Error::new_str("test error"));
385 }
386 if let Some(delay_for) = e.delay_for {
387 tokio::time::sleep(delay_for).await;
388 }
389 }
390 Ok((used, None))
391 }
392 }
393
394 #[async_trait]
395 impl MultiLookup<i32, i32, ExtraOpt> for TestCB {
396 async fn multi_lookup(
397 keys: &[&i32],
398 extra: Option<&ExtraOpt>,
399 ) -> Result<Vec<(i32, Option<Duration>)>, Box<dyn ErrorTrait + Send + Sync>> {
400 let mut resp = vec![];
401 if let Some(extra) = extra {
402 if extra.empty {
403 return Ok(resp);
404 }
405 }
406 for key in keys {
407 resp.push((**key, None));
408 }
409 Ok(resp)
410 }
411 }
412
413 #[tokio::test]
414 async fn test_basic_get() {
415 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
416 let opt = Some(ExtraOpt {
417 error: false,
418 empty: false,
419 delay_for: None,
420 used: Arc::new(AtomicI32::new(0)),
421 });
422 let (res, hit) = cache.get(&1, None, opt.as_ref()).await;
423 assert_eq!(res.unwrap(), 1);
424 assert_eq!(hit, CacheStatus::Miss);
425 let (res, hit) = cache.get(&1, None, opt.as_ref()).await;
426 assert_eq!(res.unwrap(), 1);
427 assert_eq!(hit, CacheStatus::Hit);
428 }
429
430 #[tokio::test]
431 async fn test_basic_get_error() {
432 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
433 let opt1 = Some(ExtraOpt {
434 error: true,
435 empty: false,
436 delay_for: None,
437 used: Arc::new(AtomicI32::new(0)),
438 });
439 let (res, hit) = cache.get(&-1, None, opt1.as_ref()).await;
440 assert!(res.is_err());
441 assert_eq!(hit, CacheStatus::Miss);
442 }
443
444 #[tokio::test]
445 async fn test_concurrent_get() {
446 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
447 let cache = Arc::new(cache);
448 let opt = Some(ExtraOpt {
449 error: false,
450 empty: false,
451 delay_for: None,
452 used: Arc::new(AtomicI32::new(0)),
453 });
454 let cache_c = cache.clone();
455 let opt1 = opt.clone();
456 let t1 = tokio::spawn(async move {
458 let (res, _hit) = cache_c.get(&1, None, opt1.as_ref()).await;
459 res.unwrap()
460 });
461 let cache_c = cache.clone();
462 let opt2 = opt.clone();
463 let t2 = tokio::spawn(async move {
464 let (res, _hit) = cache_c.get(&1, None, opt2.as_ref()).await;
465 res.unwrap()
466 });
467 let opt3 = opt.clone();
468 let cache_c = cache.clone();
469 let t3 = tokio::spawn(async move {
470 let (res, _hit) = cache_c.get(&1, None, opt3.as_ref()).await;
471 res.unwrap()
472 });
473 let (r1, r2, r3) = tokio::join!(t1, t2, t3);
474 assert_eq!(r1.unwrap(), 1);
475 assert_eq!(r2.unwrap(), 1);
476 assert_eq!(r3.unwrap(), 1);
477 }
478
479 #[tokio::test]
480 async fn test_concurrent_get_error() {
481 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
482 let cache = Arc::new(cache);
483 let cache_c = cache.clone();
484 let opt1 = Some(ExtraOpt {
485 error: true,
486 empty: false,
487 delay_for: None,
488 used: Arc::new(AtomicI32::new(0)),
489 });
490 let opt2 = opt1.clone();
491 let opt3 = opt1.clone();
492 let t1 = tokio::spawn(async move {
494 let (res, _hit) = cache_c.get(&-1, None, opt1.as_ref()).await;
495 res.is_err()
496 });
497 let cache_c = cache.clone();
498 let t2 = tokio::spawn(async move {
499 let (res, _hit) = cache_c.get(&-1, None, opt2.as_ref()).await;
500 res.is_err()
501 });
502 let cache_c = cache.clone();
503 let t3 = tokio::spawn(async move {
504 let (res, _hit) = cache_c.get(&-1, None, opt3.as_ref()).await;
505 res.is_err()
506 });
507 let (r1, r2, r3) = tokio::join!(t1, t2, t3);
508 assert!(r1.unwrap());
509 assert!(r2.unwrap());
510 assert!(r3.unwrap());
511 }
512
513 #[tokio::test]
514 async fn test_concurrent_get_different_value() {
515 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
516 let cache = Arc::new(cache);
517 let opt1 = Some(ExtraOpt {
518 error: false,
519 empty: false,
520 delay_for: None,
521 used: Arc::new(AtomicI32::new(0)),
522 });
523 let opt2 = opt1.clone();
524 let opt3 = opt1.clone();
525 let cache_c = cache.clone();
526 let t1 = tokio::spawn(async move {
528 let (res, _hit) = cache_c.get(&1, None, opt1.as_ref()).await;
529 res.unwrap()
530 });
531 let cache_c = cache.clone();
532 let t2 = tokio::spawn(async move {
533 let (res, _hit) = cache_c.get(&3, None, opt2.as_ref()).await;
534 res.unwrap()
535 });
536 let cache_c = cache.clone();
537 let t3 = tokio::spawn(async move {
538 let (res, _hit) = cache_c.get(&5, None, opt3.as_ref()).await;
539 res.unwrap()
540 });
541 let (r1, r2, r3) = tokio::join!(t1, t2, t3);
542 assert_eq!(r1.unwrap() + r2.unwrap() + r3.unwrap(), 6);
544 }
545
546 #[tokio::test]
547 async fn test_get_lock_age() {
548 let cache: RTCache<i32, i32, TestCB, ExtraOpt> =
550 RTCache::new(10, Some(Duration::from_secs(1)), None);
551 let cache = Arc::new(cache);
552 let counter = Arc::new(AtomicI32::new(0));
553 let opt1 = Some(ExtraOpt {
554 error: false,
555 empty: false,
556 delay_for: Some(Duration::from_secs(2)),
557 used: counter.clone(),
558 });
559
560 let opt2 = Some(ExtraOpt {
561 error: false,
562 empty: false,
563 delay_for: None,
564 used: counter.clone(),
565 });
566 let opt3 = opt2.clone();
567 let cache_c = cache.clone();
568 let t1 = tokio::spawn(async move {
570 let (res, _hit) = cache_c.get(&1, None, opt1.as_ref()).await;
571 res.unwrap()
572 });
573 tokio::time::sleep(Duration::from_secs_f32(1.5)).await;
575 let cache_c = cache.clone();
576 let t2 = tokio::spawn(async move {
577 let (res, _hit) = cache_c.get(&1, None, opt2.as_ref()).await;
578 res.unwrap()
579 });
580 let cache_c = cache.clone();
581 let t3 = tokio::spawn(async move {
582 let (res, _hit) = cache_c.get(&1, None, opt3.as_ref()).await;
583 res.unwrap()
584 });
585 let (r1, r2, r3) = tokio::join!(t1, t2, t3);
586 assert_eq!(r1.unwrap() + r2.unwrap() + r3.unwrap(), 6);
588 }
589
590 #[tokio::test]
591 async fn test_get_lock_timeout() {
592 let cache: RTCache<i32, i32, TestCB, ExtraOpt> =
594 RTCache::new(10, None, Some(Duration::from_secs(1)));
595 let cache = Arc::new(cache);
596 let counter = Arc::new(AtomicI32::new(0));
597 let opt1 = Some(ExtraOpt {
598 error: false,
599 empty: false,
600 delay_for: Some(Duration::from_secs(2)),
601 used: counter.clone(),
602 });
603 let opt2 = Some(ExtraOpt {
604 error: false,
605 empty: false,
606 delay_for: None,
607 used: counter.clone(),
608 });
609 let opt3 = opt2.clone();
610 let cache_c = cache.clone();
611 let t1 = tokio::spawn(async move {
613 let (res, _hit) = cache_c.get(&1, None, opt1.as_ref()).await;
614 res.unwrap()
615 });
616 let cache_c = cache.clone();
618 let t2 = tokio::spawn(async move {
619 let (res, _hit) = cache_c.get(&1, None, opt2.as_ref()).await;
620 res.unwrap()
621 });
622 let cache_c = cache.clone();
623 let t3 = tokio::spawn(async move {
624 let (res, _hit) = cache_c.get(&1, None, opt3.as_ref()).await;
625 res.unwrap()
626 });
627 let (r1, r2, r3) = tokio::join!(t1, t2, t3);
628 assert_eq!(r1.unwrap() + r2.unwrap() + r3.unwrap(), 6);
630 }
631
632 #[tokio::test]
633 async fn test_multi_get() {
634 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
635 let counter = Arc::new(AtomicI32::new(0));
636 let opt1 = Some(ExtraOpt {
637 error: false,
638 empty: false,
639 delay_for: Some(Duration::from_secs(2)),
640 used: counter.clone(),
641 });
642 let (res, hit) = cache.get(&1, None, opt1.as_ref()).await;
644 assert_eq!(res.unwrap(), 1);
645 assert_eq!(hit, CacheStatus::Miss);
646 let (res, hit) = cache.get(&1, None, opt1.as_ref()).await;
647 assert_eq!(res.unwrap(), 1);
648 assert_eq!(hit, CacheStatus::Hit);
649 let resp = cache
651 .multi_get([1, 2, 3].iter(), None, opt1.as_ref())
652 .await
653 .unwrap();
654 assert_eq!(resp[0].0, 1);
655 assert_eq!(resp[0].1, CacheStatus::Hit);
656 assert_eq!(resp[1].0, 2);
657 assert_eq!(resp[1].1, CacheStatus::Miss);
658 assert_eq!(resp[2].0, 3);
659 assert_eq!(resp[2].1, CacheStatus::Miss);
660 let resp = cache
662 .multi_get([1, 2, 3].iter(), None, opt1.as_ref())
663 .await
664 .unwrap();
665 assert_eq!(resp[0].0, 1);
666 assert_eq!(resp[0].1, CacheStatus::Hit);
667 assert_eq!(resp[1].0, 2);
668 assert_eq!(resp[1].1, CacheStatus::Hit);
669 assert_eq!(resp[2].0, 3);
670 assert_eq!(resp[2].1, CacheStatus::Hit);
671 }
672
673 #[tokio::test]
674 #[should_panic(expected = "multi_lookup() failed to return the matching number of results")]
675 async fn test_inconsistent_miss_results() {
676 let opt1 = Some(ExtraOpt {
678 error: false,
679 empty: true,
680 delay_for: None,
681 used: Arc::new(AtomicI32::new(0)),
682 });
683 let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
684 cache
685 .multi_get([4, 5, 6].iter(), None, opt1.as_ref())
686 .await
687 .unwrap();
688 }
689}