use slumber_template::{Identifier, Value};
use std::{
collections::{HashMap, hash_map::Entry},
ops::DerefMut,
sync::Arc,
};
use tokio::sync::{Mutex, OwnedMutexGuard};
#[derive(Debug, Default)]
pub struct FieldCache {
cache: Mutex<HashMap<Identifier, Arc<Mutex<Option<Value>>>>>,
}
impl FieldCache {
pub(crate) async fn get_or_init(
&self,
field: Identifier,
) -> FieldCacheOutcome {
let mut cache = self.cache.lock().await;
match cache.entry(field) {
Entry::Occupied(entry) => {
let lock = Arc::clone(entry.get());
drop(cache); let guard = lock.clone().lock_owned().await;
if let Some(value) = &*guard {
FieldCacheOutcome::Hit(value.clone())
} else {
FieldCacheOutcome::Miss(FieldCacheGuard(guard))
}
}
Entry::Vacant(entry) => {
let lock = Arc::new(Mutex::new(None));
entry.insert(Arc::clone(&lock));
let guard = lock
.try_lock_owned()
.expect("Lock was just created, who the hell grabbed it??");
drop(cache);
FieldCacheOutcome::Miss(FieldCacheGuard(guard))
}
}
}
}
#[derive(Debug)]
pub(crate) enum FieldCacheOutcome {
Hit(Value),
Miss(FieldCacheGuard),
}
#[derive(Debug)]
pub(crate) struct FieldCacheGuard(OwnedMutexGuard<Option<Value>>);
impl FieldCacheGuard {
pub fn set(mut self, value: Value) {
*self.0.deref_mut() = Some(value);
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::join;
use slumber_util::assert_matches;
#[tokio::test]
async fn test_field_cache() {
let cache = FieldCache::default();
let field: Identifier = "field".into();
let fut1 = async {
let guard = assert_matches!(
cache.get_or_init(field.clone()).await,
FieldCacheOutcome::Miss(guard) => guard,
);
let value: Value = true.into();
guard.set(value.clone());
value
};
let fut2 = async {
assert_matches!(
cache.get_or_init(field.clone()).await,
FieldCacheOutcome::Hit(value) => value,
)
};
let (v1, v2) = join!(fut1, fut2);
assert_eq!(v1, true.into());
assert_eq!(v2, true.into());
}
#[tokio::test]
async fn test_field_cache_dropped_guard() {
let cache = FieldCache::default();
let field: Identifier = "field".into();
let fut1 = async {
let guard = assert_matches!(
cache.get_or_init(field.clone()).await,
FieldCacheOutcome::Miss(guard) => guard,
);
drop(guard);
};
let fut2 = async {
let guard = assert_matches!(
cache.get_or_init(field.clone()).await,
FieldCacheOutcome::Miss(guard) => guard,
);
let value: Value = true.into();
guard.set(value.clone());
value
};
let fut3 = async {
assert_matches!(
cache.get_or_init(field.clone()).await,
FieldCacheOutcome::Hit(value) => value,
)
};
let ((), v2, v3) = join!(fut1, fut2, fut3);
assert_eq!(v2, true.into());
assert_eq!(v3, true.into());
}
}