ic_kit/
storage.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3
4/// An storage implementation for singleton design pattern, where we only have one value
5/// associated with each types.
6#[derive(Default)]
7pub struct Storage {
8    // TODO(qti3e) put Box in a RefCell when we get rid of get_mut::
9    storage: HashMap<TypeId, Box<dyn Any>>,
10}
11
12impl Storage {
13    #[inline]
14    pub(crate) fn get<T: 'static + Default>(&mut self) -> &T {
15        let tid = TypeId::of::<T>();
16        self.storage
17            .entry(tid)
18            .or_insert_with(|| Box::new(T::default()))
19            .downcast_ref::<T>()
20            .unwrap()
21    }
22
23    #[inline]
24    pub(crate) fn get_mut<T: 'static + Default>(&mut self) -> &mut T {
25        let tid = TypeId::of::<T>();
26        self.storage
27            .entry(tid)
28            .or_insert_with(|| Box::new(T::default()))
29            .downcast_mut::<T>()
30            .unwrap()
31    }
32
33    #[inline]
34    pub(crate) fn get_maybe<T: 'static>(&mut self) -> Option<&T> {
35        let tid = TypeId::of::<T>();
36        self.storage
37            .get(&tid)
38            .map(|c| c.downcast_ref::<T>().unwrap())
39    }
40
41    /// Pass an immutable reference to the stored data of the type `T` to the closure,
42    /// if there is no data associated with the type, store the `Default` and then perform the
43    /// operation.
44    #[inline]
45    pub fn with<T: 'static + Default, U, F: FnOnce(&T) -> U>(&mut self, callback: F) -> U {
46        let tid = TypeId::of::<T>();
47        let cell = &*self
48            .storage
49            .entry(tid)
50            .or_insert_with(|| Box::new(T::default()));
51        let borrow = cell.downcast_ref::<T>().unwrap();
52        callback(borrow)
53    }
54
55    /// Pass an immutable reference to the stored data of the type `T` to the closure,
56    /// if there is no data associated with the type, just return None.
57    #[inline]
58    pub fn maybe_with<T: 'static, U, F: FnOnce(&T) -> U>(&mut self, callback: F) -> Option<U> {
59        let tid = TypeId::of::<T>();
60        self.storage
61            .get(&tid)
62            .map(|cell| cell.downcast_ref::<T>().unwrap())
63            .map(callback)
64    }
65
66    /// Like [`Self::with`] but passes a mutable reference.
67    #[inline]
68    pub fn with_mut<T: 'static + Default, U, F: FnOnce(&mut T) -> U>(&mut self, callback: F) -> U {
69        let tid = TypeId::of::<T>();
70        let cell = self
71            .storage
72            .entry(tid)
73            .or_insert_with(|| Box::new(T::default()));
74        let borrow = cell.downcast_mut::<T>().unwrap();
75        callback(borrow)
76    }
77
78    /// Like [`Self::maybe_with`] but passes a mutable reference.
79    #[inline]
80    pub fn maybe_with_mut<T: 'static, U, F: FnOnce(&mut T) -> U>(
81        &mut self,
82        callback: F,
83    ) -> Option<U> {
84        let tid = TypeId::of::<T>();
85        self.storage
86            .get_mut(&tid)
87            .map(|cell| cell.downcast_mut::<T>().unwrap())
88            .map(callback)
89    }
90
91    /// Remove the data associated with the type `T`, and returns it if any.
92    #[inline]
93    pub fn take<T: 'static>(&mut self) -> Option<T> {
94        let tid = TypeId::of::<T>();
95        self.storage
96            .remove(&tid)
97            .map(|cell| *cell.downcast::<T>().unwrap())
98    }
99
100    /// Store the given value for type `T`, returns the previously stored value if any.
101    #[inline]
102    pub fn swap<T: 'static>(&mut self, value: T) -> Option<T> {
103        let tid = TypeId::of::<T>();
104        self.storage
105            .insert(tid, Box::new(value))
106            .map(|cell| *cell.downcast::<T>().unwrap())
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use crate::storage::Storage;
113
114    #[derive(Default)]
115    struct Counter {
116        count: u64,
117    }
118
119    impl Counter {
120        pub fn get(&self) -> u64 {
121            self.count
122        }
123
124        pub fn increment(&mut self) -> u64 {
125            self.count += 1;
126            self.count
127        }
128    }
129
130    #[test]
131    fn test_storage() {
132        let mut storage = Storage::default();
133        assert_eq!(storage.with(Counter::get), 0);
134        assert_eq!(storage.with_mut(Counter::increment), 1);
135    }
136}