use super::{CacheStatus, MemoryCache};
use async_trait::async_trait;
use log::warn;
use parking_lot::RwLock;
use pingora_error::{Error, ErrorTrait};
use std::collections::HashMap;
use std::hash::Hash;
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Semaphore;
struct CacheLock {
pub lock_start: Instant,
pub lock: Semaphore,
}
impl CacheLock {
pub fn new_arc() -> Arc<Self> {
Arc::new(CacheLock {
lock: Semaphore::new(0),
lock_start: Instant::now(),
})
}
pub fn too_old(&self, age: Option<&Duration>) -> bool {
match age {
Some(t) => Instant::now() - self.lock_start > *t,
None => false,
}
}
}
#[async_trait]
pub trait Lookup<K, T, S> {
async fn lookup(
key: &K,
extra: Option<&S>,
) -> Result<(T, Option<Duration>), Box<dyn ErrorTrait + Send + Sync>>
where
K: 'async_trait,
S: 'async_trait;
}
#[async_trait]
pub trait MultiLookup<K, T, S> {
async fn multi_lookup(
keys: &[&K],
extra: Option<&S>,
) -> Result<Vec<(T, Option<Duration>)>, Box<dyn ErrorTrait + Send + Sync>>
where
K: 'async_trait,
S: 'async_trait;
}
const LOOKUP_ERR_MSG: &str = "RTCache: lookup error";
pub struct RTCache<K, T, CB, S>
where
K: Hash + Send,
T: Clone + Send,
{
inner: MemoryCache<K, T>,
_callback: PhantomData<CB>,
lockers: RwLock<HashMap<u64, Arc<CacheLock>>>,
lock_age: Option<Duration>,
lock_timeout: Option<Duration>,
phantom: PhantomData<S>,
}
impl<K, T, CB, S> RTCache<K, T, CB, S>
where
K: Hash + Send,
T: Clone + Send + Sync + 'static,
{
pub fn new(size: usize, lock_age: Option<Duration>, lock_timeout: Option<Duration>) -> Self {
RTCache {
inner: MemoryCache::new(size),
lockers: RwLock::new(HashMap::new()),
_callback: PhantomData,
lock_age,
lock_timeout,
phantom: PhantomData,
}
}
}
impl<K, T, CB, S> RTCache<K, T, CB, S>
where
K: Hash + Send,
T: Clone + Send + Sync + 'static,
CB: Lookup<K, T, S>,
{
pub async fn get(
&self,
key: &K,
ttl: Option<Duration>,
extra: Option<&S>,
) -> (Result<T, Box<Error>>, CacheStatus) {
let (result, cache_state) = self.inner.get(key);
if let Some(result) = result {
return (Ok(result), cache_state);
}
let hashed_key = self.inner.hasher.hash_one(key);
let my_lock = {
let lockers = self.lockers.read();
lockers.get(&hashed_key).cloned()
};
let (my_write, my_read) = match my_lock {
Some(lock) => {
if lock.too_old(self.lock_age.as_ref()) {
(None, None)
} else {
(None, Some(lock))
}
}
None => {
let mut lockers = self.lockers.write();
match lockers.get(&hashed_key) {
Some(lock) => {
if lock.too_old(self.lock_age.as_ref()) {
(None, None)
} else {
(None, Some(lock.clone()))
}
}
None => {
let new_lock = CacheLock::new_arc();
let new_lock2 = new_lock.clone();
lockers.insert(hashed_key, new_lock2);
(Some(new_lock), None)
}
} }
};
if my_read.is_some() {
let my_lock = my_read.unwrap();
if my_lock.lock.available_permits() == 0 {
let lock_fut = my_lock.lock.acquire();
let timed_out = match self.lock_timeout {
Some(t) => pingora_timeout::timeout(t, lock_fut).await.is_err(),
None => {
let _ = lock_fut.await;
false
}
};
if timed_out {
let value = CB::lookup(key, extra).await;
return match value {
Ok((v, _ttl)) => (Ok(v), cache_state),
Err(e) => {
let mut err = Error::new_str(LOOKUP_ERR_MSG);
err.set_cause(e);
(Err(err), cache_state)
}
};
}
}
let (result, cache_state) = self.inner.get(key);
if let Some(result) = result {
(Ok(result), CacheStatus::LockHit)
} else {
warn!(
"RTCache: no result after read lock, cache status: {:?}",
cache_state
);
match CB::lookup(key, extra).await {
Ok((v, new_ttl)) => {
self.inner.force_put(key, v.clone(), new_ttl.or(ttl));
(Ok(v), cache_state)
}
Err(e) => {
let mut err = Error::new_str(LOOKUP_ERR_MSG);
err.set_cause(e);
(Err(err), cache_state)
}
}
}
} else {
let value = CB::lookup(key, extra).await;
let ret = match value {
Ok((v, new_ttl)) => {
if my_write.is_some() {
self.inner.force_put(key, v.clone(), new_ttl.or(ttl));
}
(Ok(v), cache_state) }
Err(e) => {
let mut err = Error::new_str(LOOKUP_ERR_MSG);
err.set_cause(e);
(Err(err), cache_state)
}
};
if my_write.is_some() {
my_write.unwrap().lock.add_permits(10);
{
let mut lockers = self.lockers.write();
lockers.remove(&hashed_key);
} }
ret
}
}
}
impl<K, T, CB, S> RTCache<K, T, CB, S>
where
K: Hash + Send,
T: Clone + Send + Sync + 'static,
CB: MultiLookup<K, T, S>,
{
pub async fn multi_get<'a, I>(
&self,
keys: I,
ttl: Option<Duration>,
extra: Option<&S>,
) -> Result<Vec<(T, CacheStatus)>, Box<Error>>
where
I: Iterator<Item = &'a K>,
K: 'a,
{
let size = keys.size_hint().0;
let (hits, misses) = self.inner.multi_get_with_miss(keys);
let mut final_results = Vec::with_capacity(size);
let miss_results = if !misses.is_empty() {
match CB::multi_lookup(&misses, extra).await {
Ok(miss_results) => {
assert!(
miss_results.len() == misses.len(),
"multi_lookup() failed to return the matching number of results"
);
for item in misses.iter().zip(miss_results.iter()) {
self.inner
.force_put(item.0, (item.1).0.clone(), (item.1).1.or(ttl));
}
miss_results
}
Err(e) => {
let mut err = Error::new_str(LOOKUP_ERR_MSG);
err.set_cause(e);
return Err(err);
}
}
} else {
vec![] };
let mut n_miss = 0;
for item in hits {
match item.0 {
Some(v) => final_results.push((v, item.1)),
None => {
final_results .push((miss_results[n_miss].0.clone(), CacheStatus::Miss));
n_miss += 1;
}
}
}
Ok(final_results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use atomic::AtomicI32;
use std::sync::atomic;
#[derive(Clone, Debug)]
struct ExtraOpt {
error: bool,
empty: bool,
delay_for: Option<Duration>,
used: Arc<AtomicI32>,
}
struct TestCB();
#[async_trait]
impl Lookup<i32, i32, ExtraOpt> for TestCB {
async fn lookup(
_key: &i32,
extra: Option<&ExtraOpt>,
) -> Result<(i32, Option<Duration>), Box<dyn ErrorTrait + Send + Sync>> {
let mut used = 0;
if let Some(e) = extra {
used = e.used.fetch_add(1, atomic::Ordering::Relaxed) + 1;
if e.error {
return Err(Error::new_str("test error"));
}
if let Some(delay_for) = e.delay_for {
tokio::time::sleep(delay_for).await;
}
}
Ok((used, None))
}
}
#[async_trait]
impl MultiLookup<i32, i32, ExtraOpt> for TestCB {
async fn multi_lookup(
keys: &[&i32],
extra: Option<&ExtraOpt>,
) -> Result<Vec<(i32, Option<Duration>)>, Box<dyn ErrorTrait + Send + Sync>> {
let mut resp = vec![];
if let Some(extra) = extra {
if extra.empty {
return Ok(resp);
}
}
for key in keys {
resp.push((**key, None));
}
Ok(resp)
}
}
#[tokio::test]
async fn test_basic_get() {
let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
let opt = Some(ExtraOpt {
error: false,
empty: false,
delay_for: None,
used: Arc::new(AtomicI32::new(0)),
});
let (res, hit) = cache.get(&1, None, opt.as_ref()).await;
assert_eq!(res.unwrap(), 1);
assert_eq!(hit, CacheStatus::Miss);
let (res, hit) = cache.get(&1, None, opt.as_ref()).await;
assert_eq!(res.unwrap(), 1);
assert_eq!(hit, CacheStatus::Hit);
}
#[tokio::test]
async fn test_basic_get_error() {
let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
let opt1 = Some(ExtraOpt {
error: true,
empty: false,
delay_for: None,
used: Arc::new(AtomicI32::new(0)),
});
let (res, hit) = cache.get(&-1, None, opt1.as_ref()).await;
assert!(res.is_err());
assert_eq!(hit, CacheStatus::Miss);
}
#[tokio::test]
async fn test_concurrent_get() {
let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
let cache = Arc::new(cache);
let opt = Some(ExtraOpt {
error: false,
empty: false,
delay_for: None,
used: Arc::new(AtomicI32::new(0)),
});
let cache_c = cache.clone();
let opt1 = opt.clone();
let t1 = tokio::spawn(async move {
let (res, _hit) = cache_c.get(&1, None, opt1.as_ref()).await;
res.unwrap()
});
let cache_c = cache.clone();
let opt2 = opt.clone();
let t2 = tokio::spawn(async move {
let (res, _hit) = cache_c.get(&1, None, opt2.as_ref()).await;
res.unwrap()
});
let opt3 = opt.clone();
let cache_c = cache.clone();
let t3 = tokio::spawn(async move {
let (res, _hit) = cache_c.get(&1, None, opt3.as_ref()).await;
res.unwrap()
});
let (r1, r2, r3) = tokio::join!(t1, t2, t3);
assert_eq!(r1.unwrap(), 1);
assert_eq!(r2.unwrap(), 1);
assert_eq!(r3.unwrap(), 1);
}
#[tokio::test]
async fn test_concurrent_get_error() {
let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
let cache = Arc::new(cache);
let cache_c = cache.clone();
let opt1 = Some(ExtraOpt {
error: true,
empty: false,
delay_for: None,
used: Arc::new(AtomicI32::new(0)),
});
let opt2 = opt1.clone();
let opt3 = opt1.clone();
let t1 = tokio::spawn(async move {
let (res, _hit) = cache_c.get(&-1, None, opt1.as_ref()).await;
res.is_err()
});
let cache_c = cache.clone();
let t2 = tokio::spawn(async move {
let (res, _hit) = cache_c.get(&-1, None, opt2.as_ref()).await;
res.is_err()
});
let cache_c = cache.clone();
let t3 = tokio::spawn(async move {
let (res, _hit) = cache_c.get(&-1, None, opt3.as_ref()).await;
res.is_err()
});
let (r1, r2, r3) = tokio::join!(t1, t2, t3);
assert!(r1.unwrap());
assert!(r2.unwrap());
assert!(r3.unwrap());
}
#[tokio::test]
async fn test_concurrent_get_different_value() {
let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
let cache = Arc::new(cache);
let opt1 = Some(ExtraOpt {
error: false,
empty: false,
delay_for: None,
used: Arc::new(AtomicI32::new(0)),
});
let opt2 = opt1.clone();
let opt3 = opt1.clone();
let cache_c = cache.clone();
let t1 = tokio::spawn(async move {
let (res, _hit) = cache_c.get(&1, None, opt1.as_ref()).await;
res.unwrap()
});
let cache_c = cache.clone();
let t2 = tokio::spawn(async move {
let (res, _hit) = cache_c.get(&3, None, opt2.as_ref()).await;
res.unwrap()
});
let cache_c = cache.clone();
let t3 = tokio::spawn(async move {
let (res, _hit) = cache_c.get(&5, None, opt3.as_ref()).await;
res.unwrap()
});
let (r1, r2, r3) = tokio::join!(t1, t2, t3);
assert_eq!(r1.unwrap() + r2.unwrap() + r3.unwrap(), 6);
}
#[tokio::test]
async fn test_get_lock_age() {
let cache: RTCache<i32, i32, TestCB, ExtraOpt> =
RTCache::new(10, Some(Duration::from_secs(1)), None);
let cache = Arc::new(cache);
let counter = Arc::new(AtomicI32::new(0));
let opt1 = Some(ExtraOpt {
error: false,
empty: false,
delay_for: Some(Duration::from_secs(2)),
used: counter.clone(),
});
let opt2 = Some(ExtraOpt {
error: false,
empty: false,
delay_for: None,
used: counter.clone(),
});
let opt3 = opt2.clone();
let cache_c = cache.clone();
let t1 = tokio::spawn(async move {
let (res, _hit) = cache_c.get(&1, None, opt1.as_ref()).await;
res.unwrap()
});
tokio::time::sleep(Duration::from_secs_f32(1.5)).await;
let cache_c = cache.clone();
let t2 = tokio::spawn(async move {
let (res, _hit) = cache_c.get(&1, None, opt2.as_ref()).await;
res.unwrap()
});
let cache_c = cache.clone();
let t3 = tokio::spawn(async move {
let (res, _hit) = cache_c.get(&1, None, opt3.as_ref()).await;
res.unwrap()
});
let (r1, r2, r3) = tokio::join!(t1, t2, t3);
assert_eq!(r1.unwrap() + r2.unwrap() + r3.unwrap(), 6);
}
#[tokio::test]
async fn test_get_lock_timeout() {
let cache: RTCache<i32, i32, TestCB, ExtraOpt> =
RTCache::new(10, None, Some(Duration::from_secs(1)));
let cache = Arc::new(cache);
let counter = Arc::new(AtomicI32::new(0));
let opt1 = Some(ExtraOpt {
error: false,
empty: false,
delay_for: Some(Duration::from_secs(2)),
used: counter.clone(),
});
let opt2 = Some(ExtraOpt {
error: false,
empty: false,
delay_for: None,
used: counter.clone(),
});
let opt3 = opt2.clone();
let cache_c = cache.clone();
let t1 = tokio::spawn(async move {
let (res, _hit) = cache_c.get(&1, None, opt1.as_ref()).await;
res.unwrap()
});
let cache_c = cache.clone();
let t2 = tokio::spawn(async move {
let (res, _hit) = cache_c.get(&1, None, opt2.as_ref()).await;
res.unwrap()
});
let cache_c = cache.clone();
let t3 = tokio::spawn(async move {
let (res, _hit) = cache_c.get(&1, None, opt3.as_ref()).await;
res.unwrap()
});
let (r1, r2, r3) = tokio::join!(t1, t2, t3);
assert_eq!(r1.unwrap() + r2.unwrap() + r3.unwrap(), 6);
}
#[tokio::test]
async fn test_multi_get() {
let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
let counter = Arc::new(AtomicI32::new(0));
let opt1 = Some(ExtraOpt {
error: false,
empty: false,
delay_for: Some(Duration::from_secs(2)),
used: counter.clone(),
});
let (res, hit) = cache.get(&1, None, opt1.as_ref()).await;
assert_eq!(res.unwrap(), 1);
assert_eq!(hit, CacheStatus::Miss);
let (res, hit) = cache.get(&1, None, opt1.as_ref()).await;
assert_eq!(res.unwrap(), 1);
assert_eq!(hit, CacheStatus::Hit);
let resp = cache
.multi_get([1, 2, 3].iter(), None, opt1.as_ref())
.await
.unwrap();
assert_eq!(resp[0].0, 1);
assert_eq!(resp[0].1, CacheStatus::Hit);
assert_eq!(resp[1].0, 2);
assert_eq!(resp[1].1, CacheStatus::Miss);
assert_eq!(resp[2].0, 3);
assert_eq!(resp[2].1, CacheStatus::Miss);
let resp = cache
.multi_get([1, 2, 3].iter(), None, opt1.as_ref())
.await
.unwrap();
assert_eq!(resp[0].0, 1);
assert_eq!(resp[0].1, CacheStatus::Hit);
assert_eq!(resp[1].0, 2);
assert_eq!(resp[1].1, CacheStatus::Hit);
assert_eq!(resp[2].0, 3);
assert_eq!(resp[2].1, CacheStatus::Hit);
}
#[tokio::test]
#[should_panic(expected = "multi_lookup() failed to return the matching number of results")]
async fn test_inconsistent_miss_results() {
let opt1 = Some(ExtraOpt {
error: false,
empty: true,
delay_for: None,
used: Arc::new(AtomicI32::new(0)),
});
let cache: RTCache<i32, i32, TestCB, ExtraOpt> = RTCache::new(10, None, None);
cache
.multi_get([4, 5, 6].iter(), None, opt1.as_ref())
.await
.unwrap();
}
}