use crate::{hashtable::ConcurrentHashTable, key::CacheHashKey, CacheKey};
use crate::{Span, Tag};
use http::Extensions;
use pingora_timeout::timeout;
use std::sync::Arc;
use std::time::Duration;
pub type CacheKeyLockImpl = dyn CacheKeyLock + Send + Sync;
pub trait CacheKeyLock {
fn lock(&self, key: &CacheKey, stale_writer: bool) -> Locked;
fn release(&self, key: &CacheKey, permit: WritePermit, reason: LockStatus);
fn trace_lock_wait(&self, span: &mut Span, _read_lock: &ReadLock, lock_status: LockStatus) {
let tag_value: &'static str = lock_status.into();
span.set_tag(|| Tag::new("status", tag_value));
}
fn custom_lock_status(&self, _custom_no_cache: &'static str) -> LockStatus {
LockStatus::GiveUp
}
}
const N_SHARDS: usize = 16;
#[derive(Debug)]
pub struct CacheLock {
lock_table: ConcurrentHashTable<LockStub, N_SHARDS>,
age_timeout_default: Duration,
}
#[derive(Debug)]
pub enum Locked {
Write(WritePermit),
Read(ReadLock),
}
impl Locked {
pub fn is_write(&self) -> bool {
matches!(self, Self::Write(_))
}
}
impl CacheLock {
pub fn new_boxed(age_timeout: Duration) -> Box<Self> {
Box::new(CacheLock {
lock_table: ConcurrentHashTable::new(),
age_timeout_default: age_timeout,
})
}
pub fn new(age_timeout_default: Duration) -> Self {
CacheLock {
lock_table: ConcurrentHashTable::new(),
age_timeout_default,
}
}
}
impl CacheKeyLock for CacheLock {
fn lock(&self, key: &CacheKey, stale_writer: bool) -> Locked {
let hash = key.combined_bin();
let key = u128::from_be_bytes(hash); let table = self.lock_table.get(key);
if let Some(lock) = table.read().get(&key) {
if !matches!(
lock.0.lock_status(),
LockStatus::Dangling | LockStatus::AgeTimeout
) {
return Locked::Read(lock.read_lock());
}
}
let mut table = table.write();
if let Some(lock) = table.get(&key) {
if !matches!(
lock.0.lock_status(),
LockStatus::Dangling | LockStatus::AgeTimeout
) {
return Locked::Read(lock.read_lock());
}
}
let (permit, stub) =
WritePermit::new(self.age_timeout_default, stale_writer, Extensions::new());
table.insert(key, stub);
Locked::Write(permit)
}
fn release(&self, key: &CacheKey, mut permit: WritePermit, reason: LockStatus) {
let hash = key.combined_bin();
let key = u128::from_be_bytes(hash); if permit.lock.lock_status() == LockStatus::AgeTimeout {
permit.unlock(LockStatus::AgeTimeout);
} else if let Some(_lock) = self.lock_table.write(key).remove(&key) {
permit.unlock(reason);
}
}
}
use log::warn;
use std::sync::atomic::{AtomicU8, Ordering};
use std::time::Instant;
use strum::{FromRepr, IntoStaticStr};
use tokio::sync::Semaphore;
#[derive(Debug, Copy, Clone, PartialEq, Eq, IntoStaticStr, FromRepr)]
#[repr(u8)]
pub enum LockStatus {
Waiting = 0,
Done = 1,
TransientError = 2,
GiveUp = 3,
Dangling = 4,
WaitTimeout = 5,
AgeTimeout = 6,
}
impl From<LockStatus> for u8 {
fn from(l: LockStatus) -> u8 {
match l {
LockStatus::Waiting => 0,
LockStatus::Done => 1,
LockStatus::TransientError => 2,
LockStatus::GiveUp => 3,
LockStatus::Dangling => 4,
LockStatus::WaitTimeout => 5,
LockStatus::AgeTimeout => 6,
}
}
}
impl From<u8> for LockStatus {
fn from(v: u8) -> Self {
Self::from_repr(v).unwrap_or(Self::GiveUp)
}
}
#[derive(Debug)]
pub struct LockCore {
pub lock_start: Instant,
pub age_timeout: Duration,
pub(super) lock: Semaphore,
lock_status: AtomicU8,
stale_writer: bool,
extensions: Extensions,
}
impl LockCore {
pub fn new_arc(timeout: Duration, stale_writer: bool, extensions: Extensions) -> Arc<Self> {
Arc::new(LockCore {
lock: Semaphore::new(0),
age_timeout: timeout,
lock_start: Instant::now(),
lock_status: AtomicU8::new(LockStatus::Waiting.into()),
stale_writer,
extensions,
})
}
pub fn locked(&self) -> bool {
self.lock.available_permits() == 0
}
pub fn unlock(&self, reason: LockStatus) {
assert!(
reason != LockStatus::WaitTimeout,
"WaitTimeout is not stored in LockCore"
);
self.lock_status.store(reason.into(), Ordering::SeqCst);
self.lock.add_permits(10);
}
pub fn lock_status(&self) -> LockStatus {
self.lock_status.load(Ordering::SeqCst).into()
}
pub fn stale_writer(&self) -> bool {
self.stale_writer
}
pub fn extensions(&self) -> &Extensions {
&self.extensions
}
}
#[derive(Debug)]
pub struct ReadLock(Arc<LockCore>);
impl ReadLock {
pub async fn wait(&self) {
if !self.locked() {
return;
}
if let Some(duration) = self.0.age_timeout.checked_sub(self.0.lock_start.elapsed()) {
match timeout(duration, self.0.lock.acquire()).await {
Ok(Ok(_)) => { }
Ok(Err(e)) => {
warn!("error acquiring semaphore {e:?}")
}
Err(_) => {
self.0
.lock_status
.store(LockStatus::AgeTimeout.into(), Ordering::SeqCst);
}
}
} else {
self.0
.lock_status
.store(LockStatus::AgeTimeout.into(), Ordering::SeqCst);
}
}
pub fn locked(&self) -> bool {
self.0.locked()
}
pub fn expired(&self) -> bool {
self.0.lock_start.elapsed() >= self.0.age_timeout
}
pub fn lock_status(&self) -> LockStatus {
let status = self.0.lock_status();
if matches!(status, LockStatus::Waiting) && self.expired() {
LockStatus::AgeTimeout
} else {
status
}
}
pub fn extensions(&self) -> &Extensions {
self.0.extensions()
}
}
#[derive(Debug)]
pub struct WritePermit {
lock: Arc<LockCore>,
finished: bool,
}
impl WritePermit {
pub fn new(
timeout: Duration,
stale_writer: bool,
extensions: Extensions,
) -> (WritePermit, LockStub) {
let lock = LockCore::new_arc(timeout, stale_writer, extensions);
let stub = LockStub(lock.clone());
(
WritePermit {
lock,
finished: false,
},
stub,
)
}
pub fn stale_writer(&self) -> bool {
self.lock.stale_writer()
}
pub fn unlock(&mut self, reason: LockStatus) {
self.finished = true;
self.lock.unlock(reason);
}
pub fn lock_status(&self) -> LockStatus {
self.lock.lock_status()
}
pub fn extensions(&self) -> &Extensions {
self.lock.extensions()
}
}
impl Drop for WritePermit {
fn drop(&mut self) {
if !self.finished {
debug_assert!(false, "Dangling cache lock started!");
self.unlock(LockStatus::Dangling);
}
}
}
#[derive(Debug)]
pub struct LockStub(pub Arc<LockCore>);
impl LockStub {
pub fn read_lock(&self) -> ReadLock {
ReadLock(self.0.clone())
}
pub fn extensions(&self) -> &Extensions {
&self.0.extensions
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::CacheKey;
#[test]
fn test_get_release() {
let cache_lock = CacheLock::new_boxed(Duration::from_secs(1000));
let key1 = CacheKey::new("", "a", "1");
let locked1 = cache_lock.lock(&key1, false);
assert!(locked1.is_write()); let locked2 = cache_lock.lock(&key1, false);
assert!(!locked2.is_write()); if let Locked::Write(permit) = locked1 {
cache_lock.release(&key1, permit, LockStatus::Done);
}
let locked3 = cache_lock.lock(&key1, false);
assert!(locked3.is_write()); if let Locked::Write(permit) = locked3 {
cache_lock.release(&key1, permit, LockStatus::Done);
}
}
#[tokio::test]
async fn test_lock() {
let cache_lock = CacheLock::new_boxed(Duration::from_secs(1000));
let key1 = CacheKey::new("", "a", "1");
let mut permit = match cache_lock.lock(&key1, false) {
Locked::Write(w) => w,
_ => panic!(),
};
let lock = match cache_lock.lock(&key1, false) {
Locked::Read(r) => r,
_ => panic!(),
};
assert!(lock.locked());
let handle = tokio::spawn(async move {
lock.wait().await;
assert_eq!(lock.lock_status(), LockStatus::Done);
});
permit.unlock(LockStatus::Done);
handle.await.unwrap(); }
#[tokio::test]
async fn test_lock_timeout() {
let cache_lock = CacheLock::new_boxed(Duration::from_secs(1));
let key1 = CacheKey::new("", "a", "1");
let mut permit = match cache_lock.lock(&key1, false) {
Locked::Write(w) => w,
_ => panic!(),
};
let lock = match cache_lock.lock(&key1, false) {
Locked::Read(r) => r,
_ => panic!(),
};
assert!(lock.locked());
let handle = tokio::spawn(async move {
lock.wait().await;
assert_eq!(lock.lock_status(), LockStatus::AgeTimeout);
});
tokio::time::sleep(Duration::from_millis(2100)).await;
handle.await.unwrap();
let mut permit2 = match cache_lock.lock(&key1, false) {
Locked::Write(w) => w,
_ => panic!(),
};
let lock2 = match cache_lock.lock(&key1, false) {
Locked::Read(r) => r,
_ => panic!(),
};
assert!(lock2.locked());
let handle = tokio::spawn(async move {
lock2.wait().await;
assert_eq!(lock2.lock_status(), LockStatus::Done);
});
permit.unlock(LockStatus::Done);
permit2.unlock(LockStatus::Done);
handle.await.unwrap();
}
#[tokio::test]
async fn test_lock_expired_release() {
let cache_lock = CacheLock::new_boxed(Duration::from_secs(1));
let key1 = CacheKey::new("", "a", "1");
let permit = match cache_lock.lock(&key1, false) {
Locked::Write(w) => w,
_ => panic!(),
};
let lock = match cache_lock.lock(&key1, false) {
Locked::Read(r) => r,
_ => panic!(),
};
assert!(lock.locked());
let handle = tokio::spawn(async move {
lock.wait().await;
assert_eq!(lock.lock_status(), LockStatus::AgeTimeout);
});
tokio::time::sleep(Duration::from_millis(1100)).await; handle.await.unwrap();
cache_lock.release(&key1, permit, LockStatus::Done);
let mut permit = match cache_lock.lock(&key1, false) {
Locked::Write(w) => w,
_ => panic!(),
};
assert_eq!(permit.lock.lock_status(), LockStatus::Waiting);
let lock2 = match cache_lock.lock(&key1, false) {
Locked::Read(r) => r,
_ => panic!(),
};
assert!(lock2.locked());
let handle = tokio::spawn(async move {
lock2.wait().await;
assert_eq!(lock2.lock_status(), LockStatus::Done);
});
permit.unlock(LockStatus::Done);
handle.await.unwrap();
}
#[tokio::test]
async fn test_lock_expired_no_reader() {
let cache_lock = CacheLock::new_boxed(Duration::from_secs(1));
let key1 = CacheKey::new("", "a", "1");
let mut permit = match cache_lock.lock(&key1, false) {
Locked::Write(w) => w,
_ => panic!(),
};
tokio::time::sleep(Duration::from_millis(1100)).await;
assert_eq!(permit.lock.lock_status(), LockStatus::Waiting);
let lock = match cache_lock.lock(&key1, false) {
Locked::Read(r) => r,
_ => panic!(),
};
lock.wait().await;
assert_eq!(lock.lock_status(), LockStatus::AgeTimeout);
assert_eq!(permit.lock.lock_status(), LockStatus::AgeTimeout);
permit.unlock(LockStatus::AgeTimeout);
}
#[tokio::test]
async fn test_lock_concurrent() {
let _ = env_logger::builder().is_test(true).try_init();
let cache_lock = Arc::new(CacheLock::new_boxed(Duration::from_secs(1)));
let key1 = CacheKey::new("", "a", "1");
let mut handles = vec![];
const READERS: usize = 30;
for _ in 0..READERS {
let key1 = key1.clone();
let cache_lock = cache_lock.clone();
handles.push(tokio::spawn(async move {
loop {
match cache_lock.lock(&key1, false) {
Locked::Write(permit) => {
let _ = tokio::time::sleep(Duration::from_millis(5)).await;
cache_lock.release(&key1, permit, LockStatus::Done);
break;
}
Locked::Read(r) => {
r.wait().await;
}
}
}
}));
}
for handle in handles {
handle.await.unwrap();
}
}
}