linked/
per_access_static.rs

1// Copyright (c) Microsoft Corporation.
2// Copyright (c) Folo authors.
3
4use std::any::{Any, TypeId};
5use std::cell::RefCell;
6use std::collections::hash_map;
7use std::sync::{LazyLock, RwLock};
8
9use hash_hasher::HashedMap;
10
11use crate::{ERR_POISONED_LOCK, Handle};
12
13/// This is the real type of variables wrapped in the [`linked::instance_per_access!` macro][1].
14/// See macro documentation for more details.
15///
16/// Instances of this type are created by the [`linked::instance_per_access!` macro][1],
17/// never directly by user code. User code will simply call `.get()` on this type when it
18/// wants to obtain an instance of the linked object `T`.
19///
20/// [1]: [crate::instance_per_access]
21#[derive(Debug)]
22pub struct PerAccessStatic<T>
23where
24    T: linked::Object,
25{
26    /// A function we can call to obtain the lookup key for the family of linked objects.
27    ///
28    /// The family key is a `TypeId` because the expectation is that a unique empty type is
29    /// generated for each static variable in a `linked::instance_per_access!` block, to be used
30    /// as the family key for all instances linked to each other via that static variable.
31    family_key_provider: fn() -> TypeId,
32
33    /// Used to create the first instance of the linked object family. May be called
34    /// concurrently multiple times due to optimistic concurrency control, so it must
35    /// be idempotent and returned instances must be functionally equivalent. Even though
36    /// this may be called multiple times, only one return value will ever be exposed to
37    /// user code, with the others being dropped shortly after creation.
38    first_instance_provider: fn() -> T,
39}
40
41impl<T> PerAccessStatic<T>
42where
43    T: linked::Object,
44{
45    /// Note: this function exists to serve the inner workings of the
46    /// `linked::instance_per_access!` macro and should not be used directly.
47    /// It is not part of the public API and may be removed or changed at any time.
48    #[doc(hidden)]
49    pub const fn new(
50        family_key_provider: fn() -> TypeId,
51        first_instance_provider: fn() -> T,
52    ) -> Self {
53        Self {
54            family_key_provider,
55            first_instance_provider,
56        }
57    }
58
59    /// Gets a new `T` instance from the family of linked objects.
60    ///
61    /// # Performance
62    ///
63    /// This creates a new instance of `T` on every call so caching the return value is
64    /// performance-critical.
65    ///
66    /// Consider using the [`linked::instance_per_thread!` macro][1] if you want to
67    /// maintain only one instance per thread and return shared references to it.
68    ///
69    /// [1]: [crate::instance_per_thread]
70    pub fn get(&self) -> T {
71        if let Some(instance) = self.new_from_local_registry() {
72            return instance;
73        }
74
75        // TODO: This global registry step feels too smeared out.
76        // Can we draw it together into one step under one lock?
77        self.try_initialize_global_registry(self.first_instance_provider);
78
79        let handle = self
80            .get_handle_global()
81            .expect("we just initialized it, the handle must exist");
82
83        self.set_local(handle);
84
85        // We can now be certain the local registry has the value.
86        self.new_from_local_registry()
87            .expect("we just set the value, it must be there")
88    }
89
90    fn set_local(&self, value: Handle<T>) {
91        LOCAL_REGISTRY.with(|local_registry| {
92            local_registry
93                .borrow_mut()
94                .insert((self.family_key_provider)(), Box::new(value));
95        });
96    }
97
98    fn get_handle_global(&self) -> Option<Handle<T>> {
99        GLOBAL_REGISTRY
100            .read()
101            .expect(ERR_POISONED_LOCK)
102            .get(&(self.family_key_provider)())
103            .and_then(|w| w.downcast_ref::<Handle<T>>())
104            .cloned()
105    }
106
107    /// Attempts to register the object family in the global registry (if not already registered).
108    ///
109    /// This may use the provided provider to create a "first" instance of `T` even if another
110    /// "first" instance has been created already - optimistic concurrency is used to throw
111    /// away the extra instance if this proves necessary. Type-level documented requirements
112    /// give us the right to do this - the "first" instances must be functionally equivalent.
113    ///
114    /// # Performance
115    ///
116    /// This is a slightly expensive operation, so should be avoided on the hot path. We only do
117    /// this once per family per thread.
118    fn try_initialize_global_registry<FIP>(&self, first_instance_provider: FIP)
119    where
120        FIP: FnOnce() -> T,
121    {
122        // We do not today make use of our right to create a "first" instance of `T` even when
123        // we do not need it. This is a potential future optimization if it proves valuable.
124
125        let mut global_registry = GLOBAL_REGISTRY.write().expect(ERR_POISONED_LOCK);
126
127        // TODO: We are repeatedly acquiring the family key here and in sibling functions.
128        // Perhaps a trivial cost but explore the value of eliminating the duplicate access.
129        let family_key = (self.family_key_provider)();
130        let entry = global_registry.entry(family_key);
131
132        match entry {
133            hash_map::Entry::Occupied(_) => (),
134            hash_map::Entry::Vacant(entry) => {
135                // TODO: We create an instance here, only to immediately transform it back to
136                // a handle. Can we skip the middle step and just create a handle?
137                let first_instance = first_instance_provider();
138                entry.insert(Box::new(first_instance.handle()));
139            }
140        }
141    }
142
143    // Attempts to obtain a new instance of `T` using the current thread's family registry,
144    // returning `None` if the linked variable has not yet been seen by this thread and is
145    // therefore not present in the local registry.
146    fn new_from_local_registry(&self) -> Option<T> {
147        LOCAL_REGISTRY.with_borrow(|registry| {
148            let family_key = (self.family_key_provider)();
149
150            registry
151                .get(&family_key)
152                .and_then(|w| w.downcast_ref::<Handle<T>>())
153                .map(|handle| handle.clone().into())
154        })
155    }
156}
157
158/// Declares that all static variables within the macro body contain [linked objects][crate],
159/// with each access to this variable returning a new instance from the same family.
160///
161/// Each [`.get()`][1] on an included static variable returns a new linked object instance,
162/// with all instances obtained from the same static variable being part of the same family.
163///
164/// # Dynamic family relationships
165///
166/// If you need `Arc`-style dynamic multithreaded storage (i.e. not a single static variable),
167/// pass instances of [`Handle<T>`][3] between threads instead of (or in addition to)
168/// using this macro. You can obtain a [`Handle<T>`][3] from any linked object via the
169/// [`.handle()`][4] method of the [`linked::Object` trait][5], even if the instance of the
170/// linked object originally came from a static variable.
171///
172/// # Example
173///
174/// ```
175/// # #[linked::object]
176/// # struct TokenCache { }
177/// # impl TokenCache { fn with_capacity(capacity: usize) -> Self { linked::new!(Self { } ) } fn get_token(&self) -> usize { 42 } }
178/// linked::instance_per_access!(static TOKEN_CACHE: TokenCache = TokenCache::with_capacity(1000));
179///
180/// fn do_something() {
181///     // `.get()` returns a unique instance of the linked object on every call.
182///     let token_cache = TOKEN_CACHE.get();
183///
184///     let token = token_cache.get_token();
185/// }
186/// ```
187///
188/// [1]: PerAccessStatic::get
189/// [3]: crate::Handle
190/// [4]: crate::Object::handle
191/// [5]: crate::Object
192#[macro_export]
193macro_rules! instance_per_access {
194    () => {};
195
196    ($(#[$attr:meta])* $vis:vis static $NAME:ident: $t:ty = $e:expr; $($rest:tt)*) => (
197        ::linked::instance_per_access!($(#[$attr])* $vis static $NAME: $t = $e);
198        ::linked::instance_per_access!($($rest)*);
199    );
200
201    ($(#[$attr:meta])* $vis:vis static $NAME:ident: $t:ty = $e:expr) => {
202        ::linked::__private::paste! {
203            #[doc(hidden)]
204            #[allow(non_camel_case_types)]
205            struct [<__lookup_key_ $NAME>];
206
207            $(#[$attr])* $vis const $NAME: ::linked::PerAccessStatic<$t> =
208                ::linked::PerAccessStatic::new(
209                    ::std::any::TypeId::of::<[<__lookup_key_ $NAME>]>,
210                    move || $e);
211        }
212    };
213}
214
215// We use HashedMap which takes the raw value from Hash::hash() and uses it directly as the key.
216// This is OK because TypeId already returns a hashed value as its raw value, no need to hash more.
217// We also do not care about any hash manipulation because none of this is untrusted user input.
218type HandleRegistry = HashedMap<TypeId, Box<dyn Any + Send + Sync>>;
219
220// Global registry that is the ultimate authority where all linked variables are registered.
221// The values inside are `Handle<T>` where T may be different for each entry.
222static GLOBAL_REGISTRY: LazyLock<RwLock<HandleRegistry>> =
223    LazyLock::new(|| RwLock::new(HandleRegistry::default()));
224
225thread_local! {
226    // Thread-local registry where we cache any linked variables that have been seen by the current
227    // thread. The values inside are `Handle<T>` where T may be different for each entry.
228    static LOCAL_REGISTRY: RefCell<HandleRegistry> = RefCell::new(HandleRegistry::default());
229}
230
231/// Clears all data stored in the linked variable system from the current thread's point of view.
232///
233/// This is intended for use in tests only. It is publicly exposed because it may need to be called
234/// from integration tests and benchmarks, which cannot access private functions.
235#[doc(hidden)]
236pub fn __private_clear_linked_variables() {
237    let mut global_registry = GLOBAL_REGISTRY.write().expect(ERR_POISONED_LOCK);
238    global_registry.clear();
239
240    LOCAL_REGISTRY.with(|local_registry| {
241        local_registry.borrow_mut().clear();
242    });
243}
244
245#[cfg(test)]
246mod tests {
247    use std::any::TypeId;
248    use std::rc::Rc;
249    use std::sync::{Arc, Mutex};
250    use std::thread;
251
252    use crate::PerAccessStatic;
253
254    #[linked::object]
255    struct TokenCache {
256        value: Arc<Mutex<usize>>,
257    }
258
259    impl TokenCache {
260        fn new(value: usize) -> Self {
261            let value = Arc::new(Mutex::new(value));
262
263            linked::new!(Self {
264                value: Arc::clone(&value),
265            })
266        }
267
268        fn value(&self) -> usize {
269            *self.value.lock().unwrap()
270        }
271
272        fn increment(&self) {
273            let mut writer = self.value.lock().unwrap();
274            *writer = writer.saturating_add(1);
275        }
276    }
277
278    #[test]
279    fn linked_lazy() {
280        // Here we test the inner logic of the variable! macro without applying the macro.
281        // There is a separate test for executing the same logic via the macro itself.
282        struct RedKey;
283        struct GreenKey;
284
285        const RED_TOKEN_CACHE: PerAccessStatic<TokenCache> =
286            PerAccessStatic::new(TypeId::of::<RedKey>, || TokenCache::new(42));
287        const GREEN_TOKEN_CACHE: PerAccessStatic<TokenCache> =
288            PerAccessStatic::new(TypeId::of::<GreenKey>, || TokenCache::new(99));
289
290        assert_eq!(RED_TOKEN_CACHE.get().value(), 42);
291        assert_eq!(GREEN_TOKEN_CACHE.get().value(), 99);
292
293        RED_TOKEN_CACHE.get().increment();
294        GREEN_TOKEN_CACHE.get().increment();
295
296        thread::spawn(move || {
297            assert_eq!(RED_TOKEN_CACHE.get().value(), 43);
298            assert_eq!(GREEN_TOKEN_CACHE.get().value(), 100);
299
300            RED_TOKEN_CACHE.get().increment();
301            GREEN_TOKEN_CACHE.get().increment();
302        })
303        .join()
304        .unwrap();
305
306        assert_eq!(RED_TOKEN_CACHE.get().value(), 44);
307        assert_eq!(GREEN_TOKEN_CACHE.get().value(), 101);
308    }
309
310    #[test]
311    fn linked_smoke_test() {
312        linked::instance_per_access! {
313            static BLUE_TOKEN_CACHE: TokenCache = TokenCache::new(1000);
314            static YELLOW_TOKEN_CACHE: TokenCache = TokenCache::new(2000);
315        }
316
317        assert_eq!(BLUE_TOKEN_CACHE.get().value(), 1000);
318        assert_eq!(YELLOW_TOKEN_CACHE.get().value(), 2000);
319
320        assert_eq!(BLUE_TOKEN_CACHE.get().clone().value(), 1000);
321        assert_eq!(YELLOW_TOKEN_CACHE.get().clone().value(), 2000);
322
323        BLUE_TOKEN_CACHE.get().increment();
324        YELLOW_TOKEN_CACHE.get().increment();
325
326        thread::spawn(move || {
327            assert_eq!(BLUE_TOKEN_CACHE.get().value(), 1001);
328            assert_eq!(YELLOW_TOKEN_CACHE.get().value(), 2001);
329
330            BLUE_TOKEN_CACHE.get().increment();
331            YELLOW_TOKEN_CACHE.get().increment();
332        })
333        .join()
334        .unwrap();
335
336        assert_eq!(BLUE_TOKEN_CACHE.get().value(), 1002);
337        assert_eq!(YELLOW_TOKEN_CACHE.get().value(), 2002);
338    }
339
340    #[test]
341    fn thread_local_from_linked() {
342        linked::instance_per_access!(static LINKED_CACHE: TokenCache = TokenCache::new(1000));
343        thread_local!(static LOCAL_CACHE: Rc<TokenCache> = Rc::new(LINKED_CACHE.get()));
344
345        let cache = LOCAL_CACHE.with(Rc::clone);
346        assert_eq!(cache.value(), 1000);
347        cache.increment();
348
349        thread::spawn(move || {
350            let cache = LOCAL_CACHE.with(Rc::clone);
351            assert_eq!(cache.value(), 1001);
352            cache.increment();
353            assert_eq!(cache.value(), 1002);
354        })
355        .join()
356        .unwrap();
357
358        assert_eq!(cache.value(), 1002);
359
360        let cache_rc_clone = LOCAL_CACHE.with(Rc::clone);
361        assert_eq!(cache_rc_clone.value(), 1002);
362    }
363}