use crate::{hashtable::ConcurrentHashTable, key::CacheHashKey, CacheKey};
use pingora_timeout::timeout;
use std::sync::Arc;
use std::time::Duration;
pub type CacheKeyLockImpl = (dyn CacheKeyLock + Send + Sync);
pub trait CacheKeyLock {
fn lock(&self, key: &CacheKey) -> Locked;
fn release(&self, key: &CacheKey, permit: WritePermit, reason: LockStatus);
}
const N_SHARDS: usize = 16;
#[derive(Debug)]
pub struct CacheLock {
lock_table: ConcurrentHashTable<LockStub, N_SHARDS>,
age_timeout_default: Duration,
}
#[derive(Debug)]
pub enum Locked {
Write(WritePermit),
Read(ReadLock),
}
impl Locked {
pub fn is_write(&self) -> bool {
matches!(self, Self::Write(_))
}
}
impl CacheLock {
pub fn new_boxed(age_timeout: Duration) -> Box<Self> {
Box::new(CacheLock {
lock_table: ConcurrentHashTable::new(),
age_timeout_default: age_timeout,
})
}
pub fn new(age_timeout_default: Duration) -> Self {
CacheLock {
lock_table: ConcurrentHashTable::new(),
age_timeout_default,
}
}
}
impl CacheKeyLock for CacheLock {
fn lock(&self, key: &CacheKey) -> Locked {
let hash = key.combined_bin();
let key = u128::from_be_bytes(hash); let table = self.lock_table.get(key);
if let Some(lock) = table.read().get(&key) {
if !matches!(
lock.0.lock_status(),
LockStatus::Dangling | LockStatus::Timeout
) {
return Locked::Read(lock.read_lock());
}
}
let mut table = table.write();
if let Some(lock) = table.get(&key) {
if !matches!(
lock.0.lock_status(),
LockStatus::Dangling | LockStatus::Timeout
) {
return Locked::Read(lock.read_lock());
}
}
let (permit, stub) = WritePermit::new(self.age_timeout_default);
table.insert(key, stub);
Locked::Write(permit)
}
fn release(&self, key: &CacheKey, mut permit: WritePermit, reason: LockStatus) {
let hash = key.combined_bin();
let key = u128::from_be_bytes(hash); if permit.lock.lock_status() == LockStatus::Timeout {
permit.unlock(LockStatus::Timeout);
} else if let Some(_lock) = self.lock_table.write(key).remove(&key) {
permit.unlock(reason);
}
}
}
use log::warn;
use std::sync::atomic::{AtomicU8, Ordering};
use std::time::Instant;
use strum::IntoStaticStr;
use tokio::sync::Semaphore;
#[derive(Debug, Copy, Clone, PartialEq, Eq, IntoStaticStr)]
pub enum LockStatus {
Waiting,
Done,
TransientError,
GiveUp,
Dangling,
Timeout,
}
impl From<LockStatus> for u8 {
fn from(l: LockStatus) -> u8 {
match l {
LockStatus::Waiting => 0,
LockStatus::Done => 1,
LockStatus::TransientError => 2,
LockStatus::GiveUp => 3,
LockStatus::Dangling => 4,
LockStatus::Timeout => 5,
}
}
}
impl From<u8> for LockStatus {
fn from(v: u8) -> Self {
match v {
0 => Self::Waiting,
1 => Self::Done,
2 => Self::TransientError,
3 => Self::GiveUp,
4 => Self::Dangling,
5 => Self::Timeout,
_ => Self::GiveUp, }
}
}
#[derive(Debug)]
pub struct LockCore {
pub lock_start: Instant,
pub age_timeout: Duration,
pub(super) lock: Semaphore,
lock_status: AtomicU8,
}
impl LockCore {
pub fn new_arc(timeout: Duration) -> Arc<Self> {
Arc::new(LockCore {
lock: Semaphore::new(0),
age_timeout: timeout,
lock_start: Instant::now(),
lock_status: AtomicU8::new(LockStatus::Waiting.into()),
})
}
pub fn locked(&self) -> bool {
self.lock.available_permits() == 0
}
pub fn unlock(&self, reason: LockStatus) {
self.lock_status.store(reason.into(), Ordering::SeqCst);
self.lock.add_permits(10);
}
pub fn lock_status(&self) -> LockStatus {
self.lock_status.load(Ordering::SeqCst).into()
}
}
#[derive(Debug)]
pub struct ReadLock(Arc<LockCore>);
impl ReadLock {
pub async fn wait(&self) {
if !self.locked() {
return;
}
if let Some(duration) = self.0.age_timeout.checked_sub(self.0.lock_start.elapsed()) {
match timeout(duration, self.0.lock.acquire()).await {
Ok(Ok(_)) => { }
Ok(Err(e)) => {
warn!("error acquiring semaphore {e:?}")
}
Err(_) => {
self.0
.lock_status
.store(LockStatus::Timeout.into(), Ordering::SeqCst);
}
}
} else {
self.0
.lock_status
.store(LockStatus::Timeout.into(), Ordering::SeqCst);
}
}
pub fn locked(&self) -> bool {
self.0.locked()
}
pub fn expired(&self) -> bool {
self.0.lock_start.elapsed() >= self.0.age_timeout
}
pub fn lock_status(&self) -> LockStatus {
let status = self.0.lock_status();
if matches!(status, LockStatus::Waiting) && self.expired() {
LockStatus::Timeout
} else {
status
}
}
}
#[derive(Debug)]
pub struct WritePermit {
lock: Arc<LockCore>,
finished: bool,
}
impl WritePermit {
pub fn new(timeout: Duration) -> (WritePermit, LockStub) {
let lock = LockCore::new_arc(timeout);
let stub = LockStub(lock.clone());
(
WritePermit {
lock,
finished: false,
},
stub,
)
}
pub fn unlock(&mut self, reason: LockStatus) {
self.finished = true;
self.lock.unlock(reason);
}
}
impl Drop for WritePermit {
fn drop(&mut self) {
if !self.finished {
debug_assert!(false, "Dangling cache lock started!");
self.unlock(LockStatus::Dangling);
}
}
}
#[derive(Debug)]
pub struct LockStub(pub Arc<LockCore>);
impl LockStub {
pub fn read_lock(&self) -> ReadLock {
ReadLock(self.0.clone())
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::CacheKey;
#[test]
fn test_get_release() {
let cache_lock = CacheLock::new_boxed(Duration::from_secs(1000));
let key1 = CacheKey::new("", "a", "1");
let locked1 = cache_lock.lock(&key1);
assert!(locked1.is_write()); let locked2 = cache_lock.lock(&key1);
assert!(!locked2.is_write()); if let Locked::Write(permit) = locked1 {
cache_lock.release(&key1, permit, LockStatus::Done);
}
let locked3 = cache_lock.lock(&key1);
assert!(locked3.is_write()); if let Locked::Write(permit) = locked3 {
cache_lock.release(&key1, permit, LockStatus::Done);
}
}
#[tokio::test]
async fn test_lock() {
let cache_lock = CacheLock::new_boxed(Duration::from_secs(1000));
let key1 = CacheKey::new("", "a", "1");
let mut permit = match cache_lock.lock(&key1) {
Locked::Write(w) => w,
_ => panic!(),
};
let lock = match cache_lock.lock(&key1) {
Locked::Read(r) => r,
_ => panic!(),
};
assert!(lock.locked());
let handle = tokio::spawn(async move {
lock.wait().await;
assert_eq!(lock.lock_status(), LockStatus::Done);
});
permit.unlock(LockStatus::Done);
handle.await.unwrap(); }
#[tokio::test]
async fn test_lock_timeout() {
let cache_lock = CacheLock::new_boxed(Duration::from_secs(1));
let key1 = CacheKey::new("", "a", "1");
let mut permit = match cache_lock.lock(&key1) {
Locked::Write(w) => w,
_ => panic!(),
};
let lock = match cache_lock.lock(&key1) {
Locked::Read(r) => r,
_ => panic!(),
};
assert!(lock.locked());
let handle = tokio::spawn(async move {
lock.wait().await;
assert_eq!(lock.lock_status(), LockStatus::Timeout);
});
tokio::time::sleep(Duration::from_millis(2100)).await;
handle.await.unwrap();
let mut permit2 = match cache_lock.lock(&key1) {
Locked::Write(w) => w,
_ => panic!(),
};
let lock2 = match cache_lock.lock(&key1) {
Locked::Read(r) => r,
_ => panic!(),
};
assert!(lock2.locked());
let handle = tokio::spawn(async move {
lock2.wait().await;
assert_eq!(lock2.lock_status(), LockStatus::Done);
});
permit.unlock(LockStatus::Done);
permit2.unlock(LockStatus::Done);
handle.await.unwrap();
}
#[tokio::test]
async fn test_lock_expired_release() {
let cache_lock = CacheLock::new_boxed(Duration::from_secs(1));
let key1 = CacheKey::new("", "a", "1");
let permit = match cache_lock.lock(&key1) {
Locked::Write(w) => w,
_ => panic!(),
};
let lock = match cache_lock.lock(&key1) {
Locked::Read(r) => r,
_ => panic!(),
};
assert!(lock.locked());
let handle = tokio::spawn(async move {
lock.wait().await;
assert_eq!(lock.lock_status(), LockStatus::Timeout);
});
tokio::time::sleep(Duration::from_millis(1100)).await; handle.await.unwrap();
cache_lock.release(&key1, permit, LockStatus::Done);
let mut permit = match cache_lock.lock(&key1) {
Locked::Write(w) => w,
_ => panic!(),
};
assert_eq!(permit.lock.lock_status(), LockStatus::Waiting);
let lock2 = match cache_lock.lock(&key1) {
Locked::Read(r) => r,
_ => panic!(),
};
assert!(lock2.locked());
let handle = tokio::spawn(async move {
lock2.wait().await;
assert_eq!(lock2.lock_status(), LockStatus::Done);
});
permit.unlock(LockStatus::Done);
handle.await.unwrap();
}
#[tokio::test]
async fn test_lock_expired_no_reader() {
let cache_lock = CacheLock::new_boxed(Duration::from_secs(1));
let key1 = CacheKey::new("", "a", "1");
let mut permit = match cache_lock.lock(&key1) {
Locked::Write(w) => w,
_ => panic!(),
};
tokio::time::sleep(Duration::from_millis(1100)).await;
assert_eq!(permit.lock.lock_status(), LockStatus::Waiting);
let lock = match cache_lock.lock(&key1) {
Locked::Read(r) => r,
_ => panic!(),
};
lock.wait().await;
assert_eq!(lock.lock_status(), LockStatus::Timeout);
assert_eq!(permit.lock.lock_status(), LockStatus::Timeout);
permit.unlock(LockStatus::Timeout);
}
#[tokio::test]
async fn test_lock_concurrent() {
let _ = env_logger::builder().is_test(true).try_init();
let cache_lock = Arc::new(CacheLock::new_boxed(Duration::from_secs(1)));
let key1 = CacheKey::new("", "a", "1");
let mut handles = vec![];
const READERS: usize = 30;
for _ in 0..READERS {
let key1 = key1.clone();
let cache_lock = cache_lock.clone();
handles.push(tokio::spawn(async move {
loop {
match cache_lock.lock(&key1) {
Locked::Write(permit) => {
let _ = tokio::time::sleep(Duration::from_millis(5)).await;
cache_lock.release(&key1, permit, LockStatus::Done);
break;
}
Locked::Read(r) => {
r.wait().await;
}
}
}
}));
}
for handle in handles {
handle.await.unwrap();
}
}
}