anyctx/lib.rs
1use std::{
2 any::{Any, TypeId},
3 fmt::Debug,
4 mem::MaybeUninit,
5 ops::Deref,
6 sync::{Arc, RwLock},
7};
8
9use ahash::AHashMap;
10
11#[derive(Clone)]
12/// A context type for storing and retrieving "quasi-global" data in a type-safe and scope-respecting manner. Think of it as a "god object" that does not clutter up scope or generate spaghetti.
13///
14/// This context allows for the dynamic association of data with a key derived from a lazily evaluated constructor. It is designed to be thread-safe and can be shared across threads.
15///
16/// Moreover, the context also wraps a provided *initialization value*. This allows easy access to data that's more ergonomically passed in at initialization rather than lazily initialized later.
17///
18/// Generics:
19/// - `I`: Initialization info type, which must be `Send + Sync + 'static`
20///
21/// # Examples
22///
23/// ```
24/// use anyctx::AnyCtx;
25///
26/// fn forty_two(ctx: &AnyCtx<i32>) -> i32 {
27/// 40 + *ctx.init()
28/// }
29///
30/// let ctx = AnyCtx::new(2);
31/// let number = ctx.get(forty_two);
32/// assert_eq!(*number, 42);
33/// ```
34pub struct AnyCtx<I: Send + Sync + 'static> {
35 init: Arc<I>,
36 dynamic: Arc<RwLock<AHashMap<TypeId, Arc<RwLock<MaybeUninit<Box<dyn Any + Send + Sync>>>>>>>,
37}
38
39unsafe impl<T: Send + Sync + 'static> Send for AnyCtx<T> {}
40unsafe impl<T: Send + Sync + 'static> Sync for AnyCtx<T> {}
41
42impl<T: Send + Sync + 'static> Debug for AnyCtx<T> {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 format!("AnyCtx({} keys)", self.dynamic.read().unwrap().len()).fmt(f)
45 }
46}
47
48impl<I: Send + Sync + 'static> AnyCtx<I> {
49 /// Creates a new context, wrapping the given initialization value.
50 pub fn new(init: I) -> Self {
51 Self {
52 init: init.into(),
53 dynamic: Default::default(),
54 }
55 }
56
57 /// Gets the initialization value.
58 pub fn init(&self) -> &I {
59 &self.init
60 }
61
62 /// Gets the value associated with the given constructor function. If there already is a value, the value will be retrieved. Otherwise, the constructor will be run to produce the value, which will be stored in the context.
63 ///
64 /// It is guaranteed that the constructor will be called at most once, even if `get` is called concurrently from multiple threads with the same key.
65 ///
66 /// The constructor itself should take in an AnyCtx as an argument, and is allowed to call `get` too. Take care to avoid infinite recursion, which will cause a deadlock.
67 pub fn get<T: 'static + Send + Sync, F: Fn(&Self) -> T + 'static + Send + Sync + Copy>(
68 &self,
69 construct: F,
70 ) -> &T {
71 loop {
72 if let Some(exists) = self.get_inner(construct) {
73 return exists;
74 } else {
75 let key = construct.type_id();
76 let mut inner = self.dynamic.write().unwrap();
77 if inner.contains_key(&key) {
78 // now get will return, so loop around
79 continue;
80 }
81 let to_init = Arc::new(RwLock::new(MaybeUninit::uninit()));
82 let mut entry = to_init.write().unwrap();
83 inner.insert(key, to_init.clone());
84 drop(inner);
85
86 // now inner is unlocked, so we're not blocking the whole map.
87 // but our particular entry is locked.
88 // we can init in peace.
89 let value = construct(self);
90 entry.write(Box::new(value));
91 // loop around again and we'll get it
92 }
93 }
94 }
95
96 fn get_inner<'a, T: 'static + Send + Sync, F: Fn(&Self) -> T + 'static + Send + Sync + Copy>(
97 &'a self,
98 init: F,
99 ) -> Option<&'a T> {
100 let inner = self.dynamic.read().unwrap();
101 let b = inner.get(&init.type_id())?;
102 let b = b.read().unwrap();
103 // SAFETY: by the time we can read-lock this value, we know that it has already initialized, since the initialization function holds a lock for the full duration.
104 let b = unsafe { b.assume_init_ref() };
105 let downcasted: &T = b
106 .downcast_ref()
107 .expect("downcast failed, this should not happen");
108 // SAFETY: we never remove items from inner without dropping Context first, and the address of what Box points to cannot change, so this is safe
109 let downcasted: &'a T = unsafe { std::mem::transmute(downcasted) };
110 Some(downcasted)
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use std::any::Any;
117
118 use crate::AnyCtx;
119
120 fn one(_ctx: &AnyCtx<()>) -> usize {
121 1
122 }
123
124 fn hello(_ctx: &AnyCtx<()>) -> String {
125 "hello".to_string()
126 }
127
128 fn two(ctx: &AnyCtx<()>) -> usize {
129 ctx.get(one) + ctx.get(one)
130 }
131
132 #[test]
133 fn simple() {
134 let ctx = AnyCtx::new(());
135 assert_eq!(ctx.get(two), &2);
136 assert_eq!(ctx.get(hello), "hello")
137 }
138
139 #[test]
140 fn function_magic() {
141 fn a() -> usize {
142 1
143 }
144
145 fn b() -> usize {
146 1
147 }
148
149 eprintln!("{}", a as usize);
150 eprintln!("{}", b as usize);
151 eprintln!("{:?}", a.type_id());
152 eprintln!("{:?}", b.type_id());
153 }
154}