Skip to main content

node_flow/context/storage/shared_storage/
implementation.rs

1use std::{
2    any::{Any, TypeId},
3    collections::{HashMap, hash_map::Entry},
4    fmt::Debug,
5    ops::{Deref, DerefMut},
6    sync::{Arc, Mutex},
7};
8
9use async_lock::RwLock;
10use futures_util::FutureExt;
11
12use crate::context::{Fork, Join, Update, storage::shared_storage::SharedStorage};
13
14type StorageItem = Arc<RwLock<Option<Box<dyn Any + Send + Sync>>>>;
15
16/// An implementation of type-based shared storage.
17///
18/// See [`SharedStorage`] for more info.\
19/// See also [`Fork`], [`Update`], [`Join`].
20///
21/// # Internal Structure
22/// Item entries are stored in `Arc<std::sync::Mutex<HashMap<TypeId, StorageItem>>>` which is cloned in [`Fork`] and shared between branches.\
23/// `StorageItem` has type `Arc<async_lock::RwLock<...>>` which allows for per entry locking without holding a lock for the entire `HashMap`.
24#[derive(Default, Clone)]
25pub struct SharedStorageImpl {
26    inner: Arc<Mutex<HashMap<TypeId, StorageItem>>>,
27}
28
29impl Debug for SharedStorageImpl {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("SharedStorageImpl").finish_non_exhaustive()
32    }
33}
34
35impl SharedStorageImpl {
36    /// Constructs new `SharedStorageImpl`.
37    #[must_use]
38    pub fn new() -> Self {
39        Self::default()
40    }
41}
42
43impl SharedStorage for SharedStorageImpl {
44    fn get<T>(&self) -> impl Future<Output = Option<impl Deref<Target = T>>> + Send
45    where
46        T: 'static,
47    {
48        let rw_lock = {
49            let guard = self.inner.lock().unwrap();
50            guard.get(&TypeId::of::<T>()).cloned()
51        };
52
53        async move {
54            let rw_lock = rw_lock?;
55            let rw_lock_guard = rw_lock.read_arc().await;
56            if rw_lock_guard.is_none() {
57                return None;
58            }
59            let read_guard = guards::ReadGuard {
60                guard: rw_lock_guard,
61                _item_type: std::marker::PhantomData,
62            };
63
64            Some(read_guard)
65        }
66    }
67
68    fn get_mut<T>(&mut self) -> impl Future<Output = Option<impl DerefMut<Target = T>>> + Send
69    where
70        T: 'static,
71    {
72        let rw_lock = {
73            let guard = self.inner.lock().unwrap();
74            guard.get(&TypeId::of::<T>()).cloned()
75        };
76
77        async move {
78            let rw_lock = rw_lock?;
79            let rw_lock_guard = rw_lock.write_arc().await;
80            if rw_lock_guard.is_none() {
81                return None;
82            }
83            let write_guard = guards::WriteGuard {
84                guard: rw_lock_guard,
85                _item_type: std::marker::PhantomData,
86            };
87
88            Some(write_guard)
89        }
90    }
91
92    fn insert<T>(&mut self, val: T) -> impl Future<Output = Option<T>> + Send
93    where
94        T: Send + Sync + 'static,
95    {
96        let rw_lock = {
97            let mut guard = self.inner.lock().unwrap();
98            match guard.entry(TypeId::of::<T>()) {
99                Entry::Occupied(occupied_entry) => occupied_entry.get().clone(),
100                Entry::Vacant(vacant_entry) => {
101                    vacant_entry.insert(Arc::new(RwLock::new(Some(Box::new(val)))));
102                    return futures_util::future::ready(None).left_future();
103                }
104            }
105        };
106
107        async move {
108            let val = rw_lock.write().await.replace(Box::new(val))?;
109            let val = *val.downcast::<T>().unwrap();
110            Some(val)
111        }
112        .right_future()
113    }
114
115    fn insert_with_if_absent<T, E>(
116        &self,
117        fut: impl Future<Output = Result<T, E>> + Send,
118    ) -> impl Future<Output = Result<(), E>> + Send
119    where
120        T: Send + Sync + 'static,
121        E: Send,
122    {
123        let mut guard = self.inner.lock().unwrap();
124        match guard.entry(TypeId::of::<T>()) {
125            Entry::Occupied(_) => futures_util::future::ready(Ok(())).left_future(),
126            Entry::Vacant(vacant_entry) => {
127                let rw_lock = Arc::new(RwLock::new(None));
128                let mut rw_lock_guard = rw_lock.write_arc_blocking();
129                vacant_entry.insert(rw_lock);
130                async move {
131                    let val = fut.await?;
132                    *rw_lock_guard = Some(Box::new(val));
133                    Ok(())
134                }
135                .right_future()
136            }
137        }
138    }
139
140    fn remove<T>(&mut self) -> impl Future<Output = Option<T>> + Send
141    where
142        T: 'static,
143    {
144        let rw_lock = {
145            let mut guard = self.inner.lock().unwrap();
146            guard.remove(&TypeId::of::<T>())
147        };
148
149        async move {
150            let rw_lock = rw_lock?;
151            let val = rw_lock.write().await.take()?;
152            let val = *val.downcast::<T>().unwrap();
153            Some(val)
154        }
155    }
156}
157
158impl Fork for SharedStorageImpl {
159    fn fork(&self) -> Self {
160        self.clone()
161    }
162}
163
164impl Update for SharedStorageImpl {
165    fn update_from(&mut self, _other: Self) {}
166}
167
168impl Join for SharedStorageImpl {
169    fn join(&mut self, _others: Box<[Self]>) {}
170}
171
172mod guards {
173    use std::{
174        any::Any,
175        ops::{Deref, DerefMut},
176    };
177
178    pub struct ReadGuard<T: 'static> {
179        pub guard: async_lock::RwLockReadGuardArc<Option<Box<dyn Any + Send + Sync>>>,
180        pub _item_type: std::marker::PhantomData<T>,
181    }
182
183    impl<T: 'static> Deref for ReadGuard<T> {
184        type Target = T;
185
186        fn deref(&self) -> &Self::Target {
187            let any_ref: &dyn Any = &**self.guard.as_ref().unwrap();
188            any_ref.downcast_ref::<T>().unwrap()
189        }
190    }
191
192    pub struct WriteGuard<T: 'static> {
193        pub guard: async_lock::RwLockWriteGuardArc<Option<Box<dyn Any + Send + Sync>>>,
194        pub _item_type: std::marker::PhantomData<T>,
195    }
196
197    impl<T: 'static> Deref for WriteGuard<T> {
198        type Target = T;
199
200        fn deref(&self) -> &Self::Target {
201            let any_ref: &dyn Any = &**self.guard.as_ref().unwrap();
202            any_ref.downcast_ref::<T>().unwrap()
203        }
204    }
205
206    impl<T: 'static> DerefMut for WriteGuard<T> {
207        fn deref_mut(&mut self) -> &mut Self::Target {
208            let any_ref: &mut dyn Any = &mut **self.guard.as_mut().unwrap();
209            any_ref.downcast_mut::<T>().unwrap()
210        }
211    }
212}
213
214#[cfg(test)]
215#[doc(hidden)]
216pub mod tests {
217    use super::*;
218
219    #[derive(Debug, Clone, PartialEq, Eq)]
220    #[allow(dead_code)]
221    pub struct MyVal(pub String);
222
223    impl Default for MyVal {
224        fn default() -> Self {
225            Self("|".to_owned())
226        }
227    }
228
229    #[tokio::test]
230    async fn works() {
231        let mut s = SharedStorageImpl::new();
232        let _ = s.insert(MyVal("test".into()));
233        //println!("{s:#?}");
234        let v = s.get::<MyVal>().await;
235        assert!(v.is_some());
236        assert_eq!(v.unwrap().0, "test".to_string());
237
238        let v = s.get_mut::<MyVal>().await;
239        assert!(v.is_some());
240        assert_eq!(v.as_ref().unwrap().0, "test".to_string());
241        *v.unwrap() = MyVal("hmm".into());
242
243        let v = s.insert(MyVal("jop".into())).await;
244        assert!(v.is_some());
245        assert_eq!(v.unwrap().0, "hmm".to_string());
246
247        let v = s.remove::<MyVal>().await;
248        assert!(v.is_some());
249        assert_eq!(v.unwrap().0, "jop".to_string());
250    }
251
252    #[tokio::test]
253    async fn test_merge() {
254        let mut parent = SharedStorageImpl::new();
255        let mut child1 = parent.fork();
256        let _ = child1.insert(MyVal("bbb".to_owned())).await;
257        let mut child2 = parent.fork();
258        let _ = child2.insert(MyVal("ccc".to_owned())).await;
259        let mut child3 = parent.fork();
260        let _ = child3.insert(MyVal("ddd".to_owned())).await;
261        parent.join(Box::new([child1, child2, child3]));
262        let mut child = parent.fork();
263        let _ = child.insert(MyVal("aaa".to_owned())).await;
264        parent.join(Box::new([child]));
265
266        let res = parent.get::<MyVal>().await;
267        assert_eq!(res.unwrap().0, "aaa".to_owned());
268    }
269}