global_ref/
lib.rs

1use once_cell::sync::OnceCell;
2use std::{fmt, marker::PhantomData, sync::Mutex};
3
4/// A cell to store **immutable** reference.
5///
6/// # Safety
7/// Because the implementation internally converts raw pointers to usize and shares them between threads,
8/// fetching references is essentially unsafe. Please verify its safety before using it.
9pub struct GlobalRef<T> {
10    inner: OnceCell<Mutex<Option<usize>>>,
11    _marker: PhantomData<T>,
12}
13
14impl<T: fmt::Debug> fmt::Debug for GlobalRef<T> {
15    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16        f.debug_tuple("GlobalRef").finish()
17    }
18}
19
20impl<T> Default for GlobalRef<T> {
21    fn default() -> Self {
22        GlobalRef::new()
23    }
24}
25
26impl<T> GlobalRef<T> {
27    /// Create a new instance.
28    pub const fn new() -> Self {
29        GlobalRef {
30            inner: OnceCell::new(),
31            _marker: PhantomData,
32        }
33    }
34
35    /// Set a reference so that other functions can obtain it through GlobalRef.
36    /// It is recommended to use `with()` instead.
37    /// **Be sure to call `clear()` after it is used.**
38    pub unsafe fn set(&self, item: &T) {
39        let mutex = self.inner.get_or_init(|| None.into());
40        mutex.lock().unwrap().replace(item as *const T as usize);
41    }
42
43    /// Clear the registered reference.
44    pub fn clear(&self) {
45        let mutex = self.inner.get_or_init(|| None.into());
46        *mutex.lock().unwrap() = None;
47    }
48
49    /// Set a reference and clear the reference after calling the given closure.
50    pub fn with<F, R>(&self, item: &T, f: F) -> R
51    where
52        F: FnOnce() -> R,
53    {
54        unsafe {
55            self.set(item);
56        }
57        let res = f();
58        self.clear();
59        res
60    }
61
62    /// Get a immutable reference. Panics if `set()` or `with()` has not been called before.
63    pub fn get(&self) -> &T {
64        self.try_get().expect("Call set() before calling get()!")
65    }
66
67    /// Get a immutable reference. Returns None if `set()` or `with()` has not been called before.
68    pub fn try_get(&self) -> Option<&T> {
69        let inner = self.inner.get()?.lock().unwrap();
70        unsafe { inner.and_then(|p| (p as *const T).as_ref()) }
71    }
72}
73
74/// A cell to store **mutable** reference.
75///
76/// # Safety
77/// Because the implementation internally converts raw pointers to usize and shares them between threads,
78/// fetching references is essentially unsafe. Please verify its safety before using it.
79pub struct GlobalMut<T> {
80    inner: OnceCell<Mutex<Option<usize>>>,
81    _marker: PhantomData<T>,
82}
83
84impl<T: fmt::Debug> fmt::Debug for GlobalMut<T> {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        f.debug_tuple("GlobalMut").finish()
87    }
88}
89
90impl<T> Default for GlobalMut<T> {
91    fn default() -> Self {
92        GlobalMut::new()
93    }
94}
95
96impl<T> GlobalMut<T> {
97    /// Create a new instance.
98    pub const fn new() -> Self {
99        GlobalMut {
100            inner: OnceCell::new(),
101            _marker: PhantomData,
102        }
103    }
104
105    /// Set a reference so that other functions can obtain it through GlobalMut.
106    /// It is recommended to use `with()` instead.
107    /// **Be sure to call `clear()` after it is used.**
108    pub unsafe fn set(&self, item: &mut T) {
109        let mutex = self.inner.get_or_init(|| None.into());
110        mutex.lock().unwrap().replace(item as *mut T as usize);
111    }
112
113    /// Clear the registered reference.
114    pub fn clear(&self) {
115        let mutex = self.inner.get_or_init(|| None.into());
116        *mutex.lock().unwrap() = None;
117    }
118
119    /// Set a reference and clear the reference after calling the given closure.
120    pub fn with<F, R>(&self, item: &mut T, f: F) -> R
121    where
122        F: FnOnce() -> R,
123    {
124        unsafe {
125            self.set(item);
126        }
127        let res = f();
128        self.clear();
129        res
130    }
131
132    /// Get a immutable reference. Panics if `set()` or `with()` has not been called before.
133    pub fn get(&self) -> &T {
134        self.try_get().expect("Call set() before calling get()!")
135    }
136
137    /// Get a immutable reference. Returns None if `set()` or `with()` has not been called before.
138    pub fn try_get(&self) -> Option<&T> {
139        let inner = self.inner.get()?.lock().unwrap();
140        unsafe { inner.and_then(|p| (p as *mut T).as_ref()) }
141    }
142
143    /// Get a mutable reference. Panics if `set()` or `with()` has not been called before.
144    pub fn get_mut(&self) -> &mut T {
145        self.try_get_mut()
146            .expect("Call set() before calling get_mut()!")
147    }
148
149    /// Get a mutable reference. Returns None if `set()` or `with()` has not been called before.
150    pub fn try_get_mut(&self) -> Option<&mut T> {
151        let inner = self.inner.get()?.lock().unwrap();
152        unsafe { inner.and_then(|p| (p as *mut T).as_mut()) }
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use std::thread;
160
161    #[test]
162    fn global_ref() {
163        static GLOBAL: GlobalRef<i32> = GlobalRef::new();
164
165        let content = -1;
166        unsafe {
167            GLOBAL.set(&content);
168        }
169        assert_eq!(GLOBAL.get().abs(), 1);
170        GLOBAL.clear();
171        assert!(GLOBAL.try_get().is_none());
172    }
173
174    #[test]
175    fn global_mut() {
176        static GLOBAL: GlobalMut<i32> = GlobalMut::new();
177
178        let mut content = 0;
179        unsafe {
180            GLOBAL.set(&mut content);
181        }
182        *GLOBAL.get_mut() += 1;
183        assert_eq!(*GLOBAL.get(), 1);
184        GLOBAL.clear();
185        assert!(GLOBAL.try_get().is_none());
186    }
187
188    #[test]
189    fn multi_thread() {
190        static GLOBAL: GlobalMut<i32> = GlobalMut::new();
191
192        let mut content = 0;
193
194        GLOBAL.with(&mut content, || {
195            fn add_one() {
196                *GLOBAL.get_mut() += 1;
197            }
198
199            let handle = thread::spawn(add_one);
200            handle.join().unwrap();
201            assert_eq!(*GLOBAL.get(), 1);
202        });
203
204        assert!(GLOBAL.try_get().is_none());
205    }
206}