linked/per_access_static.rs
1// Copyright (c) Microsoft Corporation.
2// Copyright (c) Folo authors.
3
4use std::any::{Any, TypeId};
5use std::cell::RefCell;
6use std::collections::hash_map;
7use std::sync::{LazyLock, RwLock};
8
9use hash_hasher::HashedMap;
10
11use crate::{ERR_POISONED_LOCK, Handle};
12
13/// This is the real type of variables wrapped in the [`linked::instance_per_access!` macro][1].
14/// See macro documentation for more details.
15///
16/// Instances of this type are created by the [`linked::instance_per_access!` macro][1],
17/// never directly by user code. User code will simply call `.get()` on this type when it
18/// wants to obtain an instance of the linked object `T`.
19///
20/// [1]: [crate::instance_per_access]
21#[derive(Debug)]
22pub struct PerAccessStatic<T>
23where
24 T: linked::Object,
25{
26 /// A function we can call to obtain the lookup key for the family of linked objects.
27 ///
28 /// The family key is a `TypeId` because the expectation is that a unique empty type is
29 /// generated for each static variable in a `linked::instance_per_access!` block, to be used
30 /// as the family key for all instances linked to each other via that static variable.
31 family_key_provider: fn() -> TypeId,
32
33 /// Used to create the first instance of the linked object family. May be called
34 /// concurrently multiple times due to optimistic concurrency control, so it must
35 /// be idempotent and returned instances must be functionally equivalent. Even though
36 /// this may be called multiple times, only one return value will ever be exposed to
37 /// user code, with the others being dropped shortly after creation.
38 first_instance_provider: fn() -> T,
39}
40
41impl<T> PerAccessStatic<T>
42where
43 T: linked::Object,
44{
45 /// Note: this function exists to serve the inner workings of the
46 /// `linked::instance_per_access!` macro and should not be used directly.
47 /// It is not part of the public API and may be removed or changed at any time.
48 #[doc(hidden)]
49 pub const fn new(
50 family_key_provider: fn() -> TypeId,
51 first_instance_provider: fn() -> T,
52 ) -> Self {
53 Self {
54 family_key_provider,
55 first_instance_provider,
56 }
57 }
58
59 /// Gets a new `T` instance from the family of linked objects.
60 ///
61 /// # Performance
62 ///
63 /// This creates a new instance of `T` on every call so caching the return value is
64 /// performance-critical.
65 ///
66 /// Consider using the [`linked::instance_per_thread!` macro][1] if you want to
67 /// maintain only one instance per thread and return shared references to it.
68 ///
69 /// [1]: [crate::instance_per_thread]
70 pub fn get(&self) -> T {
71 if let Some(instance) = self.new_from_local_registry() {
72 return instance;
73 }
74
75 // TODO: This global registry step feels too smeared out.
76 // Can we draw it together into one step under one lock?
77 self.try_initialize_global_registry(self.first_instance_provider);
78
79 let handle = self
80 .get_handle_global()
81 .expect("we just initialized it, the handle must exist");
82
83 self.set_local(handle);
84
85 // We can now be certain the local registry has the value.
86 self.new_from_local_registry()
87 .expect("we just set the value, it must be there")
88 }
89
90 fn set_local(&self, value: Handle<T>) {
91 LOCAL_REGISTRY.with(|local_registry| {
92 local_registry
93 .borrow_mut()
94 .insert((self.family_key_provider)(), Box::new(value));
95 });
96 }
97
98 fn get_handle_global(&self) -> Option<Handle<T>> {
99 GLOBAL_REGISTRY
100 .read()
101 .expect(ERR_POISONED_LOCK)
102 .get(&(self.family_key_provider)())
103 .and_then(|w| w.downcast_ref::<Handle<T>>())
104 .cloned()
105 }
106
107 /// Attempts to register the object family in the global registry (if not already registered).
108 ///
109 /// This may use the provided provider to create a "first" instance of `T` even if another
110 /// "first" instance has been created already - optimistic concurrency is used to throw
111 /// away the extra instance if this proves necessary. Type-level documented requirements
112 /// give us the right to do this - the "first" instances must be functionally equivalent.
113 ///
114 /// # Performance
115 ///
116 /// This is a slightly expensive operation, so should be avoided on the hot path. We only do
117 /// this once per family per thread.
118 fn try_initialize_global_registry<FIP>(&self, first_instance_provider: FIP)
119 where
120 FIP: FnOnce() -> T,
121 {
122 // We do not today make use of our right to create a "first" instance of `T` even when
123 // we do not need it. This is a potential future optimization if it proves valuable.
124
125 let mut global_registry = GLOBAL_REGISTRY.write().expect(ERR_POISONED_LOCK);
126
127 // TODO: We are repeatedly acquiring the family key here and in sibling functions.
128 // Perhaps a trivial cost but explore the value of eliminating the duplicate access.
129 let family_key = (self.family_key_provider)();
130 let entry = global_registry.entry(family_key);
131
132 match entry {
133 hash_map::Entry::Occupied(_) => (),
134 hash_map::Entry::Vacant(entry) => {
135 // TODO: We create an instance here, only to immediately transform it back to
136 // a handle. Can we skip the middle step and just create a handle?
137 let first_instance = first_instance_provider();
138 entry.insert(Box::new(first_instance.handle()));
139 }
140 }
141 }
142
143 // Attempts to obtain a new instance of `T` using the current thread's family registry,
144 // returning `None` if the linked variable has not yet been seen by this thread and is
145 // therefore not present in the local registry.
146 fn new_from_local_registry(&self) -> Option<T> {
147 LOCAL_REGISTRY.with_borrow(|registry| {
148 let family_key = (self.family_key_provider)();
149
150 registry
151 .get(&family_key)
152 .and_then(|w| w.downcast_ref::<Handle<T>>())
153 .map(|handle| handle.clone().into())
154 })
155 }
156}
157
158/// Declares that all static variables within the macro body contain [linked objects][crate],
159/// with each access to this variable returning a new instance from the same family.
160///
161/// Each [`.get()`][1] on an included static variable returns a new linked object instance,
162/// with all instances obtained from the same static variable being part of the same family.
163///
164/// # Dynamic family relationships
165///
166/// If you need `Arc`-style dynamic multithreaded storage (i.e. not a single static variable),
167/// pass instances of [`Handle<T>`][3] between threads instead of (or in addition to)
168/// using this macro. You can obtain a [`Handle<T>`][3] from any linked object via the
169/// [`.handle()`][4] method of the [`linked::Object` trait][5], even if the instance of the
170/// linked object originally came from a static variable.
171///
172/// # Example
173///
174/// ```
175/// # #[linked::object]
176/// # struct TokenCache { }
177/// # impl TokenCache { fn with_capacity(capacity: usize) -> Self { linked::new!(Self { } ) } fn get_token(&self) -> usize { 42 } }
178/// linked::instance_per_access!(static TOKEN_CACHE: TokenCache = TokenCache::with_capacity(1000));
179///
180/// fn do_something() {
181/// // `.get()` returns a unique instance of the linked object on every call.
182/// let token_cache = TOKEN_CACHE.get();
183///
184/// let token = token_cache.get_token();
185/// }
186/// ```
187///
188/// [1]: PerAccessStatic::get
189/// [3]: crate::Handle
190/// [4]: crate::Object::handle
191/// [5]: crate::Object
192#[macro_export]
193macro_rules! instance_per_access {
194 () => {};
195
196 ($(#[$attr:meta])* $vis:vis static $NAME:ident: $t:ty = $e:expr; $($rest:tt)*) => (
197 ::linked::instance_per_access!($(#[$attr])* $vis static $NAME: $t = $e);
198 ::linked::instance_per_access!($($rest)*);
199 );
200
201 ($(#[$attr:meta])* $vis:vis static $NAME:ident: $t:ty = $e:expr) => {
202 ::linked::__private::paste! {
203 #[doc(hidden)]
204 #[allow(non_camel_case_types)]
205 struct [<__lookup_key_ $NAME>];
206
207 $(#[$attr])* $vis const $NAME: ::linked::PerAccessStatic<$t> =
208 ::linked::PerAccessStatic::new(
209 ::std::any::TypeId::of::<[<__lookup_key_ $NAME>]>,
210 move || $e);
211 }
212 };
213}
214
215// We use HashedMap which takes the raw value from Hash::hash() and uses it directly as the key.
216// This is OK because TypeId already returns a hashed value as its raw value, no need to hash more.
217// We also do not care about any hash manipulation because none of this is untrusted user input.
218type HandleRegistry = HashedMap<TypeId, Box<dyn Any + Send + Sync>>;
219
220// Global registry that is the ultimate authority where all linked variables are registered.
221// The values inside are `Handle<T>` where T may be different for each entry.
222static GLOBAL_REGISTRY: LazyLock<RwLock<HandleRegistry>> =
223 LazyLock::new(|| RwLock::new(HandleRegistry::default()));
224
225thread_local! {
226 // Thread-local registry where we cache any linked variables that have been seen by the current
227 // thread. The values inside are `Handle<T>` where T may be different for each entry.
228 static LOCAL_REGISTRY: RefCell<HandleRegistry> = RefCell::new(HandleRegistry::default());
229}
230
231/// Clears all data stored in the linked variable system from the current thread's point of view.
232///
233/// This is intended for use in tests only. It is publicly exposed because it may need to be called
234/// from integration tests and benchmarks, which cannot access private functions.
235#[doc(hidden)]
236pub fn __private_clear_linked_variables() {
237 let mut global_registry = GLOBAL_REGISTRY.write().expect(ERR_POISONED_LOCK);
238 global_registry.clear();
239
240 LOCAL_REGISTRY.with(|local_registry| {
241 local_registry.borrow_mut().clear();
242 });
243}
244
245#[cfg(test)]
246mod tests {
247 use std::any::TypeId;
248 use std::rc::Rc;
249 use std::sync::{Arc, Mutex};
250 use std::thread;
251
252 use crate::PerAccessStatic;
253
254 #[linked::object]
255 struct TokenCache {
256 value: Arc<Mutex<usize>>,
257 }
258
259 impl TokenCache {
260 fn new(value: usize) -> Self {
261 let value = Arc::new(Mutex::new(value));
262
263 linked::new!(Self {
264 value: Arc::clone(&value),
265 })
266 }
267
268 fn value(&self) -> usize {
269 *self.value.lock().unwrap()
270 }
271
272 fn increment(&self) {
273 let mut writer = self.value.lock().unwrap();
274 *writer = writer.saturating_add(1);
275 }
276 }
277
278 #[test]
279 fn linked_lazy() {
280 // Here we test the inner logic of the variable! macro without applying the macro.
281 // There is a separate test for executing the same logic via the macro itself.
282 struct RedKey;
283 struct GreenKey;
284
285 const RED_TOKEN_CACHE: PerAccessStatic<TokenCache> =
286 PerAccessStatic::new(TypeId::of::<RedKey>, || TokenCache::new(42));
287 const GREEN_TOKEN_CACHE: PerAccessStatic<TokenCache> =
288 PerAccessStatic::new(TypeId::of::<GreenKey>, || TokenCache::new(99));
289
290 assert_eq!(RED_TOKEN_CACHE.get().value(), 42);
291 assert_eq!(GREEN_TOKEN_CACHE.get().value(), 99);
292
293 RED_TOKEN_CACHE.get().increment();
294 GREEN_TOKEN_CACHE.get().increment();
295
296 thread::spawn(move || {
297 assert_eq!(RED_TOKEN_CACHE.get().value(), 43);
298 assert_eq!(GREEN_TOKEN_CACHE.get().value(), 100);
299
300 RED_TOKEN_CACHE.get().increment();
301 GREEN_TOKEN_CACHE.get().increment();
302 })
303 .join()
304 .unwrap();
305
306 assert_eq!(RED_TOKEN_CACHE.get().value(), 44);
307 assert_eq!(GREEN_TOKEN_CACHE.get().value(), 101);
308 }
309
310 #[test]
311 fn linked_smoke_test() {
312 linked::instance_per_access! {
313 static BLUE_TOKEN_CACHE: TokenCache = TokenCache::new(1000);
314 static YELLOW_TOKEN_CACHE: TokenCache = TokenCache::new(2000);
315 }
316
317 assert_eq!(BLUE_TOKEN_CACHE.get().value(), 1000);
318 assert_eq!(YELLOW_TOKEN_CACHE.get().value(), 2000);
319
320 assert_eq!(BLUE_TOKEN_CACHE.get().clone().value(), 1000);
321 assert_eq!(YELLOW_TOKEN_CACHE.get().clone().value(), 2000);
322
323 BLUE_TOKEN_CACHE.get().increment();
324 YELLOW_TOKEN_CACHE.get().increment();
325
326 thread::spawn(move || {
327 assert_eq!(BLUE_TOKEN_CACHE.get().value(), 1001);
328 assert_eq!(YELLOW_TOKEN_CACHE.get().value(), 2001);
329
330 BLUE_TOKEN_CACHE.get().increment();
331 YELLOW_TOKEN_CACHE.get().increment();
332 })
333 .join()
334 .unwrap();
335
336 assert_eq!(BLUE_TOKEN_CACHE.get().value(), 1002);
337 assert_eq!(YELLOW_TOKEN_CACHE.get().value(), 2002);
338 }
339
340 #[test]
341 fn thread_local_from_linked() {
342 linked::instance_per_access!(static LINKED_CACHE: TokenCache = TokenCache::new(1000));
343 thread_local!(static LOCAL_CACHE: Rc<TokenCache> = Rc::new(LINKED_CACHE.get()));
344
345 let cache = LOCAL_CACHE.with(Rc::clone);
346 assert_eq!(cache.value(), 1000);
347 cache.increment();
348
349 thread::spawn(move || {
350 let cache = LOCAL_CACHE.with(Rc::clone);
351 assert_eq!(cache.value(), 1001);
352 cache.increment();
353 assert_eq!(cache.value(), 1002);
354 })
355 .join()
356 .unwrap();
357
358 assert_eq!(cache.value(), 1002);
359
360 let cache_rc_clone = LOCAL_CACHE.with(Rc::clone);
361 assert_eq!(cache_rc_clone.value(), 1002);
362 }
363}