use std::{
any::Any,
collections::{HashMap, HashSet},
sync::{Arc, LazyLock, RwLock},
};
use crate::fused_rw::FusedRw;
type Item = Arc<FusedRw<Option<Arc<dyn Any + Send + Sync>>>>;
static CACHE: LazyLock<RwLock<HashMap<String, Item>>> =
LazyLock::new(|| RwLock::new(HashMap::new()));
pub struct GlobalCache {
_p: std::marker::PhantomData<[u8]>,
}
impl GlobalCache {
pub fn is_cached(path: impl AsRef<str>) -> bool {
let Some(entry) = CACHE
.read()
.expect("Cache cannot poison")
.get(path.as_ref())
.cloned()
else {
return false;
};
entry.read().is_some()
}
pub async fn is_cached_async(path: impl AsRef<str>) -> bool {
let Some(entry) = CACHE
.read()
.expect("Cache cannot poison")
.get(path.as_ref())
.cloned()
else {
return false;
};
entry.read_async().await.is_some()
}
pub fn is_cached_as<T: Any + Send + Sync>(path: impl AsRef<str>) -> bool {
let Some(entry) = CACHE
.read()
.expect("Cache cannot poison")
.get(path.as_ref())
.cloned()
else {
return false;
};
entry.read().as_ref().is_some_and(|e| e.is::<T>())
}
pub async fn is_cached_as_async<T: Any + Send + Sync>(path: impl AsRef<str>) -> bool {
let Some(entry) = CACHE
.read()
.expect("Cache cannot poison")
.get(path.as_ref())
.cloned()
else {
return false;
};
entry
.read_async()
.await
.as_ref()
.is_some_and(|e| e.is::<T>())
}
pub fn get_dyn(path: impl AsRef<str>) -> Option<Arc<dyn Any + Send + Sync>> {
let entry = {
let guard = CACHE.read().expect("Cache cannot poison");
let entry = guard.get(path.as_ref())?.clone();
drop(guard);
entry
};
entry.read().clone()
}
pub async fn get_dyn_async(path: impl AsRef<str>) -> Option<Arc<dyn Any + Send + Sync>> {
let entry = {
let guard = CACHE.read().expect("Cache cannot poison");
let entry = guard.get(path.as_ref())?.clone();
drop(guard);
entry
};
entry.read_async().await.clone()
}
pub fn get<T: Any + Send + Sync>(path: impl AsRef<str>) -> Option<Arc<T>> {
Self::get_dyn(path).and_then(|v| v.downcast().ok())
}
pub async fn get_async<T: Any + Send + Sync>(path: impl AsRef<str>) -> Option<Arc<T>> {
Self::get_dyn_async(path)
.await
.and_then(|v| v.downcast().ok())
}
pub fn get_or<T: Any + Send + Sync>(path: impl AsRef<str>, default: Arc<T>) -> Option<Arc<T>> {
if let Some(good) = Self::get_dyn(path.as_ref()) {
return good.downcast().ok();
}
let entry = {
let mut guard = CACHE.write().expect("Cache cannot poison");
guard
.entry(path.as_ref().to_string())
.or_insert_with(|| Arc::new(FusedRw::new(None)))
.clone()
};
if let Some(good) = entry.read().as_ref() {
good.clone().downcast().ok()
} else {
assert!(
entry.write().replace(default.clone()).is_none(),
"Cached value should be None"
);
Some(default)
}
}
pub async fn get_or_async<T: Any + Send + Sync>(
path: impl AsRef<str>,
default: Arc<T>,
) -> Option<Arc<T>> {
if let Some(good) = Self::get_dyn_async(path.as_ref()).await {
return good.downcast().ok();
}
let entry = {
let mut guard = CACHE.write().expect("Cache cannot poison");
guard
.entry(path.as_ref().to_string())
.or_insert_with(|| Arc::new(FusedRw::new(None)))
.clone()
};
if let Some(good) = entry.read_async().await.as_ref() {
good.clone().downcast().ok()
} else {
assert!(
entry.write_async().await.replace(default.clone()).is_none(),
"Cached value should be None"
);
Some(default)
}
}
pub fn get_or_else<T: Any + Send + Sync, F: FnOnce() -> Arc<T>>(
path: impl AsRef<str>,
f: F,
) -> Option<Arc<T>> {
if let Some(good) = Self::get_dyn(path.as_ref()) {
return good.downcast().ok();
}
let entry = {
let mut guard = CACHE.write().expect("Cache cannot poison");
guard
.entry(path.as_ref().to_string())
.or_insert_with(|| Arc::new(FusedRw::new(None)))
.clone()
};
if let Some(good) = entry.read().as_ref() {
good.clone().downcast().ok()
} else {
let loaded = f();
assert!(
entry.write().replace(loaded.clone()).is_none(),
"Cached value should be None"
);
Some(loaded)
}
}
pub async fn get_or_else_async<T: Any + Send + Sync, F: AsyncFnOnce() -> Arc<T>>(
path: impl AsRef<str>,
f: F,
) -> Option<Arc<T>> {
if let Some(good) = Self::get_dyn_async(path.as_ref()).await {
return good.downcast().ok();
}
let entry = {
let mut guard = CACHE.write().expect("Cache cannot poison");
guard
.entry(path.as_ref().to_string())
.or_insert_with(|| Arc::new(FusedRw::new(None)))
.clone()
};
if let Some(good) = entry.read_async().await.as_ref() {
good.clone().downcast().ok()
} else {
let loaded = f().await;
assert!(
entry.write_async().await.replace(loaded.clone()).is_none(),
"Cached value should be None"
);
Some(loaded)
}
}
pub fn uncache(path: impl AsRef<str>) {
let mut guard = CACHE.write().expect("Cache cannot poison");
guard.remove(path.as_ref());
}
pub fn clear() {
let mut guard = CACHE.write().expect("Cache cannot poison");
guard.clear();
}
pub fn namespaced_paths(namespace: impl AsRef<str>) -> HashSet<String> {
let ns = namespace.as_ref();
let guard = CACHE.read().expect("Cache cannot poison");
guard
.iter()
.filter_map(|(k, v)| k.strip_prefix(ns).map(|s| (s.to_string(), v)))
.filter(|(_, v)| v.read().is_some())
.map(|(k, _)| k)
.collect()
}
pub async fn namespaced_paths_async(namespace: impl AsRef<str>) -> HashSet<String> {
let ns = namespace.as_ref();
let namespaced: Vec<_> = {
CACHE
.read()
.expect("Cache cannot poison")
.iter()
.filter_map(|(k, _)| k.strip_prefix(ns))
.map(|s| s.to_string())
.collect()
};
let mut hs = HashSet::new();
for k in namespaced {
let Some(entry) = CACHE.read().expect("Cache cannot poison").get(&k).cloned() else {
continue;
};
if entry.read_async().await.is_some() {
hs.insert(k);
}
}
hs
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicUsize;
#[test]
fn simple() {
let string = GlobalCache::get_or_else("abc", || Arc::new("hello".to_string())).unwrap();
assert_eq!(*string, "hello");
}
#[test]
fn shared() {
let one = GlobalCache::get_or_else("123", || {
println!("Running one");
Arc::new("hello".to_string())
})
.unwrap();
let two = GlobalCache::get_or_else("123", || {
println!("Running two");
Arc::new("hello".to_string())
})
.unwrap();
assert!(Arc::ptr_eq(&one, &two));
}
#[test]
fn simultaneous30() {
let mut handles = Vec::new();
const N: usize = 30;
let barrier = Arc::new(std::sync::Barrier::new(N));
let runs = Arc::new(AtomicUsize::new(0));
for _ in 0..N {
let bar_clone = barrier.clone();
let runs_clone = runs.clone();
let handle = std::thread::spawn(move || {
bar_clone.wait();
let string = GlobalCache::get_or_else("xyz", || {
runs_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Arc::new("hello".to_string())
})
.unwrap();
assert_eq!(*string, "hello");
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(runs.load(std::sync::atomic::Ordering::SeqCst), 1);
}
}