moka 0.9.6

A fast and concurrent cache library inspired by Java Caffeine
Documentation
use parking_lot::RwLock;
use std::{
    any::{Any, TypeId},
    hash::{BuildHasher, Hash},
    sync::Arc,
};
use triomphe::Arc as TrioArc;

use super::OptionallyNone;

const WAITER_MAP_NUM_SEGMENTS: usize = 64;

type ErrorObject = Arc<dyn Any + Send + Sync + 'static>;
type WaiterValue<V> = Option<Result<V, ErrorObject>>;
type Waiter<V> = TrioArc<RwLock<WaiterValue<V>>>;

pub(crate) enum InitResult<V, E> {
    Initialized(V),
    ReadExisting(V),
    InitErr(Arc<E>),
}

pub(crate) struct ValueInitializer<K, V, S> {
    // TypeId is the type ID of the concrete error type of generic type E in the
    // try_get_with method. We use the type ID as a part of the key to ensure that
    // we can always downcast the trait object ErrorObject (in Waiter<V>) into
    // its concrete type.
    waiters: crate::cht::SegmentedHashMap<(Arc<K>, TypeId), Waiter<V>, S>,
}

impl<K, V, S> ValueInitializer<K, V, S>
where
    K: Eq + Hash,
    V: Clone,
    S: BuildHasher,
{
    pub(crate) fn with_hasher(hasher: S) -> Self {
        Self {
            waiters: crate::cht::SegmentedHashMap::with_num_segments_and_hasher(
                WAITER_MAP_NUM_SEGMENTS,
                hasher,
            ),
        }
    }

    /// # Panics
    /// Panics if the `init` closure has been panicked.
    pub(crate) fn init_or_read(
        &self,
        key: Arc<K>,
        // Closure to get an existing value from cache.
        get: impl FnMut() -> Option<V>,
        init: impl FnOnce() -> V,
        // Closure to insert a new value into cache.
        mut insert: impl FnMut(V),
    ) -> InitResult<V, ()> {
        // This closure will be called before the init closure is called, in order to
        // check if the value has already been inserted by other thread.
        let pre_init = make_pre_init(get);

        // This closure will be called after the init closure has returned a value.
        // It will insert the returned value (from init) to the cache, and convert
        // the value into a pair of a WaiterValue and an InitResult.
        let post_init = |value: V| {
            insert(value.clone());
            (Some(Ok(value.clone())), InitResult::Initialized(value))
        };

        let type_id = TypeId::of::<()>();
        self.do_try_init(&key, type_id, pre_init, init, post_init)
    }

    /// # Panics
    /// Panics if the `init` closure has been panicked.
    pub(crate) fn try_init_or_read<E>(
        &self,
        key: Arc<K>,
        get: impl FnMut() -> Option<V>,
        init: impl FnOnce() -> Result<V, E>,
        mut insert: impl FnMut(V),
    ) -> InitResult<V, E>
    where
        E: Send + Sync + 'static,
    {
        let type_id = TypeId::of::<E>();

        // This closure will be called before the init closure is called, in order to
        // check if the value has already been inserted by other thread.
        let pre_init = make_pre_init(get);

        // This closure will be called after the init closure has returned a value.
        // It will insert the returned value (from init) to the cache, and convert
        // the value into a pair of a WaiterValue and an InitResult.
        let post_init = |value: Result<V, E>| match value {
            Ok(value) => {
                insert(value.clone());
                (Some(Ok(value.clone())), InitResult::Initialized(value))
            }
            Err(e) => {
                let err: ErrorObject = Arc::new(e);
                (
                    Some(Err(Arc::clone(&err))),
                    InitResult::InitErr(err.downcast().unwrap()),
                )
            }
        };

        self.do_try_init(&key, type_id, pre_init, init, post_init)
    }

    /// # Panics
    /// Panics if the `init` closure has been panicked.
    pub(super) fn optionally_init_or_read(
        &self,
        key: Arc<K>,
        get: impl FnMut() -> Option<V>,
        init: impl FnOnce() -> Option<V>,
        mut insert: impl FnMut(V),
    ) -> InitResult<V, OptionallyNone> {
        let type_id = TypeId::of::<OptionallyNone>();

        // This closure will be called before the init closure is called, in order to
        // check if the value has already been inserted by other thread.
        let pre_init = make_pre_init(get);

        // This closure will be called after the init closure has returned a value.
        // It will insert the returned value (from init) to the cache, and convert
        // the value into a pair of a WaiterValue and an InitResult.
        let post_init = |value: Option<V>| match value {
            Some(value) => {
                insert(value.clone());
                (Some(Ok(value.clone())), InitResult::Initialized(value))
            }
            None => {
                // `value` can be either `Some` or `None`. For `None` case, without
                // change the existing API too much, we will need to convert `None`
                // to Arc<E> here. `Infallible` could not be instantiated. So it
                // might be good to use an empty struct to indicate the error type.
                let err: ErrorObject = Arc::new(OptionallyNone);
                (
                    Some(Err(Arc::clone(&err))),
                    InitResult::InitErr(err.downcast().unwrap()),
                )
            }
        };

        self.do_try_init(&key, type_id, pre_init, init, post_init)
    }

    /// # Panics
    /// Panics if the `init` closure has been panicked.
    fn do_try_init<O, E>(
        &self,
        key: &Arc<K>,
        type_id: TypeId,
        mut pre_init: impl FnMut() -> Option<(WaiterValue<V>, InitResult<V, E>)>,
        init: impl FnOnce() -> O,
        mut post_init: impl FnMut(O) -> (WaiterValue<V>, InitResult<V, E>),
    ) -> InitResult<V, E>
    where
        E: Send + Sync + 'static,
    {
        use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
        use InitResult::*;

        const MAX_RETRIES: usize = 200;
        let mut retries = 0;

        let (cht_key, hash) = self.cht_key_hash(key, type_id);

        loop {
            let waiter = TrioArc::new(RwLock::new(None));
            let mut lock = waiter.write();

            match self.try_insert_waiter(cht_key.clone(), hash, &waiter) {
                None => {
                    // Our waiter was inserted.
                    // Check if the value has already been inserted by other thread.
                    if let Some((waiter_val, init_res)) = pre_init() {
                        // Yes. Set the waiter value, remove our waiter, and return
                        // the existing value.
                        *lock = waiter_val;
                        self.remove_waiter(cht_key, hash);
                        return init_res;
                    }

                    // The value still does note exist. Let's evaluate the init
                    // closure. Catching panic is safe here as we do not try to
                    // evaluate the closure again.
                    match catch_unwind(AssertUnwindSafe(init)) {
                        // Evaluated.
                        Ok(value) => {
                            let (waiter_val, init_res) = post_init(value);
                            *lock = waiter_val;
                            self.remove_waiter(cht_key, hash);
                            return init_res;
                        }
                        // Panicked.
                        Err(payload) => {
                            *lock = None;
                            // Remove the waiter so that others can retry.
                            self.remove_waiter(cht_key, hash);
                            resume_unwind(payload);
                        }
                    } // The write lock will be unlocked here.
                }
                Some(res) => {
                    // Somebody else's waiter already exists. Drop our write lock and
                    // wait for the read lock to become available.
                    std::mem::drop(lock);
                    match &*res.read() {
                        Some(Ok(value)) => return ReadExisting(value.clone()),
                        Some(Err(e)) => return InitErr(Arc::clone(e).downcast().unwrap()),
                        // None means somebody else's init closure has been panicked.
                        None => {
                            retries += 1;
                            if retries < MAX_RETRIES {
                                // Retry from the beginning.
                                continue;
                            } else {
                                panic!(
                                    "Too many retries. Tried to read the return value from the `init` \
                                closure but failed {} times. Maybe the `init` kept panicking?",
                                    retries
                                );
                            }
                        }
                    }
                }
            }
        }
    }

    #[inline]
    fn remove_waiter(&self, cht_key: (Arc<K>, TypeId), hash: u64) {
        self.waiters.remove(hash, |k| k == &cht_key);
    }

    #[inline]
    fn try_insert_waiter(
        &self,
        cht_key: (Arc<K>, TypeId),
        hash: u64,
        waiter: &Waiter<V>,
    ) -> Option<Waiter<V>> {
        let waiter = TrioArc::clone(waiter);
        self.waiters.insert_if_not_present(cht_key, hash, waiter)
    }

    #[inline]
    fn cht_key_hash(&self, key: &Arc<K>, type_id: TypeId) -> ((Arc<K>, TypeId), u64) {
        let cht_key = (Arc::clone(key), type_id);
        let hash = self.waiters.hash(&cht_key);
        (cht_key, hash)
    }
}

#[inline]
fn make_pre_init<V, E>(
    mut get: impl FnMut() -> Option<V>,
) -> impl FnMut() -> Option<(WaiterValue<V>, InitResult<V, E>)>
where
    V: Clone,
{
    move || get().map(|value| (Some(Ok(value.clone())), InitResult::ReadExisting(value)))
}