use crate::key::CacheHashKey;
use crate::hashtable::ConcurrentHashTable;
use pingora_timeout::timeout;
use std::sync::Arc;
const N_SHARDS: usize = 16;
pub struct CacheLock {
lock_table: ConcurrentHashTable<LockStub, N_SHARDS>,
timeout: Duration, }
#[derive(Debug)]
pub enum Locked {
Write(WritePermit),
Read(ReadLock),
}
impl Locked {
pub fn is_write(&self) -> bool {
matches!(self, Self::Write(_))
}
}
impl CacheLock {
pub fn new(timeout: Duration) -> Self {
CacheLock {
lock_table: ConcurrentHashTable::new(),
timeout,
}
}
pub fn lock<K: CacheHashKey>(&self, key: &K) -> Locked {
let hash = key.combined_bin();
let key = u128::from_be_bytes(hash); let table = self.lock_table.get(key);
if let Some(lock) = table.read().get(&key) {
if lock.0.lock_status() != LockStatus::Dangling {
return Locked::Read(lock.read_lock());
}
}
let (permit, stub) = WritePermit::new(self.timeout);
let mut table = table.write();
if let Some(lock) = table.get(&key) {
if lock.0.lock_status() != LockStatus::Dangling {
return Locked::Read(lock.read_lock());
}
}
table.insert(key, stub);
Locked::Write(permit)
}
pub fn release<K: CacheHashKey>(&self, key: &K, reason: LockStatus) {
let hash = key.combined_bin();
let key = u128::from_be_bytes(hash); if let Some(lock) = self.lock_table.write(key).remove(&key) {
if lock.0.locked() {
lock.0.unlock(reason);
}
}
}
}
use std::sync::atomic::{AtomicU8, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::Semaphore;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum LockStatus {
Waiting,
Done,
TransientError,
GiveUp,
Dangling,
Timeout,
}
impl From<LockStatus> for u8 {
fn from(l: LockStatus) -> u8 {
match l {
LockStatus::Waiting => 0,
LockStatus::Done => 1,
LockStatus::TransientError => 2,
LockStatus::GiveUp => 3,
LockStatus::Dangling => 4,
LockStatus::Timeout => 5,
}
}
}
impl From<u8> for LockStatus {
fn from(v: u8) -> Self {
match v {
0 => Self::Waiting,
1 => Self::Done,
2 => Self::TransientError,
3 => Self::GiveUp,
4 => Self::Dangling,
5 => Self::Timeout,
_ => Self::GiveUp, }
}
}
#[derive(Debug)]
struct LockCore {
pub lock_start: Instant,
pub timeout: Duration,
pub(super) lock: Semaphore,
lock_status: AtomicU8,
}
impl LockCore {
pub fn new_arc(timeout: Duration) -> Arc<Self> {
Arc::new(LockCore {
lock: Semaphore::new(0),
timeout,
lock_start: Instant::now(),
lock_status: AtomicU8::new(LockStatus::Waiting.into()),
})
}
fn locked(&self) -> bool {
self.lock.available_permits() == 0
}
fn unlock(&self, reason: LockStatus) {
self.lock_status.store(reason.into(), Ordering::SeqCst);
self.lock.add_permits(10);
}
fn lock_status(&self) -> LockStatus {
self.lock_status.load(Ordering::Relaxed).into()
}
}
#[derive(Debug)]
pub struct ReadLock(Arc<LockCore>);
impl ReadLock {
pub async fn wait(&self) {
if !self.locked() || self.expired() {
return;
}
let _ = timeout(self.0.timeout, self.0.lock.acquire()).await;
}
pub fn locked(&self) -> bool {
self.0.locked()
}
pub fn expired(&self) -> bool {
self.0.lock_start.elapsed() >= self.0.timeout
}
pub fn lock_status(&self) -> LockStatus {
let status = self.0.lock_status();
if matches!(status, LockStatus::Waiting) && self.expired() {
LockStatus::Timeout
} else {
status
}
}
}
#[derive(Debug)]
pub struct WritePermit(Arc<LockCore>);
impl WritePermit {
fn new(timeout: Duration) -> (WritePermit, LockStub) {
let lock = LockCore::new_arc(timeout);
let stub = LockStub(lock.clone());
(WritePermit(lock), stub)
}
fn unlock(&self, reason: LockStatus) {
self.0.unlock(reason)
}
}
impl Drop for WritePermit {
fn drop(&mut self) {
if self.0.locked() {
self.unlock(LockStatus::Dangling);
}
}
}
struct LockStub(Arc<LockCore>);
impl LockStub {
pub fn read_lock(&self) -> ReadLock {
ReadLock(self.0.clone())
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::CacheKey;
#[test]
fn test_get_release() {
let cache_lock = CacheLock::new(Duration::from_secs(1000));
let key1 = CacheKey::new("", "a", "1");
let locked1 = cache_lock.lock(&key1);
assert!(locked1.is_write()); let locked2 = cache_lock.lock(&key1);
assert!(!locked2.is_write()); cache_lock.release(&key1, LockStatus::Done);
let locked3 = cache_lock.lock(&key1);
assert!(locked3.is_write()); }
#[tokio::test]
async fn test_lock() {
let cache_lock = CacheLock::new(Duration::from_secs(1000));
let key1 = CacheKey::new("", "a", "1");
let permit = match cache_lock.lock(&key1) {
Locked::Write(w) => w,
_ => panic!(),
};
let lock = match cache_lock.lock(&key1) {
Locked::Read(r) => r,
_ => panic!(),
};
assert!(lock.locked());
let handle = tokio::spawn(async move {
lock.wait().await;
assert_eq!(lock.lock_status(), LockStatus::Done);
});
permit.unlock(LockStatus::Done);
handle.await.unwrap(); }
#[tokio::test]
async fn test_lock_timeout() {
let cache_lock = CacheLock::new(Duration::from_secs(1));
let key1 = CacheKey::new("", "a", "1");
let permit = match cache_lock.lock(&key1) {
Locked::Write(w) => w,
_ => panic!(),
};
let lock = match cache_lock.lock(&key1) {
Locked::Read(r) => r,
_ => panic!(),
};
assert!(lock.locked());
let handle = tokio::spawn(async move {
lock.wait().await;
assert_eq!(lock.lock_status(), LockStatus::Timeout);
});
tokio::time::sleep(Duration::from_secs(2)).await;
let lock2 = match cache_lock.lock(&key1) {
Locked::Read(r) => r,
_ => panic!(),
};
assert!(lock2.locked());
assert_eq!(lock2.lock_status(), LockStatus::Timeout);
lock2.wait().await;
assert_eq!(lock2.lock_status(), LockStatus::Timeout);
permit.unlock(LockStatus::Done);
handle.await.unwrap();
}
}