node_flow/context/storage/shared_storage/
implementation.rs1use std::{
2 any::{Any, TypeId},
3 collections::{HashMap, hash_map::Entry},
4 fmt::Debug,
5 ops::{Deref, DerefMut},
6 sync::{Arc, Mutex},
7};
8
9use async_lock::RwLock;
10use futures_util::FutureExt;
11
12use crate::context::{Fork, Join, Update, storage::shared_storage::SharedStorage};
13
14type StorageItem = Arc<RwLock<Option<Box<dyn Any + Send + Sync>>>>;
15
16#[derive(Default, Clone)]
25pub struct SharedStorageImpl {
26 inner: Arc<Mutex<HashMap<TypeId, StorageItem>>>,
27}
28
29impl Debug for SharedStorageImpl {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 f.debug_struct("SharedStorageImpl").finish_non_exhaustive()
32 }
33}
34
35impl SharedStorageImpl {
36 #[must_use]
38 pub fn new() -> Self {
39 Self::default()
40 }
41}
42
43impl SharedStorage for SharedStorageImpl {
44 fn get<T>(&self) -> impl Future<Output = Option<impl Deref<Target = T>>> + Send
45 where
46 T: 'static,
47 {
48 let rw_lock = {
49 let guard = self.inner.lock().unwrap();
50 guard.get(&TypeId::of::<T>()).cloned()
51 };
52
53 async move {
54 let rw_lock = rw_lock?;
55 let rw_lock_guard = rw_lock.read_arc().await;
56 if rw_lock_guard.is_none() {
57 return None;
58 }
59 let read_guard = guards::ReadGuard {
60 guard: rw_lock_guard,
61 _item_type: std::marker::PhantomData,
62 };
63
64 Some(read_guard)
65 }
66 }
67
68 fn get_mut<T>(&mut self) -> impl Future<Output = Option<impl DerefMut<Target = T>>> + Send
69 where
70 T: 'static,
71 {
72 let rw_lock = {
73 let guard = self.inner.lock().unwrap();
74 guard.get(&TypeId::of::<T>()).cloned()
75 };
76
77 async move {
78 let rw_lock = rw_lock?;
79 let rw_lock_guard = rw_lock.write_arc().await;
80 if rw_lock_guard.is_none() {
81 return None;
82 }
83 let write_guard = guards::WriteGuard {
84 guard: rw_lock_guard,
85 _item_type: std::marker::PhantomData,
86 };
87
88 Some(write_guard)
89 }
90 }
91
92 fn insert<T>(&mut self, val: T) -> impl Future<Output = Option<T>> + Send
93 where
94 T: Send + Sync + 'static,
95 {
96 let rw_lock = {
97 let mut guard = self.inner.lock().unwrap();
98 match guard.entry(TypeId::of::<T>()) {
99 Entry::Occupied(occupied_entry) => occupied_entry.get().clone(),
100 Entry::Vacant(vacant_entry) => {
101 vacant_entry.insert(Arc::new(RwLock::new(Some(Box::new(val)))));
102 return futures_util::future::ready(None).left_future();
103 }
104 }
105 };
106
107 async move {
108 let val = rw_lock.write().await.replace(Box::new(val))?;
109 let val = *val.downcast::<T>().unwrap();
110 Some(val)
111 }
112 .right_future()
113 }
114
115 fn insert_with_if_absent<T, E>(
116 &self,
117 fut: impl Future<Output = Result<T, E>> + Send,
118 ) -> impl Future<Output = Result<(), E>> + Send
119 where
120 T: Send + Sync + 'static,
121 E: Send,
122 {
123 let mut guard = self.inner.lock().unwrap();
124 match guard.entry(TypeId::of::<T>()) {
125 Entry::Occupied(_) => futures_util::future::ready(Ok(())).left_future(),
126 Entry::Vacant(vacant_entry) => {
127 let rw_lock = Arc::new(RwLock::new(None));
128 let mut rw_lock_guard = rw_lock.write_arc_blocking();
129 vacant_entry.insert(rw_lock);
130 async move {
131 let val = fut.await?;
132 *rw_lock_guard = Some(Box::new(val));
133 Ok(())
134 }
135 .right_future()
136 }
137 }
138 }
139
140 fn remove<T>(&mut self) -> impl Future<Output = Option<T>> + Send
141 where
142 T: 'static,
143 {
144 let rw_lock = {
145 let mut guard = self.inner.lock().unwrap();
146 guard.remove(&TypeId::of::<T>())
147 };
148
149 async move {
150 let rw_lock = rw_lock?;
151 let val = rw_lock.write().await.take()?;
152 let val = *val.downcast::<T>().unwrap();
153 Some(val)
154 }
155 }
156}
157
158impl Fork for SharedStorageImpl {
159 fn fork(&self) -> Self {
160 self.clone()
161 }
162}
163
164impl Update for SharedStorageImpl {
165 fn update_from(&mut self, _other: Self) {}
166}
167
168impl Join for SharedStorageImpl {
169 fn join(&mut self, _others: Box<[Self]>) {}
170}
171
172mod guards {
173 use std::{
174 any::Any,
175 ops::{Deref, DerefMut},
176 };
177
178 pub struct ReadGuard<T: 'static> {
179 pub guard: async_lock::RwLockReadGuardArc<Option<Box<dyn Any + Send + Sync>>>,
180 pub _item_type: std::marker::PhantomData<T>,
181 }
182
183 impl<T: 'static> Deref for ReadGuard<T> {
184 type Target = T;
185
186 fn deref(&self) -> &Self::Target {
187 let any_ref: &dyn Any = &**self.guard.as_ref().unwrap();
188 any_ref.downcast_ref::<T>().unwrap()
189 }
190 }
191
192 pub struct WriteGuard<T: 'static> {
193 pub guard: async_lock::RwLockWriteGuardArc<Option<Box<dyn Any + Send + Sync>>>,
194 pub _item_type: std::marker::PhantomData<T>,
195 }
196
197 impl<T: 'static> Deref for WriteGuard<T> {
198 type Target = T;
199
200 fn deref(&self) -> &Self::Target {
201 let any_ref: &dyn Any = &**self.guard.as_ref().unwrap();
202 any_ref.downcast_ref::<T>().unwrap()
203 }
204 }
205
206 impl<T: 'static> DerefMut for WriteGuard<T> {
207 fn deref_mut(&mut self) -> &mut Self::Target {
208 let any_ref: &mut dyn Any = &mut **self.guard.as_mut().unwrap();
209 any_ref.downcast_mut::<T>().unwrap()
210 }
211 }
212}
213
214#[cfg(test)]
215#[doc(hidden)]
216pub mod tests {
217 use super::*;
218
219 #[derive(Debug, Clone, PartialEq, Eq)]
220 #[allow(dead_code)]
221 pub struct MyVal(pub String);
222
223 impl Default for MyVal {
224 fn default() -> Self {
225 Self("|".to_owned())
226 }
227 }
228
229 #[tokio::test]
230 async fn works() {
231 let mut s = SharedStorageImpl::new();
232 let _ = s.insert(MyVal("test".into()));
233 let v = s.get::<MyVal>().await;
235 assert!(v.is_some());
236 assert_eq!(v.unwrap().0, "test".to_string());
237
238 let v = s.get_mut::<MyVal>().await;
239 assert!(v.is_some());
240 assert_eq!(v.as_ref().unwrap().0, "test".to_string());
241 *v.unwrap() = MyVal("hmm".into());
242
243 let v = s.insert(MyVal("jop".into())).await;
244 assert!(v.is_some());
245 assert_eq!(v.unwrap().0, "hmm".to_string());
246
247 let v = s.remove::<MyVal>().await;
248 assert!(v.is_some());
249 assert_eq!(v.unwrap().0, "jop".to_string());
250 }
251
252 #[tokio::test]
253 async fn test_merge() {
254 let mut parent = SharedStorageImpl::new();
255 let mut child1 = parent.fork();
256 let _ = child1.insert(MyVal("bbb".to_owned())).await;
257 let mut child2 = parent.fork();
258 let _ = child2.insert(MyVal("ccc".to_owned())).await;
259 let mut child3 = parent.fork();
260 let _ = child3.insert(MyVal("ddd".to_owned())).await;
261 parent.join(Box::new([child1, child2, child3]));
262 let mut child = parent.fork();
263 let _ = child.insert(MyVal("aaa".to_owned())).await;
264 parent.join(Box::new([child]));
265
266 let res = parent.get::<MyVal>().await;
267 assert_eq!(res.unwrap().0, "aaa".to_owned());
268 }
269}