state_department/
state.rs1use crate::{
2 lazy::LazyState,
3 manager::{StateManager, StateRef},
4 INITIALIZED,
5};
6use std::{
7 any::{Any, TypeId},
8 marker::PhantomData,
9};
10
11pub struct AnyContext;
14
15impl StateManager<AnyContext> {
16 pub const fn new() -> Self {
27 Self::new_()
28 }
29
30 #[must_use]
58 pub fn get<T: Send + Sync + 'static>(&self) -> StateRef<'_, T, AnyContext> {
59 match self.try_get() {
60 Some(v) => v,
61 None => panic!("State for {:?} not found", std::any::type_name::<T>()),
62 }
63 }
64
65 #[must_use]
93 pub fn try_get<T: Send + Sync + 'static>(&self) -> Option<StateRef<'_, T, AnyContext>> {
94 if self.initialized.load(std::sync::atomic::Ordering::Acquire) != INITIALIZED {
95 return None;
96 }
97
98 let state = unsafe { (*self.state.get()).assume_init_ref() }.upgrade()?;
99
100 let value: &T = state.get(&TypeId::of::<T>()).and_then(|v| {
101 let v = v.as_ref() as &dyn Any;
102
103 v.downcast_ref::<T>()
104 .or_else(|| v.downcast_ref::<LazyState<T>>().map(|v| v.get()))
105 })?;
106
107 Some(StateRef {
108 value,
109 _state: state,
110 _phantom: PhantomData,
111 })
112 }
113}
114impl Default for StateManager<AnyContext> {
115 fn default() -> Self {
116 Self::new()
117 }
118}
119
120#[test]
121fn test_state() {
122 use std::sync::atomic::AtomicU8;
123
124 let state = StateManager::<AnyContext>::new();
125
126 struct Foo {
127 bar: AtomicU8,
128 }
129
130 struct Baz {
131 qux: i32,
132 }
133
134 let lifetime = state.init(|state| {
135 state.insert(Foo {
136 bar: AtomicU8::new(42),
137 });
138
139 state.insert(Baz { qux: 24 });
140 });
141
142 {
143 let foo = state.get::<Foo>();
144
145 assert_eq!(foo.bar.load(std::sync::atomic::Ordering::Relaxed), 42);
146
147 foo.bar.store(24, std::sync::atomic::Ordering::Release);
148 }
149
150 {
151 let foo = state.get::<Foo>();
152
153 assert_eq!(foo.bar.load(std::sync::atomic::Ordering::Acquire), 24);
154 }
155
156 {
157 let baz = state.get::<Baz>();
158
159 assert_eq!(baz.qux, 24);
160 }
161
162 lifetime.try_drop().unwrap();
163}
164
165#[test]
166fn test_state_drop_with_ref() {
167 let state = StateManager::<AnyContext>::new();
168
169 struct Foo;
170
171 let lifetime = state.init(|state| {
172 state.insert(Foo);
173 });
174
175 let _foo = state.get::<Foo>();
176
177 let _ = lifetime.try_drop().unwrap_err();
178}
179
180#[test]
181fn test_state_use_after_lifetime_drop() {
182 let state = StateManager::<AnyContext>::new();
183
184 struct Foo;
185
186 let lifetime = state.init(|state| {
187 state.insert(Foo);
188 });
189
190 lifetime.try_drop().unwrap();
191
192 assert!(state.try_get::<Foo>().is_none());
193}
194
195#[test]
196fn test_state_drop_without_lifetime() {
197 use std::sync::atomic::AtomicU8;
198
199 static DROPPED: AtomicU8 = AtomicU8::new(0);
200
201 let state = StateManager::<AnyContext>::new();
202
203 struct Foo;
204 impl Drop for Foo {
205 fn drop(&mut self) {
206 DROPPED.store(1, std::sync::atomic::Ordering::Release);
207 }
208 }
209
210 let lifetime = state.init(|state| {
211 state.insert(Foo);
212 });
213
214 let foo = state.get::<Foo>();
215
216 assert_eq!(DROPPED.load(std::sync::atomic::Ordering::Acquire), 0);
217
218 drop(lifetime);
219
220 assert_eq!(DROPPED.load(std::sync::atomic::Ordering::Acquire), 0);
221
222 drop(foo);
223
224 assert_eq!(DROPPED.load(std::sync::atomic::Ordering::Acquire), 1);
225
226 drop(state);
227
228 assert_eq!(DROPPED.load(std::sync::atomic::Ordering::Acquire), 1);
229}
230
231#[test]
232fn test_lazy_initialization() {
233 use std::sync::atomic::AtomicU8;
234
235 static FOO_INITIALIZED: AtomicU8 = AtomicU8::new(0);
236
237 let state = StateManager::<AnyContext>::new();
238
239 struct Foo {
240 bar: i32,
241 }
242
243 let _lifetime = state.init(|state| {
244 state.insert_lazy(|| {
245 FOO_INITIALIZED.store(1, std::sync::atomic::Ordering::Release);
246
247 Foo { bar: 42 }
248 });
249 });
250
251 assert_eq!(
252 FOO_INITIALIZED.load(std::sync::atomic::Ordering::Acquire),
253 0
254 );
255
256 let foo = state.get::<Foo>();
257
258 assert_eq!(
259 FOO_INITIALIZED.load(std::sync::atomic::Ordering::Acquire),
260 1
261 );
262
263 assert_eq!(foo.bar, 42);
264}
265
266#[test]
267fn test_state_across_threads() {
268 use std::sync::atomic::AtomicU8;
269
270 static STATE: StateManager<AnyContext> = StateManager::<AnyContext>::new();
271
272 struct Foo {
273 bar: AtomicU8,
274 }
275
276 let _lifetime = STATE.init(|state| {
277 state.insert(Foo {
278 bar: AtomicU8::new(0),
279 });
280 });
281
282 let thread_count = 10;
283
284 let barrier = std::sync::Arc::new(std::sync::Barrier::new(thread_count));
285
286 let threads = (0..thread_count)
287 .map(|_| {
288 let barrier_ref = barrier.clone();
289
290 std::thread::spawn(move || {
291 barrier_ref.wait();
292
293 STATE
294 .get::<Foo>()
295 .bar
296 .fetch_add(1, std::sync::atomic::Ordering::Release);
297 })
298 })
299 .collect::<Vec<_>>();
300
301 for thread in threads {
302 thread.join().unwrap();
303 }
304
305 assert_eq!(
306 STATE
307 .get::<Foo>()
308 .bar
309 .load(std::sync::atomic::Ordering::Acquire),
310 thread_count as u8
311 );
312}
313
314#[test]
315#[should_panic = "State for \"()\" not found"]
316fn test_state_get_inside_init() {
317 let state = StateManager::<AnyContext>::new();
318 let _ = state.init(|r| {
319 r.insert(());
320
321 let _ = state.get::<()>();
322 });
323}
324
325#[test]
326fn test_state_get_inside_drop() {
327 static STATE: StateManager<AnyContext> = StateManager::<AnyContext>::new();
328
329 struct Foo {
330 bar: i32,
331 }
332 impl Drop for Foo {
333 fn drop(&mut self) {
334 assert!(STATE.try_get::<Foo>().is_none());
335 }
336 }
337
338 let state = STATE.init(|state| {
339 state.insert(Foo { bar: 42 });
340 });
341
342 let foo = STATE.get::<Foo>();
343
344 assert_eq!(foo.bar, 42);
345
346 drop(foo);
347
348 drop(state);
349}
350
351#[test]
352fn test_state_init_inside_drop() {
353 static STATE: StateManager<AnyContext> = StateManager::<AnyContext>::new();
354
355 struct Foo {
356 bar: i32,
357 }
358 impl Drop for Foo {
359 fn drop(&mut self) {
360 assert!(STATE.try_get::<Foo>().is_none());
361
362 let state = STATE.try_init(|state| {
363 state.insert(Foo { bar: 42 });
364
365 Ok::<_, ()>(())
366 });
367
368 assert!(state.is_none());
369
370 assert!(STATE.try_get::<Foo>().is_none());
371 }
372 }
373
374 let state = STATE.init(|state| {
375 state.insert(Foo { bar: 42 });
376 });
377
378 let foo = STATE.get::<Foo>();
379
380 assert_eq!(foo.bar, 42);
381
382 drop(foo);
383
384 drop(state);
385}
386
387#[test]
388#[should_panic = "State already initialized or is currently initializing"]
389fn test_state_init_inside_init() {
390 let state = StateManager::<AnyContext>::new();
391 let _ = state.init(|_| {
392 let _ = state.init(|_| {});
393 });
394}
395
396#[test]
397#[should_panic = "State already initialized or is currently initializing"]
398fn test_state_already_initialized() {
399 let state = StateManager::<AnyContext>::new();
400 let _ = state.init(|_| {});
401 let _ = state.init(|_| {});
402}