1use crate::{hashtable::ConcurrentHashTable, key::CacheHashKey, CacheKey};
18
19use pingora_timeout::timeout;
20use std::sync::Arc;
21
22pub type CacheKeyLockImpl = (dyn CacheKeyLock + Send + Sync);
23
24pub trait CacheKeyLock {
25 fn lock(&self, key: &CacheKey) -> Locked;
30
31 fn release(&self, key: &CacheKey, permit: WritePermit, reason: LockStatus);
36}
37
38const N_SHARDS: usize = 16;
39
40pub struct CacheLock {
42 lock_table: ConcurrentHashTable<LockStub, N_SHARDS>,
43 timeout: Duration, }
45
46#[derive(Debug)]
48pub enum Locked {
49 Write(WritePermit),
51 Read(ReadLock),
53}
54
55impl Locked {
56 pub fn is_write(&self) -> bool {
58 matches!(self, Self::Write(_))
59 }
60}
61
62impl CacheLock {
63 pub fn new_boxed(timeout: Duration) -> Box<Self> {
67 Box::new(CacheLock {
68 lock_table: ConcurrentHashTable::new(),
69 timeout,
70 })
71 }
72
73 pub fn new(timeout: Duration) -> Self {
77 CacheLock {
78 lock_table: ConcurrentHashTable::new(),
79 timeout,
80 }
81 }
82}
83
84impl CacheKeyLock for CacheLock {
85 fn lock(&self, key: &CacheKey) -> Locked {
86 let hash = key.combined_bin();
87 let key = u128::from_be_bytes(hash); let table = self.lock_table.get(key);
89 if let Some(lock) = table.read().get(&key) {
90 if lock.0.lock_status() != LockStatus::Dangling {
92 return Locked::Read(lock.read_lock());
93 }
94 }
97
98 let mut table = table.write();
99 if let Some(lock) = table.get(&key) {
101 if lock.0.lock_status() != LockStatus::Dangling {
102 return Locked::Read(lock.read_lock());
103 }
104 }
105 let (permit, stub) = WritePermit::new(self.timeout);
106 table.insert(key, stub);
107 Locked::Write(permit)
108 }
109
110 fn release(&self, key: &CacheKey, mut permit: WritePermit, reason: LockStatus) {
111 let hash = key.combined_bin();
112 let key = u128::from_be_bytes(hash); if let Some(_lock) = self.lock_table.write(key).remove(&key) {
114 permit.unlock(reason);
116 }
117 }
118}
119
120use log::warn;
121use std::sync::atomic::{AtomicU8, Ordering};
122use std::time::{Duration, Instant};
123use strum::IntoStaticStr;
124use tokio::sync::Semaphore;
125
126#[derive(Debug, Copy, Clone, PartialEq, Eq, IntoStaticStr)]
128pub enum LockStatus {
129 Waiting,
131 Done,
133 TransientError,
135 GiveUp,
138 Dangling,
140 Timeout,
142}
143
144impl From<LockStatus> for u8 {
145 fn from(l: LockStatus) -> u8 {
146 match l {
147 LockStatus::Waiting => 0,
148 LockStatus::Done => 1,
149 LockStatus::TransientError => 2,
150 LockStatus::GiveUp => 3,
151 LockStatus::Dangling => 4,
152 LockStatus::Timeout => 5,
153 }
154 }
155}
156
157impl From<u8> for LockStatus {
158 fn from(v: u8) -> Self {
159 match v {
160 0 => Self::Waiting,
161 1 => Self::Done,
162 2 => Self::TransientError,
163 3 => Self::GiveUp,
164 4 => Self::Dangling,
165 5 => Self::Timeout,
166 _ => Self::GiveUp, }
168 }
169}
170
171#[derive(Debug)]
172pub struct LockCore {
173 pub lock_start: Instant,
174 pub timeout: Duration,
175 pub(super) lock: Semaphore,
176 lock_status: AtomicU8,
178}
179
180impl LockCore {
181 pub fn new_arc(timeout: Duration) -> Arc<Self> {
182 Arc::new(LockCore {
183 lock: Semaphore::new(0),
184 timeout,
185 lock_start: Instant::now(),
186 lock_status: AtomicU8::new(LockStatus::Waiting.into()),
187 })
188 }
189
190 pub fn locked(&self) -> bool {
191 self.lock.available_permits() == 0
192 }
193
194 pub fn unlock(&self, reason: LockStatus) {
195 self.lock_status.store(reason.into(), Ordering::SeqCst);
196 self.lock.add_permits(10);
199 }
200
201 pub fn lock_status(&self) -> LockStatus {
202 self.lock_status.load(Ordering::SeqCst).into()
203 }
204}
205
206#[derive(Debug)]
210pub struct ReadLock(Arc<LockCore>);
211
212impl ReadLock {
213 pub async fn wait(&self) {
215 if !self.locked() || self.expired() {
216 return;
217 }
218
219 if let Some(duration) = self.0.timeout.checked_sub(self.0.lock_start.elapsed()) {
222 match timeout(duration, self.0.lock.acquire()).await {
223 Ok(Ok(_)) => { }
225 Ok(Err(e)) => {
226 warn!("error acquiring semaphore {e:?}")
227 }
228 Err(_) => {
229 self.0
230 .lock_status
231 .store(LockStatus::Timeout.into(), Ordering::SeqCst);
232 }
233 }
234 }
235 }
236
237 pub fn locked(&self) -> bool {
239 self.0.locked()
240 }
241
242 pub fn expired(&self) -> bool {
244 self.0.lock_start.elapsed() >= self.0.timeout
247 }
248
249 pub fn lock_status(&self) -> LockStatus {
251 let status = self.0.lock_status();
252 if matches!(status, LockStatus::Waiting) && self.expired() {
253 LockStatus::Timeout
254 } else {
255 status
256 }
257 }
258}
259
260#[derive(Debug)]
262pub struct WritePermit {
263 lock: Arc<LockCore>,
264 finished: bool,
265}
266
267impl WritePermit {
268 pub fn new(timeout: Duration) -> (WritePermit, LockStub) {
269 let lock = LockCore::new_arc(timeout);
270 let stub = LockStub(lock.clone());
271 (
272 WritePermit {
273 lock,
274 finished: false,
275 },
276 stub,
277 )
278 }
279
280 pub fn unlock(&mut self, reason: LockStatus) {
281 self.finished = true;
282 self.lock.unlock(reason);
283 }
284}
285
286impl Drop for WritePermit {
287 fn drop(&mut self) {
288 if !self.finished {
290 debug_assert!(false, "Dangling cache lock started!");
291 self.unlock(LockStatus::Dangling);
292 }
293 }
294}
295
296pub struct LockStub(pub Arc<LockCore>);
297impl LockStub {
298 pub fn read_lock(&self) -> ReadLock {
299 ReadLock(self.0.clone())
300 }
301}
302
303#[cfg(test)]
304mod test {
305 use super::*;
306 use crate::CacheKey;
307
308 #[test]
309 fn test_get_release() {
310 let cache_lock = CacheLock::new_boxed(Duration::from_secs(1000));
311 let key1 = CacheKey::new("", "a", "1");
312 let locked1 = cache_lock.lock(&key1);
313 assert!(locked1.is_write()); let locked2 = cache_lock.lock(&key1);
315 assert!(!locked2.is_write()); if let Locked::Write(permit) = locked1 {
317 cache_lock.release(&key1, permit, LockStatus::Done);
318 }
319 let locked3 = cache_lock.lock(&key1);
320 assert!(locked3.is_write()); if let Locked::Write(permit) = locked3 {
322 cache_lock.release(&key1, permit, LockStatus::Done);
323 }
324 }
325
326 #[tokio::test]
327 async fn test_lock() {
328 let cache_lock = CacheLock::new_boxed(Duration::from_secs(1000));
329 let key1 = CacheKey::new("", "a", "1");
330 let mut permit = match cache_lock.lock(&key1) {
331 Locked::Write(w) => w,
332 _ => panic!(),
333 };
334 let lock = match cache_lock.lock(&key1) {
335 Locked::Read(r) => r,
336 _ => panic!(),
337 };
338 assert!(lock.locked());
339 let handle = tokio::spawn(async move {
340 lock.wait().await;
341 assert_eq!(lock.lock_status(), LockStatus::Done);
342 });
343 permit.unlock(LockStatus::Done);
344 handle.await.unwrap(); }
346
347 #[tokio::test]
348 async fn test_lock_timeout() {
349 let cache_lock = CacheLock::new_boxed(Duration::from_secs(1));
350 let key1 = CacheKey::new("", "a", "1");
351 let mut permit = match cache_lock.lock(&key1) {
352 Locked::Write(w) => w,
353 _ => panic!(),
354 };
355 let lock = match cache_lock.lock(&key1) {
356 Locked::Read(r) => r,
357 _ => panic!(),
358 };
359 assert!(lock.locked());
360
361 let handle = tokio::spawn(async move {
362 lock.wait().await;
364 assert_eq!(lock.lock_status(), LockStatus::Timeout);
365 });
366
367 tokio::time::sleep(Duration::from_secs(2)).await;
368
369 let lock2 = match cache_lock.lock(&key1) {
371 Locked::Read(r) => r,
372 _ => panic!(),
373 };
374 assert!(lock2.locked());
375 assert_eq!(lock2.lock_status(), LockStatus::Timeout);
376 lock2.wait().await;
377 assert_eq!(lock2.lock_status(), LockStatus::Timeout);
378
379 permit.unlock(LockStatus::Done);
380 handle.await.unwrap();
381 }
382
383 #[tokio::test]
384 async fn test_lock_concurrent() {
385 let _ = env_logger::builder().is_test(true).try_init();
386 let cache_lock = Arc::new(CacheLock::new_boxed(Duration::from_secs(1)));
388 let key1 = CacheKey::new("", "a", "1");
389
390 let mut handles = vec![];
391
392 const READERS: usize = 30;
393 for _ in 0..READERS {
394 let key1 = key1.clone();
395 let cache_lock = cache_lock.clone();
396 handles.push(tokio::spawn(async move {
398 loop {
400 match cache_lock.lock(&key1) {
401 Locked::Write(permit) => {
402 let _ = tokio::time::sleep(Duration::from_millis(5)).await;
403 cache_lock.release(&key1, permit, LockStatus::Done);
404 break;
405 }
406 Locked::Read(r) => {
407 r.wait().await;
408 }
409 }
410 }
411 }));
412 }
413
414 for handle in handles {
415 handle.await.unwrap();
416 }
417 }
418}