use std::any::{Any, TypeId};
use std::cell::RefCell;
use std::collections::hash_map;
use std::sync::{LazyLock, RwLock};
use hash_hasher::HashedMap;
use crate::{ERR_POISONED_LOCK, Family};
#[derive(Debug)]
pub struct StaticInstances<T>
where
T: linked::Object,
{
family_key_provider: fn() -> TypeId,
first_instance_provider: fn() -> T,
}
impl<T> StaticInstances<T>
where
T: linked::Object,
{
#[cfg_attr(coverage_nightly, coverage(off))]
#[doc(hidden)]
#[must_use]
pub const fn new(
family_key_provider: fn() -> TypeId,
first_instance_provider: fn() -> T,
) -> Self {
Self {
family_key_provider,
first_instance_provider,
}
}
#[must_use]
pub fn get(&self) -> T {
if let Some(instance) = self.new_from_local_registry() {
return instance;
}
self.try_initialize_global_registry(self.first_instance_provider);
let family = self
.get_family_global()
.expect("we just initialized it, the family must exist");
self.set_local(family);
self.new_from_local_registry()
.expect("we just set the value, it must be there")
}
fn set_local(&self, value: Family<T>) {
LOCAL_REGISTRY.with(|local_registry| {
local_registry
.borrow_mut()
.insert((self.family_key_provider)(), Box::new(value));
});
}
fn get_family_global(&self) -> Option<Family<T>> {
GLOBAL_REGISTRY
.read()
.expect(ERR_POISONED_LOCK)
.get(&(self.family_key_provider)())
.and_then(|w| w.downcast_ref::<Family<T>>())
.cloned()
}
fn try_initialize_global_registry<FIP>(&self, first_instance_provider: FIP)
where
FIP: FnOnce() -> T,
{
let mut global_registry = GLOBAL_REGISTRY.write().expect(ERR_POISONED_LOCK);
let family_key = (self.family_key_provider)();
let entry = global_registry.entry(family_key);
match entry {
hash_map::Entry::Occupied(_) => (),
hash_map::Entry::Vacant(entry) => {
let first_instance = first_instance_provider();
entry.insert(Box::new(first_instance.family()));
}
}
}
fn new_from_local_registry(&self) -> Option<T> {
LOCAL_REGISTRY.with_borrow(|registry| {
let family_key = (self.family_key_provider)();
registry
.get(&family_key)
.and_then(|w| w.downcast_ref::<Family<T>>())
.map(|family| family.clone().into())
})
}
}
#[macro_export]
macro_rules! instances {
() => {};
($(#[$attr:meta])* $vis:vis static $NAME:ident: $t:ty = $e:expr; $($rest:tt)*) => (
$crate::instances!($(#[$attr])* $vis static $NAME: $t = $e);
$crate::instances!($($rest)*);
);
($(#[$attr:meta])* $vis:vis static $NAME:ident: $t:ty = $e:expr) => {
$crate::__private::paste! {
#[doc(hidden)]
#[expect(non_camel_case_types, reason = "intentionally uglified macro generated code")]
struct [<__lookup_key_ $NAME>];
$(#[$attr])* $vis const $NAME: $crate::StaticInstances<$t> =
$crate::StaticInstances::new(
::std::any::TypeId::of::<[<__lookup_key_ $NAME>]>,
move || $e);
}
};
}
type FamilyRegistry = HashedMap<TypeId, Box<dyn Any + Send + Sync>>;
static GLOBAL_REGISTRY: LazyLock<RwLock<FamilyRegistry>> =
LazyLock::new(|| RwLock::new(FamilyRegistry::default()));
thread_local! {
static LOCAL_REGISTRY: RefCell<FamilyRegistry> = RefCell::new(FamilyRegistry::default());
}
#[cfg_attr(test, mutants::skip)] #[doc(hidden)]
pub fn __private_clear_linked_variables_global() {
let mut global_registry = GLOBAL_REGISTRY.write().expect(ERR_POISONED_LOCK);
global_registry.clear();
}
#[cfg_attr(test, mutants::skip)] #[doc(hidden)]
pub fn __private_clear_linked_variables_local() {
LOCAL_REGISTRY.with(|local_registry| {
local_registry.borrow_mut().clear();
});
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use std::any::TypeId;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::thread;
use static_assertions::assert_impl_all;
use crate::StaticInstances;
assert_impl_all!(
StaticInstances<TokenCache>: UnwindSafe, RefUnwindSafe
);
#[linked::object]
struct TokenCache {
value: Arc<Mutex<usize>>,
}
impl TokenCache {
fn new(value: usize) -> Self {
let value = Arc::new(Mutex::new(value));
linked::new!(Self {
value: Arc::clone(&value),
})
}
fn value(&self) -> usize {
*self.value.lock().unwrap()
}
fn increment(&self) {
let mut writer = self.value.lock().unwrap();
*writer = writer.saturating_add(1);
}
}
#[test]
fn instances_non_macro() {
struct RedKey;
struct GreenKey;
const RED_TOKEN_CACHE: StaticInstances<TokenCache> =
StaticInstances::new(TypeId::of::<RedKey>, || TokenCache::new(42));
const GREEN_TOKEN_CACHE: StaticInstances<TokenCache> =
StaticInstances::new(TypeId::of::<GreenKey>, || TokenCache::new(99));
let red = RED_TOKEN_CACHE.get();
let green = GREEN_TOKEN_CACHE.get();
assert_eq!(red.value(), 42);
assert_eq!(green.value(), 99);
red.increment();
green.increment();
thread::spawn(move || {
let red = RED_TOKEN_CACHE.get();
let green = GREEN_TOKEN_CACHE.get();
assert_eq!(red.value(), 43);
assert_eq!(green.value(), 100);
red.increment();
green.increment();
})
.join()
.unwrap();
assert_eq!(red.value(), 44);
assert_eq!(green.value(), 101);
assert_eq!(RED_TOKEN_CACHE.get().value(), 44);
assert_eq!(GREEN_TOKEN_CACHE.get().value(), 101);
}
#[test]
fn instances_macro() {
linked::instances! {
static BLUE_TOKEN_CACHE: TokenCache = TokenCache::new(1000);
static YELLOW_TOKEN_CACHE: TokenCache = TokenCache::new(2000);
}
let blue = BLUE_TOKEN_CACHE.get();
let yellow = YELLOW_TOKEN_CACHE.get();
assert_eq!(blue.value(), 1000);
assert_eq!(yellow.value(), 2000);
blue.increment();
yellow.increment();
thread::spawn(move || {
let blue = BLUE_TOKEN_CACHE.get();
let yellow = YELLOW_TOKEN_CACHE.get();
assert_eq!(blue.value(), 1001);
assert_eq!(yellow.value(), 2001);
blue.increment();
yellow.increment();
})
.join()
.unwrap();
assert_eq!(blue.value(), 1002);
assert_eq!(yellow.value(), 2002);
assert_eq!(BLUE_TOKEN_CACHE.get().value(), 1002);
assert_eq!(YELLOW_TOKEN_CACHE.get().value(), 2002);
}
#[test]
fn stored_in_thread_local() {
linked::instances!(static LINKED_CACHE: TokenCache = TokenCache::new(1000));
thread_local!(static LOCAL_CACHE: Rc<TokenCache> = Rc::new(LINKED_CACHE.get()));
let cache = LOCAL_CACHE.with(Rc::clone);
assert_eq!(cache.value(), 1000);
cache.increment();
thread::spawn(move || {
let cache = LOCAL_CACHE.with(Rc::clone);
assert_eq!(cache.value(), 1001);
cache.increment();
assert_eq!(cache.value(), 1002);
})
.join()
.unwrap();
assert_eq!(cache.value(), 1002);
let cache_rc_clone = LOCAL_CACHE.with(Rc::clone);
assert_eq!(cache_rc_clone.value(), 1002);
}
}