generic_global_variables/
lib.rs

1use core::any::{Any, TypeId};
2use core::fmt;
3use core::marker::PhantomData;
4use core::ops::Deref;
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use parking_lot::{RwLock, RwLockUpgradableReadGuard};
10
11/// ```
12/// use once_cell::sync::OnceCell;
13/// use generic_global_variables::*;
14///
15/// use std::thread::spawn;
16/// use std::sync::Mutex;
17///
18/// fn get_buffer<T: Send + Sync>(f: impl FnOnce() -> T) -> Entry<T> {
19///     static GLOBALS: OnceCell<GenericGlobal> = OnceCell::new();
20///
21///     let globals = GLOBALS.get_or_init(GenericGlobal::new);
22///     globals.get_or_init(f)
23/// }
24///
25/// let handles1: Vec<_> = (0..24).map(|_| {
26///     spawn(|| {
27///         let arc = get_buffer(Mutex::<Vec::<Box<[u8]>>>::default);
28///         let buffer = arc.lock()
29///             .unwrap()
30///             .pop()
31///             .unwrap_or_else(|| vec![0 as u8; 20].into_boxed_slice());
32///         // Read some data into buffer and process it
33///         // ...
34///
35///         arc.lock().unwrap().push(buffer);
36///     })
37/// }).collect();
38///
39/// let handles2: Vec<_> = (0..50).map(|_| {
40///     spawn(|| {
41///         let arc = get_buffer(Mutex::<Vec::<Box<[u32]>>>::default);
42///         let buffer = arc.lock()
43///             .unwrap()
44///             .pop()
45///             .unwrap_or_else(|| vec![1 as u32; 20].into_boxed_slice());
46///         // Read some data into buffer and process it
47///         // ...
48///
49///         arc.lock().unwrap().push(buffer);
50///     })
51/// }).collect();
52///
53/// for handle in handles1 {
54///     handle.join();
55/// }
56///
57/// for handle in handles2 {
58///     handle.join();
59/// }
60/// ```
61#[derive(Default, Debug)]
62pub struct GenericGlobal(RwLock<HashMap<TypeId, Arc<dyn Any>>>);
63
64impl GenericGlobal {
65    pub fn new() -> Self {
66        Self::default()
67    }
68
69    pub fn get_or_init<T: 'static + Send + Sync>(&self, f: impl FnOnce() -> T) -> Entry<T> {
70        let typeid = TypeId::of::<T>();
71
72        if let Some(val) = self.0.read().get(&typeid) {
73            return Entry::new(Arc::clone(val));
74        }
75
76        // Use an upgradable_read to check if the key has already
77        // been added by another thread.
78        //
79        // Unlike write guard, this UpgradableReadGuard only blocks
80        // other UpgradableReadGuard and WriteGuard, so the readers
81        // will not be blocked while ensuring that there is no other
82        // writer.
83        let guard = self.0.upgradable_read();
84
85        // If another writer has already added that typeid, return.
86        if let Some(val) = guard.get(&typeid) {
87            return Entry::new(Arc::clone(val));
88        }
89
90        // If no other writer has added that typeid, add one now.
91        let mut guard = RwLockUpgradableReadGuard::upgrade(guard);
92        let arc: Arc<dyn Any> = Arc::new(f());
93        let option = guard.insert(typeid, Arc::clone(&arc));
94
95        // There cannot be any other write that insert the key.
96        debug_assert!(option.is_none());
97
98        Entry::new(arc)
99    }
100}
101
102unsafe impl Send for GenericGlobal {}
103unsafe impl Sync for GenericGlobal {}
104
105/// A reference to the entry
106#[derive(Debug)]
107pub struct Entry<T: 'static>(Arc<dyn Any>, PhantomData<T>);
108
109unsafe impl<T: 'static + Send + Sync> Send for Entry<T> {}
110unsafe impl<T: 'static + Send + Sync> Sync for Entry<T> {}
111
112impl<T: 'static> Clone for Entry<T> {
113    fn clone(&self) -> Self {
114        Self::new(self.0.clone())
115    }
116}
117
118impl<T: 'static> Entry<T> {
119    fn new(arc: Arc<dyn Any>) -> Self {
120        Self(arc, PhantomData)
121    }
122}
123
124impl<T: 'static> Deref for Entry<T> {
125    type Target = T;
126
127    fn deref(&self) -> &Self::Target {
128        <dyn Any>::downcast_ref::<T>(&*self.0).unwrap()
129    }
130}
131
132impl<T: 'static + fmt::Display> fmt::Display for Entry<T> {
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        fmt::Display::fmt(self.deref(), f)
135    }
136}
137
138impl<T: 'static> fmt::Pointer for Entry<T> {
139    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140        fmt::Pointer::fmt(&self.0, f)
141    }
142}
143
144#[cfg(test)]
145mod tests {}