1use crate::stub::{Arc, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard};
2use hashbrown::HashMap;
3
4pub struct SharedStateMap<K, V> {
6 state: Mutex<Option<State<K, V>>>,
7}
8
9type State<K, V> = HashMap<K, Arc<RwLock<V>>>;
10
11pub struct SharedState<V> {
13 val: Arc<RwLock<V>>,
14}
15
16impl<V> SharedState<V> {
17 pub fn read(&self) -> RwLockReadGuard<'_, V> {
19 self.val.read().unwrap()
20 }
21
22 pub fn write(&self) -> RwLockWriteGuard<'_, V> {
24 self.val.write().unwrap()
25 }
26}
27
28impl<K, V> Default for SharedStateMap<K, V>
29where
30 K: core::hash::Hash + core::cmp::PartialEq + core::cmp::Eq,
31{
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl<K, V> SharedStateMap<K, V>
38where
39 K: core::hash::Hash + core::cmp::PartialEq + core::cmp::Eq,
40{
41 pub const fn new() -> Self {
43 Self {
44 state: Mutex::new(None),
45 }
46 }
47
48 pub fn get(&self, k: &K) -> Option<SharedState<V>> {
50 let mut state = self.state.lock().unwrap();
51 let map = get_or_init::<K, V>(&mut state);
52
53 match map.get(k) {
54 Some(val) => Some(SharedState { val: val.clone() }),
55 None => None,
56 }
57 }
58
59 pub fn get_or_init<Fn: FnMut(&K) -> V>(&self, k: &K, mut init: Fn) -> SharedState<V>
62 where
63 K: Clone,
64 {
65 let mut state = self.state.lock().unwrap();
66 let map = get_or_init::<K, V>(&mut state);
67
68 match map.get(k) {
69 Some(val) => SharedState { val: val.clone() },
70 None => {
71 let val = init(k);
72 let val = Arc::new(RwLock::new(val));
73 map.insert(k.clone(), val.clone());
74 SharedState { val: val.clone() }
75 }
76 }
77 }
78
79 pub fn insert(&self, k: K, v: V) {
81 let mut state = self.state.lock().unwrap();
82 let map = get_or_init::<K, V>(&mut state);
83
84 map.insert(k, Arc::new(RwLock::new(v)));
85 }
86
87 pub fn clear(&self) {
89 let mut state = self.state.lock().unwrap();
90 let map = get_or_init::<K, V>(&mut state);
91 map.clear();
92 }
93}
94
95fn get_or_init<K, V>(state: &mut Option<State<K, V>>) -> &mut State<K, V> {
96 match state {
97 Some(state) => state,
98 None => {
99 *state = Some(State::<K, V>::default());
100 state.as_mut().unwrap()
101 }
102 }
103}