1use crate::{hashtable::ConcurrentHashTable, key::CacheHashKey, CacheKey};
18use crate::{Span, Tag};
19
20use http::Extensions;
21use pingora_timeout::timeout;
22use std::sync::Arc;
23use std::time::Duration;
24
25pub type CacheKeyLockImpl = dyn CacheKeyLock + Send + Sync;
26
27pub trait CacheKeyLock {
28 fn lock(&self, key: &CacheKey, stale_writer: bool) -> Locked;
36
37 fn release(&self, key: &CacheKey, permit: WritePermit, reason: LockStatus);
42
43 fn trace_lock_wait(&self, span: &mut Span, _read_lock: &ReadLock, lock_status: LockStatus) {
45 let tag_value: &'static str = lock_status.into();
46 span.set_tag(|| Tag::new("status", tag_value));
47 }
48
49 fn custom_lock_status(&self, _custom_no_cache: &'static str) -> LockStatus {
51 LockStatus::GiveUp
54 }
55}
56
57const N_SHARDS: usize = 16;
58
59#[derive(Debug)]
61pub struct CacheLock {
62 lock_table: ConcurrentHashTable<LockStub, N_SHARDS>,
63 age_timeout_default: Duration,
65}
66
67#[derive(Debug)]
69pub enum Locked {
70 Write(WritePermit),
72 Read(ReadLock),
74}
75
76impl Locked {
77 pub fn is_write(&self) -> bool {
79 matches!(self, Self::Write(_))
80 }
81}
82
83impl CacheLock {
84 pub fn new_boxed(age_timeout: Duration) -> Box<Self> {
90 Box::new(CacheLock {
91 lock_table: ConcurrentHashTable::new(),
92 age_timeout_default: age_timeout,
93 })
94 }
95
96 pub fn new(age_timeout_default: Duration) -> Self {
102 CacheLock {
103 lock_table: ConcurrentHashTable::new(),
104 age_timeout_default,
105 }
106 }
107}
108
109impl CacheKeyLock for CacheLock {
110 fn lock(&self, key: &CacheKey, stale_writer: bool) -> Locked {
111 let hash = key.combined_bin();
112 let key = u128::from_be_bytes(hash); let table = self.lock_table.get(key);
114 if let Some(lock) = table.read().get(&key) {
115 if !matches!(
123 lock.0.lock_status(),
124 LockStatus::Dangling | LockStatus::AgeTimeout
125 ) {
126 return Locked::Read(lock.read_lock());
127 }
128 }
131
132 let mut table = table.write();
133 if let Some(lock) = table.get(&key) {
135 if !matches!(
136 lock.0.lock_status(),
137 LockStatus::Dangling | LockStatus::AgeTimeout
138 ) {
139 return Locked::Read(lock.read_lock());
140 }
141 }
142 let (permit, stub) =
143 WritePermit::new(self.age_timeout_default, stale_writer, Extensions::new());
144 table.insert(key, stub);
145 Locked::Write(permit)
146 }
147
148 fn release(&self, key: &CacheKey, mut permit: WritePermit, reason: LockStatus) {
149 let hash = key.combined_bin();
150 let key = u128::from_be_bytes(hash); if permit.lock.lock_status() == LockStatus::AgeTimeout {
152 permit.unlock(LockStatus::AgeTimeout);
158 } else if let Some(_lock) = self.lock_table.write(key).remove(&key) {
159 permit.unlock(reason);
160 }
161 }
164}
165
166use log::warn;
167use std::sync::atomic::{AtomicU8, Ordering};
168use std::time::Instant;
169use strum::{FromRepr, IntoStaticStr};
170use tokio::sync::Semaphore;
171
172#[derive(Debug, Copy, Clone, PartialEq, Eq, IntoStaticStr, FromRepr)]
174#[repr(u8)]
175pub enum LockStatus {
176 Waiting = 0,
178 Done = 1,
180 TransientError = 2,
182 GiveUp = 3,
185 Dangling = 4,
187 WaitTimeout = 5,
189 AgeTimeout = 6,
191}
192
193impl From<LockStatus> for u8 {
194 fn from(l: LockStatus) -> u8 {
195 match l {
196 LockStatus::Waiting => 0,
197 LockStatus::Done => 1,
198 LockStatus::TransientError => 2,
199 LockStatus::GiveUp => 3,
200 LockStatus::Dangling => 4,
201 LockStatus::WaitTimeout => 5,
202 LockStatus::AgeTimeout => 6,
203 }
204 }
205}
206
207impl From<u8> for LockStatus {
208 fn from(v: u8) -> Self {
209 Self::from_repr(v).unwrap_or(Self::GiveUp)
210 }
211}
212
213#[derive(Debug)]
214pub struct LockCore {
215 pub lock_start: Instant,
216 pub age_timeout: Duration,
217 pub(super) lock: Semaphore,
218 lock_status: AtomicU8,
220 stale_writer: bool,
221 extensions: Extensions,
222}
223
224impl LockCore {
225 pub fn new_arc(timeout: Duration, stale_writer: bool, extensions: Extensions) -> Arc<Self> {
226 Arc::new(LockCore {
227 lock: Semaphore::new(0),
228 age_timeout: timeout,
229 lock_start: Instant::now(),
230 lock_status: AtomicU8::new(LockStatus::Waiting.into()),
231 stale_writer,
232 extensions,
233 })
234 }
235
236 pub fn locked(&self) -> bool {
237 self.lock.available_permits() == 0
238 }
239
240 pub fn unlock(&self, reason: LockStatus) {
241 assert!(
242 reason != LockStatus::WaitTimeout,
243 "WaitTimeout is not stored in LockCore"
244 );
245 self.lock_status.store(reason.into(), Ordering::SeqCst);
246 self.lock.add_permits(10);
249 }
250
251 pub fn lock_status(&self) -> LockStatus {
252 self.lock_status.load(Ordering::SeqCst).into()
253 }
254
255 pub fn stale_writer(&self) -> bool {
257 self.stale_writer
258 }
259
260 pub fn extensions(&self) -> &Extensions {
261 &self.extensions
262 }
263}
264
265#[derive(Debug)]
269pub struct ReadLock(Arc<LockCore>);
270
271impl ReadLock {
272 pub async fn wait(&self) {
274 if !self.locked() {
275 return;
276 }
277
278 if let Some(duration) = self.0.age_timeout.checked_sub(self.0.lock_start.elapsed()) {
284 match timeout(duration, self.0.lock.acquire()).await {
285 Ok(Ok(_)) => { }
287 Ok(Err(e)) => {
288 warn!("error acquiring semaphore {e:?}")
289 }
290 Err(_) => {
291 self.0
292 .lock_status
293 .store(LockStatus::AgeTimeout.into(), Ordering::SeqCst);
294 }
295 }
296 } else {
297 self.0
299 .lock_status
300 .store(LockStatus::AgeTimeout.into(), Ordering::SeqCst);
301 }
302 }
303
304 pub fn locked(&self) -> bool {
306 self.0.locked()
307 }
308
309 pub fn expired(&self) -> bool {
311 self.0.lock_start.elapsed() >= self.0.age_timeout
314 }
315
316 pub fn lock_status(&self) -> LockStatus {
318 let status = self.0.lock_status();
319 if matches!(status, LockStatus::Waiting) && self.expired() {
320 LockStatus::AgeTimeout
321 } else {
322 status
323 }
324 }
325
326 pub fn extensions(&self) -> &Extensions {
327 self.0.extensions()
328 }
329}
330
331#[derive(Debug)]
333pub struct WritePermit {
334 lock: Arc<LockCore>,
335 finished: bool,
336}
337
338impl WritePermit {
339 pub fn new(
341 timeout: Duration,
342 stale_writer: bool,
343 extensions: Extensions,
344 ) -> (WritePermit, LockStub) {
345 let lock = LockCore::new_arc(timeout, stale_writer, extensions);
346 let stub = LockStub(lock.clone());
347 (
348 WritePermit {
349 lock,
350 finished: false,
351 },
352 stub,
353 )
354 }
355
356 pub fn stale_writer(&self) -> bool {
358 self.lock.stale_writer()
359 }
360
361 pub fn unlock(&mut self, reason: LockStatus) {
362 self.finished = true;
363 self.lock.unlock(reason);
364 }
365
366 pub fn lock_status(&self) -> LockStatus {
367 self.lock.lock_status()
368 }
369
370 pub fn extensions(&self) -> &Extensions {
371 self.lock.extensions()
372 }
373}
374
375impl Drop for WritePermit {
376 fn drop(&mut self) {
377 if !self.finished {
379 debug_assert!(false, "Dangling cache lock started!");
380 self.unlock(LockStatus::Dangling);
381 }
382 }
383}
384
385#[derive(Debug)]
386pub struct LockStub(pub Arc<LockCore>);
387impl LockStub {
388 pub fn read_lock(&self) -> ReadLock {
389 ReadLock(self.0.clone())
390 }
391
392 pub fn extensions(&self) -> &Extensions {
393 &self.0.extensions
394 }
395}
396
397#[cfg(test)]
398mod test {
399 use super::*;
400 use crate::CacheKey;
401
402 #[test]
403 fn test_get_release() {
404 let cache_lock = CacheLock::new_boxed(Duration::from_secs(1000));
405 let key1 = CacheKey::new("", "a", "1");
406 let locked1 = cache_lock.lock(&key1, false);
407 assert!(locked1.is_write()); let locked2 = cache_lock.lock(&key1, false);
409 assert!(!locked2.is_write()); if let Locked::Write(permit) = locked1 {
411 cache_lock.release(&key1, permit, LockStatus::Done);
412 }
413 let locked3 = cache_lock.lock(&key1, false);
414 assert!(locked3.is_write()); if let Locked::Write(permit) = locked3 {
416 cache_lock.release(&key1, permit, LockStatus::Done);
417 }
418 }
419
420 #[tokio::test]
421 async fn test_lock() {
422 let cache_lock = CacheLock::new_boxed(Duration::from_secs(1000));
423 let key1 = CacheKey::new("", "a", "1");
424 let mut permit = match cache_lock.lock(&key1, false) {
425 Locked::Write(w) => w,
426 _ => panic!(),
427 };
428 let lock = match cache_lock.lock(&key1, false) {
429 Locked::Read(r) => r,
430 _ => panic!(),
431 };
432 assert!(lock.locked());
433 let handle = tokio::spawn(async move {
434 lock.wait().await;
435 assert_eq!(lock.lock_status(), LockStatus::Done);
436 });
437 permit.unlock(LockStatus::Done);
438 handle.await.unwrap(); }
440
441 #[tokio::test]
442 async fn test_lock_timeout() {
443 let cache_lock = CacheLock::new_boxed(Duration::from_secs(1));
444 let key1 = CacheKey::new("", "a", "1");
445 let mut permit = match cache_lock.lock(&key1, false) {
446 Locked::Write(w) => w,
447 _ => panic!(),
448 };
449 let lock = match cache_lock.lock(&key1, false) {
450 Locked::Read(r) => r,
451 _ => panic!(),
452 };
453 assert!(lock.locked());
454
455 let handle = tokio::spawn(async move {
456 lock.wait().await;
458 assert_eq!(lock.lock_status(), LockStatus::AgeTimeout);
459 });
460
461 tokio::time::sleep(Duration::from_millis(2100)).await;
462
463 handle.await.unwrap(); let mut permit2 = match cache_lock.lock(&key1, false) {
467 Locked::Write(w) => w,
468 _ => panic!(),
469 };
470 let lock2 = match cache_lock.lock(&key1, false) {
471 Locked::Read(r) => r,
472 _ => panic!(),
473 };
474 assert!(lock2.locked());
475 let handle = tokio::spawn(async move {
476 lock2.wait().await;
478 assert_eq!(lock2.lock_status(), LockStatus::Done);
479 });
480
481 permit.unlock(LockStatus::Done);
482 permit2.unlock(LockStatus::Done);
483 handle.await.unwrap();
484 }
485
486 #[tokio::test]
487 async fn test_lock_expired_release() {
488 let cache_lock = CacheLock::new_boxed(Duration::from_secs(1));
489 let key1 = CacheKey::new("", "a", "1");
490 let permit = match cache_lock.lock(&key1, false) {
491 Locked::Write(w) => w,
492 _ => panic!(),
493 };
494
495 let lock = match cache_lock.lock(&key1, false) {
496 Locked::Read(r) => r,
497 _ => panic!(),
498 };
499 assert!(lock.locked());
500 let handle = tokio::spawn(async move {
501 lock.wait().await;
503 assert_eq!(lock.lock_status(), LockStatus::AgeTimeout);
504 });
505
506 tokio::time::sleep(Duration::from_millis(1100)).await; handle.await.unwrap(); cache_lock.release(&key1, permit, LockStatus::Done);
511
512 let mut permit = match cache_lock.lock(&key1, false) {
514 Locked::Write(w) => w,
515 _ => panic!(),
516 };
517 assert_eq!(permit.lock.lock_status(), LockStatus::Waiting);
518
519 let lock2 = match cache_lock.lock(&key1, false) {
520 Locked::Read(r) => r,
521 _ => panic!(),
522 };
523 assert!(lock2.locked());
524 let handle = tokio::spawn(async move {
525 lock2.wait().await;
527 assert_eq!(lock2.lock_status(), LockStatus::Done);
528 });
529
530 permit.unlock(LockStatus::Done);
531 handle.await.unwrap();
532 }
533
534 #[tokio::test]
535 async fn test_lock_expired_no_reader() {
536 let cache_lock = CacheLock::new_boxed(Duration::from_secs(1));
537 let key1 = CacheKey::new("", "a", "1");
538 let mut permit = match cache_lock.lock(&key1, false) {
539 Locked::Write(w) => w,
540 _ => panic!(),
541 };
542 tokio::time::sleep(Duration::from_millis(1100)).await; assert_eq!(permit.lock.lock_status(), LockStatus::Waiting);
546
547 let lock = match cache_lock.lock(&key1, false) {
548 Locked::Read(r) => r,
549 _ => panic!(),
550 };
551 lock.wait().await;
553 assert_eq!(lock.lock_status(), LockStatus::AgeTimeout);
554 assert_eq!(permit.lock.lock_status(), LockStatus::AgeTimeout);
555 permit.unlock(LockStatus::AgeTimeout);
556 }
557
558 #[tokio::test]
559 async fn test_lock_concurrent() {
560 let _ = env_logger::builder().is_test(true).try_init();
561 let cache_lock = Arc::new(CacheLock::new_boxed(Duration::from_secs(1)));
563 let key1 = CacheKey::new("", "a", "1");
564
565 let mut handles = vec![];
566
567 const READERS: usize = 30;
568 for _ in 0..READERS {
569 let key1 = key1.clone();
570 let cache_lock = cache_lock.clone();
571 handles.push(tokio::spawn(async move {
573 loop {
575 match cache_lock.lock(&key1, false) {
576 Locked::Write(permit) => {
577 let _ = tokio::time::sleep(Duration::from_millis(5)).await;
578 cache_lock.release(&key1, permit, LockStatus::Done);
579 break;
580 }
581 Locked::Read(r) => {
582 r.wait().await;
583 }
584 }
585 }
586 }));
587 }
588
589 for handle in handles {
590 handle.await.unwrap();
591 }
592 }
593}