use std::{
any::{Any, TypeId},
fmt::Debug,
mem::MaybeUninit,
sync::{Arc, RwLock},
};
use ahash::AHashMap;
#[derive(Clone)]
pub struct AnyCtx<I: Send + Sync + 'static> {
init: Arc<I>,
dynamic: Arc<RwLock<AHashMap<CtxKey, Arc<RwLock<MaybeUninit<Box<dyn Any + Send + Sync>>>>>>>,
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
struct CtxKey {
ty: TypeId,
fn_addr: Option<usize>,
}
unsafe impl<T: Send + Sync + 'static> Send for AnyCtx<T> {}
unsafe impl<T: Send + Sync + 'static> Sync for AnyCtx<T> {}
impl<T: Send + Sync + 'static> Debug for AnyCtx<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
format!("AnyCtx({} keys)", self.dynamic.read().unwrap().len()).fmt(f)
}
}
impl<I: Send + Sync + 'static> AnyCtx<I> {
pub fn new(init: I) -> Self {
Self {
init: init.into(),
dynamic: Default::default(),
}
}
pub fn init(&self) -> &I {
&self.init
}
pub fn get<T: 'static + Send + Sync, F: Fn(&Self) -> T + 'static + Send + Sync + Copy>(
&self,
construct: F,
) -> &T {
let key = Self::key_for(construct);
loop {
if let Some(exists) = self.get_inner(key) {
return exists;
} else {
let mut inner = self.dynamic.write().unwrap();
if inner.contains_key(&key) {
continue;
}
let to_init = Arc::new(RwLock::new(MaybeUninit::uninit()));
let mut entry = to_init.write().unwrap();
inner.insert(key, to_init.clone());
drop(inner);
let value = construct(self);
entry.write(Box::new(value));
}
}
}
fn key_for<T: 'static + Send + Sync, F: Fn(&Self) -> T + 'static + Send + Sync + Copy>(
construct: F,
) -> CtxKey {
let ty = TypeId::of::<F>();
let fn_addr = if ty == TypeId::of::<fn(&Self) -> T>() {
assert_eq!(
std::mem::size_of::<F>(),
std::mem::size_of::<fn(&Self) -> T>()
);
let f = unsafe { std::mem::transmute_copy::<F, fn(&Self) -> T>(&construct) };
Some(f as usize)
} else {
None
};
CtxKey { ty, fn_addr }
}
fn get_inner<'a, T: 'static + Send + Sync>(&'a self, key: CtxKey) -> Option<&'a T> {
let inner = self.dynamic.read().unwrap();
let b = inner.get(&key)?;
let b = b.read().unwrap();
let b = unsafe { b.assume_init_ref() };
let downcasted: &T = b
.downcast_ref()
.expect("downcast failed, this should not happen");
let downcasted: &'a T = unsafe { std::mem::transmute(downcasted) };
Some(downcasted)
}
}
#[cfg(test)]
mod tests {
use std::any::Any;
use std::sync::atomic::{AtomicBool, Ordering};
use crate::AnyCtx;
fn one(_ctx: &AnyCtx<()>) -> usize {
1
}
fn hello(_ctx: &AnyCtx<()>) -> String {
"hello".to_string()
}
fn two(ctx: &AnyCtx<()>) -> usize {
ctx.get(one) + ctx.get(one)
}
type BoolField = fn(&AnyCtx<()>) -> AtomicBool;
static FIELD_A: BoolField = |_| AtomicBool::new(false);
static FIELD_B: BoolField = |_| AtomicBool::new(false);
#[test]
fn simple() {
let ctx = AnyCtx::new(());
assert_eq!(ctx.get(two), &2);
assert_eq!(ctx.get(hello), "hello")
}
#[test]
fn function_magic() {
fn a() -> usize {
1
}
fn b() -> usize {
1
}
eprintln!("{}", a as *const () as usize);
eprintln!("{}", b as *const () as usize);
eprintln!("{:?}", a.type_id());
eprintln!("{:?}", b.type_id());
}
#[test]
fn same_function_pointer_field_gets_same_value() {
let ctx = AnyCtx::new(());
ctx.get(FIELD_A).store(true, Ordering::SeqCst);
assert!(std::ptr::eq(ctx.get(FIELD_A), ctx.get(FIELD_A)));
assert!(ctx.get(FIELD_A).load(Ordering::SeqCst));
}
#[test]
fn same_signature_function_pointer_fields_get_distinct_values() {
let ctx = AnyCtx::new(());
ctx.get(FIELD_A).store(true, Ordering::SeqCst);
assert!(!std::ptr::eq(ctx.get(FIELD_A), ctx.get(FIELD_B)));
assert!(!ctx.get(FIELD_B).load(Ordering::SeqCst));
}
}