Skip to main content

node_flow/context/storage/local_storage/
implementation.rs

1use std::{
2    any::{Any, TypeId},
3    collections::{HashMap, HashSet},
4    fmt::Debug,
5};
6
7use crate::context::{
8    Fork, Join, Update,
9    storage::local_storage::{LocalStorage, Merge, MergeResult},
10};
11
12trait StorageItem: Any + Send {
13    fn duplicate(&self) -> Box<dyn StorageItem>;
14    fn merge(
15        &self,
16        parent: Option<&dyn StorageItem>,
17        others: Box<[Box<dyn StorageItem>]>,
18    ) -> MergeResult<Box<dyn StorageItem>>;
19}
20
21impl<T> StorageItem for T
22where
23    T: Merge + Any + Send + Clone,
24{
25    fn duplicate(&self) -> Box<dyn StorageItem> {
26        Box::new(self.clone())
27    }
28
29    // SAFETY: self can never be used otherwise it can lead to UB
30    fn merge(
31        &self,
32        parent: Option<&dyn StorageItem>,
33        others: Box<[Box<dyn StorageItem>]>,
34    ) -> MergeResult<Box<dyn StorageItem>> {
35        let others = others
36            .into_iter()
37            .map(|b| *(b as Box<dyn Any>).downcast::<T>().unwrap())
38            .collect::<Box<_>>();
39        let parent = parent.map(|v| (v as &dyn Any).downcast_ref::<T>().unwrap());
40        match <T as Merge>::merge(parent, others) {
41            MergeResult::ReplaceOrInsert(val) => MergeResult::ReplaceOrInsert(Box::new(val)),
42            MergeResult::KeepParent => MergeResult::KeepParent,
43            MergeResult::Remove => MergeResult::Remove,
44        }
45    }
46}
47
48impl Clone for Box<dyn StorageItem> {
49    fn clone(&self) -> Self {
50        self.duplicate()
51    }
52}
53
54/// An implementation of type-based local storage.
55///
56/// See [`LocalStorage`] for more info.\
57/// See also [`Fork`], [`Update`], [`Join`], [`Merge`].
58///
59/// # Internal Structure
60/// - `inner`: Stores the mapping of `TypeId` -> type-erased boxed value.
61/// - `changed`: Tracks which entries have been modified.
62#[derive(Default)]
63pub struct LocalStorageImpl {
64    inner: HashMap<TypeId, Box<dyn StorageItem>>,
65    changed: HashSet<TypeId>,
66}
67
68impl Debug for LocalStorageImpl {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        f.debug_struct("LocalStorageImpl").finish_non_exhaustive()
71    }
72}
73
74impl LocalStorageImpl {
75    /// Constructs new `LocalStorageImpl`.
76    #[must_use]
77    pub fn new() -> Self {
78        Self::default()
79    }
80}
81
82impl LocalStorage for LocalStorageImpl {
83    fn get<T>(&self) -> Option<&T>
84    where
85        T: 'static,
86    {
87        self.inner.get(&TypeId::of::<T>()).map(|val| {
88            let any_ref: &dyn Any = &**val;
89            any_ref.downcast_ref::<T>().unwrap()
90        })
91    }
92
93    fn get_mut<T>(&mut self) -> Option<&mut T>
94    where
95        T: 'static,
96    {
97        self.inner.get_mut(&TypeId::of::<T>()).map(|val| {
98            self.changed.insert(TypeId::of::<T>());
99            let any_debug_ref: &mut dyn Any = &mut **val;
100            any_debug_ref.downcast_mut::<T>().unwrap()
101        })
102    }
103
104    fn insert<T>(&mut self, val: T) -> Option<T>
105    where
106        T: Merge + Clone + Send + 'static,
107    {
108        self.changed.insert(TypeId::of::<T>());
109        self.inner
110            .insert(TypeId::of::<T>(), Box::new(val))
111            .map(|val| *(val as Box<dyn Any>).downcast::<T>().unwrap())
112    }
113
114    fn remove<T>(&mut self) -> Option<T>
115    where
116        T: 'static,
117    {
118        self.inner.remove(&TypeId::of::<T>()).map(|val| {
119            self.changed.insert(TypeId::of::<T>());
120            *(val as Box<dyn Any>).downcast::<T>().unwrap()
121        })
122    }
123}
124
125impl Fork for LocalStorageImpl {
126    fn fork(&self) -> Self {
127        Self {
128            inner: self.inner.clone(),
129            changed: HashSet::new(),
130        }
131    }
132}
133
134impl Update for LocalStorageImpl {
135    fn update_from(&mut self, other: Self) {
136        self.inner = other.inner;
137        self.changed.extend(other.changed.iter());
138    }
139}
140
141impl Join for LocalStorageImpl {
142    fn join(&mut self, mut others: Box<[Self]>) {
143        if others.is_empty() {
144            return;
145        }
146
147        // gather TypeId of all changed items
148        let mut changed = others[0].changed.clone();
149        others
150            .iter_mut()
151            .skip(1)
152            .for_each(|s| changed.extend(s.changed.iter()));
153
154        for key in changed {
155            // collect items from self and from other_items if the item was changed
156            let parent = self.inner.get(&key).map(Box::as_ref);
157            let other_items = others
158                .iter_mut()
159                .filter_map(|s| {
160                    s.changed
161                        .remove(&key)
162                        .then(|| s.inner.remove(&key))
163                        .flatten()
164                })
165                .collect::<Box<[_]>>();
166
167            // decide if and how the items are merged
168            // allow match_same_arms for comments
169            #[expect(clippy::match_same_arms)]
170            match (parent.is_none(), other_items.is_empty()) {
171                // parent and other_items are empty
172                //     => item was inserted in a branch and then removed
173                // = skip item
174                (true, true) => continue,
175                // parent is empty and other_items contain exactly one item
176                //     => item was inserted in exactly one branch
177                //  or => item was inserted in multiple branches, but later it was removed from all but one branch
178                // = insert first and only item
179                (true, false) if other_items.len() == 1 => {
180                    let first = other_items.into_iter().next().unwrap();
181                    self.inner.insert(key, first);
182                    self.changed.insert(key);
183                    continue;
184                }
185                // parent is empty and other_items contain more than one item
186                //     => more than one branch inserted item
187                // = merge needed
188                (true, false) => {}
189                // parent is not empty and other_items is empty
190                //     => item was removed in all branches
191                // = remove item
192                (false, true) => {
193                    self.inner.remove(&key);
194                    self.changed.insert(key);
195                    continue;
196                }
197                // parent and other_items are not empty
198                //     => at least one branch inserted item
199                // = merge needed
200                (false, false) => {}
201            }
202
203            // Merge trait is needed for merging
204            let res = {
205                // All types (inside of a `parent` and `other_items[...]`) have the same type
206                let dispatcher: &dyn StorageItem = parent.map_or_else(|| &*other_items[0], |p| p);
207
208                // Call merge on dyn StorageItem type
209                // SAFETY: reference is only used for VTable lookup, the self type is otherwise unused,
210                //         this reference is then dropped and never used since it will most likely point to a non-existent data
211                let dispatcher: &dyn StorageItem = unsafe { &*std::ptr::from_ref(dispatcher) };
212                dispatcher.merge(parent, other_items)
213            };
214            match res {
215                MergeResult::KeepParent => {}
216                MergeResult::ReplaceOrInsert(val) => {
217                    self.inner.insert(key, val);
218                    self.changed.insert(key);
219                }
220                MergeResult::Remove => {
221                    if self.inner.remove(&key).is_some() {
222                        self.changed.insert(key);
223                    }
224                }
225            }
226        }
227    }
228}
229
230#[cfg(test)]
231#[doc(hidden)]
232pub mod tests {
233    use super::*;
234
235    #[derive(Debug, Clone, PartialEq, Eq)]
236    #[allow(dead_code)]
237    pub struct MyVal(pub String);
238
239    impl Default for MyVal {
240        fn default() -> Self {
241            Self("|".to_owned())
242        }
243    }
244
245    impl Merge for MyVal {
246        fn merge(parent: Option<&Self>, others: Box<[Self]>) -> MergeResult<Self> {
247            let len = parent.as_ref().map(|v| v.0.len()).unwrap_or_default()
248                + others.iter().map(|v| v.0.len()).sum::<usize>();
249            if len == 0 {
250                return MergeResult::KeepParent;
251            }
252            let mut res = String::with_capacity(len);
253            if let Some(v) = parent {
254                res.push_str(&v.0);
255            }
256            for v in others {
257                res.push_str(&v.0);
258            }
259            MergeResult::ReplaceOrInsert(MyVal(res))
260        }
261    }
262
263    #[test]
264    fn works() {
265        let mut s = LocalStorageImpl::new();
266        s.insert(MyVal("test".into()));
267        //println!("{s:#?}");
268        let v = s.get::<MyVal>();
269        assert!(v.is_some());
270        assert_eq!(v.unwrap().0, "test".to_string());
271
272        let v = s.get_mut::<MyVal>();
273        assert!(v.is_some());
274        assert_eq!(v.as_ref().unwrap().0, "test".to_string());
275        *v.unwrap() = MyVal("hmm".into());
276
277        let v = s.insert(MyVal("jop".into()));
278        assert!(v.is_some());
279        assert_eq!(v.unwrap().0, "hmm".to_string());
280
281        let v = s.remove::<MyVal>();
282        assert!(v.is_some());
283        assert_eq!(v.unwrap().0, "jop".to_string());
284    }
285
286    #[test]
287    fn test_merge() {
288        let mut parent = LocalStorageImpl::new();
289        let mut child1 = parent.fork();
290        child1.insert(MyVal("bbb".to_owned()));
291        let mut child2 = parent.fork();
292        child2.insert(MyVal("ccc".to_owned()));
293        let mut child3 = parent.fork();
294        child3.insert(MyVal("ddd".to_owned()));
295        parent.join(Box::new([child1, child2, child3]));
296        let mut child = parent.fork();
297        child.insert(MyVal("aaa".to_owned()));
298        parent.join(Box::new([child]));
299
300        let res = parent.get::<MyVal>();
301        assert_eq!(res.unwrap().0, "bbbcccdddaaa".to_owned());
302    }
303}