state_department/
state.rs

1use crate::{
2    lazy::LazyState,
3    manager::{StateManager, StateRef},
4    INITIALIZED,
5};
6use std::{
7    any::{Any, TypeId},
8    marker::PhantomData,
9};
10
11/// A type bound for types that can be used in both synchronous and asynchronous
12/// contexts.
13pub struct AnyContext;
14
15impl StateManager<AnyContext> {
16    /// Creates a new [`StateManager`] appropriate for both synchronous and
17    /// asynchronous contexts.
18    ///
19    /// # Example
20    ///
21    /// ```rust
22    /// use state_department::State;
23    ///
24    /// static STATE: State = State::new();
25    /// ```
26    pub const fn new() -> Self {
27        Self::new_()
28    }
29
30    /// Returns a reference to a value stored in the state.
31    ///
32    /// # Panics
33    ///
34    /// * If the state has not yet been initialized.
35    /// * If the state has been dropped.
36    /// * If the state does not contain a value of the requested type.
37    ///
38    /// # Example
39    ///
40    /// ```rust
41    /// use state_department::State;
42    ///
43    /// static STATE: State = State::new();
44    ///
45    /// struct Foo {
46    ///     bar: i32
47    /// }
48    ///
49    /// let _lifetime = STATE.init(|state| {
50    ///     state.insert(Foo { bar: 42 });
51    /// });
52    ///
53    /// let foo = STATE.get::<Foo>();
54    ///
55    /// assert_eq!(foo.bar, 42);
56    /// ```
57    #[must_use]
58    pub fn get<T: Send + Sync + 'static>(&self) -> StateRef<'_, T, AnyContext> {
59        match self.try_get() {
60            Some(v) => v,
61            None => panic!("State for {:?} not found", std::any::type_name::<T>()),
62        }
63    }
64
65    /// Attempts to get a reference to a value stored in the state.
66    ///
67    /// This function does not panic.
68    ///
69    /// # Example
70    ///
71    /// ```rust
72    /// use state_department::State;
73    ///
74    /// static STATE: State = State::new();
75    ///
76    /// struct Foo {
77    ///     bar: i32
78    /// }
79    ///
80    /// let _lifetime = STATE.init(|state| {
81    ///     state.insert(Foo { bar: 42 });
82    /// });
83    ///
84    /// let foo = STATE.try_get::<Foo>();
85    ///
86    /// assert_eq!(foo.unwrap().bar, 42);
87    ///
88    /// let str = STATE.try_get::<String>();
89    ///
90    /// assert!(str.is_none());
91    /// ```
92    #[must_use]
93    pub fn try_get<T: Send + Sync + 'static>(&self) -> Option<StateRef<'_, T, AnyContext>> {
94        if self.initialized.load(std::sync::atomic::Ordering::Acquire) != INITIALIZED {
95            return None;
96        }
97
98        let state = unsafe { (*self.state.get()).assume_init_ref() }.upgrade()?;
99
100        let value: &T = state.get(&TypeId::of::<T>()).and_then(|v| {
101            let v = v.as_ref() as &dyn Any;
102
103            v.downcast_ref::<T>()
104                .or_else(|| v.downcast_ref::<LazyState<T>>().map(|v| v.get()))
105        })?;
106
107        Some(StateRef {
108            value,
109            _state: state,
110            _phantom: PhantomData,
111        })
112    }
113}
114impl Default for StateManager<AnyContext> {
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120#[test]
121fn test_state() {
122    use std::sync::atomic::AtomicU8;
123
124    let state = StateManager::<AnyContext>::new();
125
126    struct Foo {
127        bar: AtomicU8,
128    }
129
130    struct Baz {
131        qux: i32,
132    }
133
134    let lifetime = state.init(|state| {
135        state.insert(Foo {
136            bar: AtomicU8::new(42),
137        });
138
139        state.insert(Baz { qux: 24 });
140    });
141
142    {
143        let foo = state.get::<Foo>();
144
145        assert_eq!(foo.bar.load(std::sync::atomic::Ordering::Relaxed), 42);
146
147        foo.bar.store(24, std::sync::atomic::Ordering::Release);
148    }
149
150    {
151        let foo = state.get::<Foo>();
152
153        assert_eq!(foo.bar.load(std::sync::atomic::Ordering::Acquire), 24);
154    }
155
156    {
157        let baz = state.get::<Baz>();
158
159        assert_eq!(baz.qux, 24);
160    }
161
162    lifetime.try_drop().unwrap();
163}
164
165#[test]
166fn test_state_drop_with_ref() {
167    let state = StateManager::<AnyContext>::new();
168
169    struct Foo;
170
171    let lifetime = state.init(|state| {
172        state.insert(Foo);
173    });
174
175    let _foo = state.get::<Foo>();
176
177    let _ = lifetime.try_drop().unwrap_err();
178}
179
180#[test]
181fn test_state_use_after_lifetime_drop() {
182    let state = StateManager::<AnyContext>::new();
183
184    struct Foo;
185
186    let lifetime = state.init(|state| {
187        state.insert(Foo);
188    });
189
190    lifetime.try_drop().unwrap();
191
192    assert!(state.try_get::<Foo>().is_none());
193}
194
195#[test]
196fn test_state_drop_without_lifetime() {
197    use std::sync::atomic::AtomicU8;
198
199    static DROPPED: AtomicU8 = AtomicU8::new(0);
200
201    let state = StateManager::<AnyContext>::new();
202
203    struct Foo;
204    impl Drop for Foo {
205        fn drop(&mut self) {
206            DROPPED.store(1, std::sync::atomic::Ordering::Release);
207        }
208    }
209
210    let lifetime = state.init(|state| {
211        state.insert(Foo);
212    });
213
214    let foo = state.get::<Foo>();
215
216    assert_eq!(DROPPED.load(std::sync::atomic::Ordering::Acquire), 0);
217
218    drop(lifetime);
219
220    assert_eq!(DROPPED.load(std::sync::atomic::Ordering::Acquire), 0);
221
222    drop(foo);
223
224    assert_eq!(DROPPED.load(std::sync::atomic::Ordering::Acquire), 1);
225
226    drop(state);
227
228    assert_eq!(DROPPED.load(std::sync::atomic::Ordering::Acquire), 1);
229}
230
231#[test]
232fn test_lazy_initialization() {
233    use std::sync::atomic::AtomicU8;
234
235    static FOO_INITIALIZED: AtomicU8 = AtomicU8::new(0);
236
237    let state = StateManager::<AnyContext>::new();
238
239    struct Foo {
240        bar: i32,
241    }
242
243    let _lifetime = state.init(|state| {
244        state.insert_lazy(|| {
245            FOO_INITIALIZED.store(1, std::sync::atomic::Ordering::Release);
246
247            Foo { bar: 42 }
248        });
249    });
250
251    assert_eq!(
252        FOO_INITIALIZED.load(std::sync::atomic::Ordering::Acquire),
253        0
254    );
255
256    let foo = state.get::<Foo>();
257
258    assert_eq!(
259        FOO_INITIALIZED.load(std::sync::atomic::Ordering::Acquire),
260        1
261    );
262
263    assert_eq!(foo.bar, 42);
264}
265
266#[test]
267fn test_state_across_threads() {
268    use std::sync::atomic::AtomicU8;
269
270    static STATE: StateManager<AnyContext> = StateManager::<AnyContext>::new();
271
272    struct Foo {
273        bar: AtomicU8,
274    }
275
276    let _lifetime = STATE.init(|state| {
277        state.insert(Foo {
278            bar: AtomicU8::new(0),
279        });
280    });
281
282    let thread_count = 10;
283
284    let barrier = std::sync::Arc::new(std::sync::Barrier::new(thread_count));
285
286    let threads = (0..thread_count)
287        .map(|_| {
288            let barrier_ref = barrier.clone();
289
290            std::thread::spawn(move || {
291                barrier_ref.wait();
292
293                STATE
294                    .get::<Foo>()
295                    .bar
296                    .fetch_add(1, std::sync::atomic::Ordering::Release);
297            })
298        })
299        .collect::<Vec<_>>();
300
301    for thread in threads {
302        thread.join().unwrap();
303    }
304
305    assert_eq!(
306        STATE
307            .get::<Foo>()
308            .bar
309            .load(std::sync::atomic::Ordering::Acquire),
310        thread_count as u8
311    );
312}
313
314#[test]
315#[should_panic = "State for \"()\" not found"]
316fn test_state_get_inside_init() {
317    let state = StateManager::<AnyContext>::new();
318    let _ = state.init(|r| {
319        r.insert(());
320
321        let _ = state.get::<()>();
322    });
323}
324
325#[test]
326fn test_state_get_inside_drop() {
327    static STATE: StateManager<AnyContext> = StateManager::<AnyContext>::new();
328
329    struct Foo {
330        bar: i32,
331    }
332    impl Drop for Foo {
333        fn drop(&mut self) {
334            assert!(STATE.try_get::<Foo>().is_none());
335        }
336    }
337
338    let state = STATE.init(|state| {
339        state.insert(Foo { bar: 42 });
340    });
341
342    let foo = STATE.get::<Foo>();
343
344    assert_eq!(foo.bar, 42);
345
346    drop(foo);
347
348    drop(state);
349}
350
351#[test]
352fn test_state_init_inside_drop() {
353    static STATE: StateManager<AnyContext> = StateManager::<AnyContext>::new();
354
355    struct Foo {
356        bar: i32,
357    }
358    impl Drop for Foo {
359        fn drop(&mut self) {
360            assert!(STATE.try_get::<Foo>().is_none());
361
362            let state = STATE.try_init(|state| {
363                state.insert(Foo { bar: 42 });
364
365                Ok::<_, ()>(())
366            });
367
368            assert!(state.is_none());
369
370            assert!(STATE.try_get::<Foo>().is_none());
371        }
372    }
373
374    let state = STATE.init(|state| {
375        state.insert(Foo { bar: 42 });
376    });
377
378    let foo = STATE.get::<Foo>();
379
380    assert_eq!(foo.bar, 42);
381
382    drop(foo);
383
384    drop(state);
385}
386
387#[test]
388#[should_panic = "State already initialized or is currently initializing"]
389fn test_state_init_inside_init() {
390    let state = StateManager::<AnyContext>::new();
391    let _ = state.init(|_| {
392        let _ = state.init(|_| {});
393    });
394}
395
396#[test]
397#[should_panic = "State already initialized or is currently initializing"]
398fn test_state_already_initialized() {
399    let state = StateManager::<AnyContext>::new();
400    let _ = state.init(|_| {});
401    let _ = state.init(|_| {});
402}