1use async_lock::{RwLock, RwLockWriteGuard};
2use futures_util::FutureExt;
3use std::{
4 any::{Any, TypeId},
5 fmt,
6 future::Future,
7 hash::{BuildHasher, Hash},
8 pin::Pin,
9 sync::Arc,
10};
11
12use crate::{
13 common::concurrent::arc::MiniArc,
14 ops::compute::{CompResult, Op},
15 Entry,
16};
17
18use super::{Cache, ComputeNone, OptionallyNone};
19
20const WAITER_MAP_NUM_SEGMENTS: usize = 64;
21
22type ErrorObject = Arc<dyn Any + Send + Sync + 'static>;
23
24pub(crate) enum InitResult<V, E> {
25 Initialized(V),
26 ReadExisting(V),
27 InitErr(Arc<E>),
28}
29
30enum WaiterValue<V> {
31 Computing,
32 Ready(Result<V, ErrorObject>),
33 ReadyNone,
34 InitFuturePanicked,
36 EnclosingFutureAborted,
38}
39
40impl<V> fmt::Debug for WaiterValue<V> {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 WaiterValue::Computing => write!(f, "Computing"),
44 WaiterValue::Ready(_) => write!(f, "Ready"),
45 WaiterValue::ReadyNone => write!(f, "ReadyNone"),
46 WaiterValue::InitFuturePanicked => write!(f, "InitFuturePanicked"),
47 WaiterValue::EnclosingFutureAborted => write!(f, "EnclosingFutureAborted"),
48 }
49 }
50}
51
52type Waiter<V> = MiniArc<RwLock<WaiterValue<V>>>;
53type WaiterMap<K, V, S> = crate::cht::SegmentedHashMap<(Arc<K>, TypeId), Waiter<V>, S>;
54
55struct WaiterGuard<'a, K, V, S>
56where
59 K: Eq + Hash,
60 V: Clone,
61 S: BuildHasher,
62{
63 w_key: Option<(Arc<K>, TypeId)>,
64 w_hash: u64,
65 waiters: &'a WaiterMap<K, V, S>,
66 write_lock: RwLockWriteGuard<'a, WaiterValue<V>>,
67}
68
69impl<'a, K, V, S> WaiterGuard<'a, K, V, S>
70where
71 K: Eq + Hash,
72 V: Clone,
73 S: BuildHasher,
74{
75 fn new(
76 w_key: (Arc<K>, TypeId),
77 w_hash: u64,
78 waiters: &'a WaiterMap<K, V, S>,
79 write_lock: RwLockWriteGuard<'a, WaiterValue<V>>,
80 ) -> Self {
81 Self {
82 w_key: Some(w_key),
83 w_hash,
84 waiters,
85 write_lock,
86 }
87 }
88
89 fn set_waiter_value(mut self, v: WaiterValue<V>) {
90 *self.write_lock = v;
91 if let Some(w_key) = self.w_key.take() {
92 remove_waiter(self.waiters, w_key, self.w_hash);
93 }
94 }
95}
96
97impl<K, V, S> Drop for WaiterGuard<'_, K, V, S>
98where
99 K: Eq + Hash,
100 V: Clone,
101 S: BuildHasher,
102{
103 fn drop(&mut self) {
104 if let Some(w_key) = self.w_key.take() {
105 *self.write_lock = WaiterValue::EnclosingFutureAborted;
109 remove_waiter(self.waiters, w_key, self.w_hash);
110 }
111 }
112}
113
114pub(crate) struct ValueInitializer<K, V, S> {
115 waiters: MiniArc<WaiterMap<K, V, S>>,
120}
121
122impl<K, V, S> ValueInitializer<K, V, S>
123where
124 K: Eq + Hash + Send + Sync + 'static,
125 V: Clone + Send + Sync + 'static,
126 S: BuildHasher + Clone + Send + Sync + 'static,
127{
128 pub(crate) fn with_hasher(hasher: S) -> Self {
129 Self {
130 waiters: MiniArc::new(crate::cht::SegmentedHashMap::with_num_segments_and_hasher(
131 WAITER_MAP_NUM_SEGMENTS,
132 hasher,
133 )),
134 }
135 }
136
137 #[allow(clippy::too_many_arguments)]
150 pub(crate) async fn try_init_or_read<I, O, E>(
151 &self,
152 c_key: &Arc<K>,
153 c_hash: u64,
154 type_id: TypeId,
155 cache: &Cache<K, V, S>,
156 mut ignore_if: Option<I>,
157 init: Pin<&mut impl Future<Output = O>>,
159 post_init: fn(O) -> Result<V, E>,
162 ) -> InitResult<V, E>
163 where
164 I: FnMut(&V) -> bool + Send,
165 E: Send + Sync + 'static,
166 {
167 use std::panic::{resume_unwind, AssertUnwindSafe};
168 use InitResult::{InitErr, Initialized, ReadExisting};
169
170 const MAX_RETRIES: usize = 200;
171 let mut retries = 0;
172
173 let (w_key, w_hash) = waiter_key_hash(&self.waiters, c_key, type_id);
174
175 let waiter = MiniArc::new(RwLock::new(WaiterValue::Computing));
176 let lock = waiter.write().await;
179
180 loop {
181 let Some(existing_waiter) =
182 try_insert_waiter(&self.waiters, w_key.clone(), w_hash, &waiter)
183 else {
184 break;
186 };
187
188 let waiter_result = existing_waiter.read().await;
190 match &*waiter_result {
191 WaiterValue::Ready(Ok(value)) => return ReadExisting(value.clone()),
192 WaiterValue::Ready(Err(e)) => return InitErr(Arc::clone(e).downcast().unwrap()),
193 WaiterValue::InitFuturePanicked => {
195 retries += 1;
196 panic_if_retry_exhausted_for_panicking(retries, MAX_RETRIES);
197 continue;
199 }
200 WaiterValue::EnclosingFutureAborted => {
203 retries += 1;
204 panic_if_retry_exhausted_for_aborting(retries, MAX_RETRIES);
205 continue;
207 }
208 s @ (WaiterValue::Computing | WaiterValue::ReadyNone) => panic!(
210 "Got unexpected state `{s:?}` after resolving `init` future. \
211 This might be a bug in Moka"
212 ),
213 }
214 }
215
216 let waiter_guard = WaiterGuard::new(w_key, w_hash, &self.waiters, lock);
222
223 if let Some(value) = cache
225 .base
226 .get_with_hash(&**c_key, c_hash, ignore_if.as_mut(), false, false)
227 .await
228 .map(Entry::into_value)
229 {
230 waiter_guard.set_waiter_value(WaiterValue::Ready(Ok(value.clone())));
233 return ReadExisting(value);
234 }
235
236 match AssertUnwindSafe(init).catch_unwind().await {
240 Ok(value) => match post_init(value) {
242 Ok(value) => {
243 cache
244 .insert_with_hash(Arc::clone(c_key), c_hash, value.clone())
245 .await;
246 waiter_guard.set_waiter_value(WaiterValue::Ready(Ok(value.clone())));
247 Initialized(value)
248 }
249 Err(e) => {
250 let err: ErrorObject = Arc::new(e);
251 waiter_guard.set_waiter_value(WaiterValue::Ready(Err(Arc::clone(&err))));
252 InitErr(err.downcast().unwrap())
253 }
254 },
255 Err(payload) => {
257 waiter_guard.set_waiter_value(WaiterValue::InitFuturePanicked);
258 resume_unwind(payload);
259 }
260 }
261 }
263
264 pub(crate) async fn try_compute<'a, F, Fut, O, E>(
267 &'a self,
268 c_key: Arc<K>,
269 c_hash: u64,
270 cache: &Cache<K, V, S>,
271 f: F,
272 post_init: fn(O) -> Result<Op<V>, E>,
273 allow_nop: bool,
274 ) -> Result<CompResult<K, V>, E>
275 where
276 F: FnOnce(Option<Entry<K, V>>) -> Fut,
277 Fut: Future<Output = O> + 'a,
278 E: Send + Sync + 'static,
279 {
280 use std::panic::{resume_unwind, AssertUnwindSafe};
281
282 let type_id = TypeId::of::<ComputeNone>();
283 let (w_key, w_hash) = waiter_key_hash(&self.waiters, &c_key, type_id);
284 let waiter = MiniArc::new(RwLock::new(WaiterValue::Computing));
285 let lock = waiter.write().await;
288
289 loop {
290 let Some(existing_waiter) =
291 try_insert_waiter(&self.waiters, w_key.clone(), w_hash, &waiter)
292 else {
293 break;
295 };
296
297 let waiter_result = existing_waiter.read().await;
300 match &*waiter_result {
301 WaiterValue::Computing => panic!(
303 "Got unexpected state `Computing` after resolving `init` future. \
304 This might be a bug in Moka"
305 ),
306 _ => {
307 continue;
309 }
310 }
311 }
312
313 let waiter_guard = WaiterGuard::new(w_key, w_hash, &self.waiters, lock);
319
320 let ignore_if = None as Option<&mut fn(&V) -> bool>;
322 let maybe_entry = cache
323 .base
324 .get_with_hash(&*c_key, c_hash, ignore_if, true, true)
325 .await;
326 let maybe_value = if allow_nop {
327 maybe_entry.as_ref().map(|ent| ent.value().clone())
328 } else {
329 None
330 };
331 let entry_existed = maybe_entry.is_some();
332
333 let fut = match std::panic::catch_unwind(AssertUnwindSafe(|| f(maybe_entry))) {
336 Ok(fut) => fut,
338 Err(payload) => {
340 waiter_guard.set_waiter_value(WaiterValue::InitFuturePanicked);
341 resume_unwind(payload);
342 }
343 };
344
345 let output = match AssertUnwindSafe(fut).catch_unwind().await {
348 Ok(output) => output,
350 Err(payload) => {
352 waiter_guard.set_waiter_value(WaiterValue::InitFuturePanicked);
353 resume_unwind(payload);
354 }
355 };
356
357 match post_init(output) {
361 Ok(Op::Nop) => {
362 waiter_guard.set_waiter_value(WaiterValue::ReadyNone);
363 if let Some(value) = maybe_value {
364 Ok(CompResult::Unchanged(Entry::new(
365 Some(c_key),
366 value,
367 false,
368 false,
369 )))
370 } else {
371 Ok(CompResult::StillNone(c_key))
372 }
373 }
374 Ok(Op::Put(value)) => {
375 cache
376 .insert_with_hash(Arc::clone(&c_key), c_hash, value.clone())
377 .await;
378 waiter_guard.set_waiter_value(WaiterValue::ReadyNone);
379 if entry_existed {
380 crossbeam_epoch::pin().flush();
381 let entry = Entry::new(Some(c_key), value, true, true);
382 Ok(CompResult::ReplacedWith(entry))
383 } else {
384 let entry = Entry::new(Some(c_key), value, true, false);
385 Ok(CompResult::Inserted(entry))
386 }
387 }
388 Ok(Op::Remove) => {
389 let maybe_prev_v = cache.invalidate_with_hash(&*c_key, c_hash, true).await;
390 waiter_guard.set_waiter_value(WaiterValue::ReadyNone);
391 if let Some(prev_v) = maybe_prev_v {
392 crossbeam_epoch::pin().flush();
393 let entry = Entry::new(Some(c_key), prev_v, false, false);
394 Ok(CompResult::Removed(entry))
395 } else {
396 Ok(CompResult::StillNone(c_key))
397 }
398 }
399 Err(e) => {
400 waiter_guard.set_waiter_value(WaiterValue::ReadyNone);
401 Err(e)
402 }
403 }
404
405 }
407
408 pub(crate) async fn try_compute_if_nobody_else<'a, F, Fut, O, E>(
409 &'a self,
410 c_key: Arc<K>,
411 c_hash: u64,
412 cache: &Cache<K, V, S>,
413 f: F,
414 post_init: fn(O) -> Result<Op<V>, E>,
415 allow_nop: bool,
416 ) -> Result<CompResult<K, V>, E>
417 where
418 F: FnOnce(Option<Entry<K, V>>) -> Fut,
419 Fut: Future<Output = O> + 'a,
420 E: Send + Sync + 'static,
421 {
422 use std::panic::{resume_unwind, AssertUnwindSafe};
423
424 let type_id = TypeId::of::<ComputeNone>();
425 let (w_key, w_hash) = waiter_key_hash(&self.waiters, &c_key, type_id);
426 let waiter = MiniArc::new(RwLock::new(WaiterValue::Computing));
427 let lock = waiter.write().await;
430
431 if let Some(_existing_waiter) =
432 try_insert_waiter(&self.waiters, w_key.clone(), w_hash, &waiter)
433 {
434 let ignore_if = None as Option<&mut fn(&V) -> bool>;
438 let maybe_entry = cache
439 .base
440 .get_with_hash(&*c_key, c_hash, ignore_if, true, true)
441 .await;
442 let maybe_value = maybe_entry.as_ref().map(|ent| ent.value().clone());
443
444 return if let Some(value) = maybe_value {
445 Ok(CompResult::Unchanged(Entry::new(
446 Some(c_key),
447 value,
448 false,
449 false,
450 )))
451 } else {
452 Ok(CompResult::StillNone(c_key))
453 };
454 } else {
456 }
458
459 let waiter_guard = WaiterGuard::new(w_key, w_hash, &self.waiters, lock);
465
466 let ignore_if = None as Option<&mut fn(&V) -> bool>;
468 let maybe_entry = cache
469 .base
470 .get_with_hash(&*c_key, c_hash, ignore_if, true, true)
471 .await;
472 let maybe_value = if allow_nop {
473 maybe_entry.as_ref().map(|ent| ent.value().clone())
474 } else {
475 None
476 };
477 let entry_existed = maybe_entry.is_some();
478
479 let fut = match std::panic::catch_unwind(AssertUnwindSafe(|| f(maybe_entry))) {
482 Ok(fut) => fut,
484 Err(payload) => {
485 waiter_guard.set_waiter_value(WaiterValue::InitFuturePanicked);
486 resume_unwind(payload);
487 }
488 };
489
490 let output = match AssertUnwindSafe(fut).catch_unwind().await {
493 Ok(output) => output,
495 Err(payload) => {
497 waiter_guard.set_waiter_value(WaiterValue::InitFuturePanicked);
498 resume_unwind(payload);
499 }
500 };
501
502 match post_init(output) {
506 Ok(Op::Nop) => {
507 waiter_guard.set_waiter_value(WaiterValue::ReadyNone);
508 if let Some(value) = maybe_value {
509 Ok(CompResult::Unchanged(Entry::new(
510 Some(c_key),
511 value,
512 false,
513 false,
514 )))
515 } else {
516 Ok(CompResult::StillNone(c_key))
517 }
518 }
519 Ok(Op::Put(value)) => {
520 cache
521 .insert_with_hash(Arc::clone(&c_key), c_hash, value.clone())
522 .await;
523 waiter_guard.set_waiter_value(WaiterValue::ReadyNone);
524 if entry_existed {
525 crossbeam_epoch::pin().flush();
526 let entry = Entry::new(Some(c_key), value, true, true);
527 Ok(CompResult::ReplacedWith(entry))
528 } else {
529 let entry = Entry::new(Some(c_key), value, true, false);
530 Ok(CompResult::Inserted(entry))
531 }
532 }
533 Ok(Op::Remove) => {
534 let maybe_prev_v = cache.invalidate_with_hash(&*c_key, c_hash, true).await;
535 waiter_guard.set_waiter_value(WaiterValue::ReadyNone);
536 if let Some(prev_v) = maybe_prev_v {
537 crossbeam_epoch::pin().flush();
538 let entry = Entry::new(Some(c_key), prev_v, false, false);
539 Ok(CompResult::Removed(entry))
540 } else {
541 Ok(CompResult::StillNone(c_key))
542 }
543 }
544 Err(e) => {
545 waiter_guard.set_waiter_value(WaiterValue::ReadyNone);
546 Err(e)
547 }
548 }
549
550 }
552
553 pub(crate) fn post_init_for_get_with(value: V) -> Result<V, ()> {
555 Ok(value)
556 }
557
558 pub(crate) fn post_init_for_optionally_get_with(
560 value: Option<V>,
561 ) -> Result<V, Arc<OptionallyNone>> {
562 value.ok_or(Arc::new(OptionallyNone))
567 }
568
569 pub(crate) fn post_init_for_try_get_with<E>(result: Result<V, E>) -> Result<V, E> {
571 result
572 }
573
574 pub(crate) fn post_init_for_upsert_with(value: V) -> Result<Op<V>, ()> {
576 Ok(Op::Put(value))
577 }
578
579 pub(crate) fn post_init_for_compute_with(op: Op<V>) -> Result<Op<V>, ()> {
581 Ok(op)
582 }
583
584 pub(crate) fn post_init_for_try_compute_with<E>(op: Result<Op<V>, E>) -> Result<Op<V>, E>
586 where
587 E: Send + Sync + 'static,
588 {
589 op
590 }
591
592 pub(crate) fn post_init_for_try_compute_with_if_nobody_else<E>(
594 op: Result<Op<V>, E>,
595 ) -> Result<Op<V>, E>
596 where
597 E: Send + Sync + 'static,
598 {
599 op
600 }
601
602 pub(crate) fn type_id_for_get_with() -> TypeId {
604 TypeId::of::<()>()
607 }
608
609 pub(crate) fn type_id_for_optionally_get_with() -> TypeId {
611 TypeId::of::<OptionallyNone>()
612 }
613
614 pub(crate) fn type_id_for_try_get_with<E: 'static>() -> TypeId {
616 TypeId::of::<E>()
617 }
618}
619
620#[cfg(test)]
621impl<K, V, S> ValueInitializer<K, V, S> {
622 pub(crate) fn waiter_count(&self) -> usize {
623 self.waiters.len()
624 }
625}
626
627#[inline]
628fn remove_waiter<K, V, S>(waiter_map: &WaiterMap<K, V, S>, w_key: (Arc<K>, TypeId), w_hash: u64)
629where
630 (Arc<K>, TypeId): Eq + Hash,
631 S: BuildHasher,
632{
633 waiter_map.remove(w_hash, |k| k == &w_key);
634}
635
636#[inline]
637fn try_insert_waiter<K, V, S>(
638 waiter_map: &WaiterMap<K, V, S>,
639 w_key: (Arc<K>, TypeId),
640 w_hash: u64,
641 waiter: &Waiter<V>,
642) -> Option<Waiter<V>>
643where
644 (Arc<K>, TypeId): Eq + Hash,
645 S: BuildHasher,
646{
647 let waiter = MiniArc::clone(waiter);
648 waiter_map.insert_if_not_present(w_key, w_hash, waiter)
649}
650
651#[inline]
652fn waiter_key_hash<K, V, S>(
653 waiter_map: &WaiterMap<K, V, S>,
654 c_key: &Arc<K>,
655 type_id: TypeId,
656) -> ((Arc<K>, TypeId), u64)
657where
658 (Arc<K>, TypeId): Eq + Hash,
659 S: BuildHasher,
660{
661 let w_key = (Arc::clone(c_key), type_id);
662 let w_hash = waiter_map.hash(&w_key);
663 (w_key, w_hash)
664}
665
666fn panic_if_retry_exhausted_for_panicking(retries: usize, max: usize) {
667 assert!(
668 retries < max,
669 "Too many retries. Tried to read the return value from the `init` future \
670 but failed {retries} times. Maybe the `init` kept panicking?"
671 );
672}
673
674fn panic_if_retry_exhausted_for_aborting(retries: usize, max: usize) {
675 assert!(
676 retries < max,
677 "Too many retries. Tried to read the return value from the `init` future \
678 but failed {retries} times. Maybe the future containing `get_with`/`try_get_with` \
679 kept being aborted?"
680 );
681}