Skip to main content

runtime_context/
context.rs

1use super::{Data, ShareableTid, TypeMap};
2use std::any::TypeId;
3
4/// Runtime context storing values by type.
5///
6/// The context can store owned values as well as borrowed references (immutable
7/// or mutable). Values are keyed by `TypeId` using a specialized hasher for
8/// fast lookups.
9pub struct Context<'ty, 'r> {
10    data: TypeMap<Data<'ty, 'r>>,
11}
12
13impl Default for Context<'_, '_> {
14    #[inline]
15    fn default() -> Self {
16        Self::new()
17    }
18}
19
20impl<'ty, 'r> Context<'ty, 'r> {
21    /// Create a new empty `Context`.
22    #[inline]
23    pub fn new() -> Self {
24        Self {
25            data: TypeMap::default(),
26        }
27    }
28
29    /// Insert a value into the context without checking the type.
30    ///
31    /// This is a low-level escape hatch for advanced use-cases.
32    #[inline]
33    pub fn insert_unchecked(&mut self, key: TypeId, data: Data<'ty, 'r>) {
34        self.data.insert(key, data);
35    }
36
37    /// Insert a borrowed value into the context.
38    #[inline]
39    pub fn insert_ref<T: ShareableTid<'ty>>(&mut self, value: &'r T) {
40        self.data.insert(T::id(), Data::Borrowed(value));
41    }
42
43    /// Insert a mutable reference into the context.
44    #[inline]
45    pub fn insert_mut<T: ShareableTid<'ty>>(&mut self, value: &'r mut T) {
46        self.data.insert(T::id(), Data::Mut(value));
47    }
48
49    /// Insert an owned value into the context.
50    #[inline]
51    pub fn insert<T: ShareableTid<'ty>>(&mut self, value: T) {
52        self.data.insert(T::id(), Data::Owned(Box::new(value)));
53    }
54
55    /// Get a shared reference to a stored value by type.
56    #[inline]
57    pub fn get<'b, T: ShareableTid<'ty>>(&'b self) -> Option<&'b T> {
58        self.data.get(&T::id()).and_then(|v| v.downcast_ref())
59    }
60
61    /// Get a mutable reference to a stored value by type.
62    #[inline]
63    pub fn get_mut<'b, T: ShareableTid<'ty>>(&'b mut self) -> Option<&'b mut T> {
64        self.data
65            .get_mut(&T::id())
66            .and_then(|v| v.downcast_mut())
67    }
68
69    /// Get a stored `Data` by `TypeId`.
70    #[inline]
71    pub fn get_data<'b>(&'b self, id: &TypeId) -> Option<&'b Data<'ty, 'r>> {
72        self.data.get(id)
73    }
74
75    /// Get a mutable `Data` by `TypeId`.
76    #[inline]
77    pub fn get_data_mut<'b>(&'b mut self, id: &TypeId) -> Option<&'b mut Data<'ty, 'r>> {
78        self.data.get_mut(id)
79    }
80
81    /// Get multiple mutable `Data` entries by distinct `TypeId`s.
82    #[inline]
83    pub fn get_disjoint_mut<'b, const N: usize>(
84        &'b mut self,
85        keys: [&TypeId; N],
86    ) -> [Option<&'b mut Data<'ty, 'r>>; N] {
87        self.data.get_disjoint_mut(keys)
88    }
89
90    /// Remove an owned value from the context and return it.
91    #[inline]
92    pub fn take<T: ShareableTid<'ty>>(&mut self) -> Option<T> {
93        let id = T::id();
94        match self.data.remove(&id) {
95            Some(data) => data.try_take_owned::<T>().ok(),
96            None => None,
97        }
98    }
99
100    /// Remove any stored value for the given type and return the raw `Data`.
101    #[inline]
102    pub fn remove<T: ShareableTid<'ty>>(&mut self) -> Option<Data<'ty, 'r>> {
103        self.data.remove(&T::id())
104    }
105
106    /// Check if a value of a specific type is present.
107    #[inline]
108    pub fn contains<T: ShareableTid<'ty>>(&self) -> bool {
109        self.data.contains_key(&T::id())
110    }
111
112    /// Clear all values from the context.
113    #[inline]
114    pub fn clear(&mut self) {
115        self.data.clear();
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use better_any::{Tid, tid};
122
123    use super::*;
124
125    #[derive(Debug, Clone, PartialEq, Eq)]
126    struct Dummy<'a>(&'a str);
127    tid!(Dummy<'_>);
128
129    #[test]
130    fn test_context_owned() {
131        let dummy = Dummy("Hello, World!");
132        let mut context = Context::new();
133
134        context.insert(dummy);
135        assert!(matches!(context.get::<Dummy>(), Some(_)));
136        assert!(matches!(context.get_mut::<Dummy>(), Some(_)));
137        assert_eq!(context.contains::<Dummy>(), true);
138    }
139
140    #[test]
141    fn test_context_ref() {
142        let dummy = Dummy("Hello, World!");
143        let mut context = Context::new();
144
145        context.insert_ref(&dummy);
146        assert_eq!(context.get::<Dummy>(), Some(&dummy));
147        assert_eq!(context.get_mut::<Dummy>(), None);
148        assert_eq!(context.contains::<Dummy>(), true);
149    }
150
151    #[test]
152    fn test_context_mut() {
153        let mut dummy = Dummy("Hello, World!");
154        let mut context = Context::new();
155
156        context.insert_mut(&mut dummy);
157        assert!(matches!(context.get::<Dummy>(), Some(_)));
158        assert!(matches!(context.get_mut::<Dummy>(), Some(_)));
159        assert_eq!(context.contains::<Dummy>(), true);
160    }
161
162    #[test]
163    fn test_context_no_immutable_err() {
164        let mut dummy = Dummy("Hello, World!");
165        {
166            let mut context = Context::new();
167            context.insert_mut(&mut dummy);
168        }
169
170        assert_eq!(dummy.0, "Hello, World!");
171    }
172
173    #[test]
174    fn test_downcast_to_trait() {
175        trait Foo {
176            fn foo(&self) -> &str;
177        }
178
179        impl Foo for Dummy<'_> {
180            fn foo(&self) -> &str {
181                self.0
182            }
183        }
184
185        struct FooWrapper<'a, T: Foo + 'static>(&'a mut T);
186        tid! { impl<'a, T: 'static> TidAble<'a> for FooWrapper<'a, T> where T: Foo }
187
188        let mut dummy = Dummy("Hello, World!");
189        let mut context = Context::new();
190        context.insert(FooWrapper(&mut dummy));
191
192        fn inner_ref_fn<T: Foo + 'static>(context: &Context) {
193            let data = context
194                .get_data(&FooWrapper::<T>::id())
195                .expect("Data not found");
196            data.downcast_ref::<FooWrapper<T>>()
197                .expect("Downcast failed")
198                .0
199                .foo();
200        }
201
202        inner_ref_fn::<Dummy>(&context);
203
204        fn inner_mut_fn<T: Foo + 'static>(context: &mut Context) {
205            let data = context
206                .get_data_mut(&FooWrapper::<T>::id())
207                .expect("Data not found");
208            data.downcast_mut::<FooWrapper<T>>()
209                .expect("Downcast failed")
210                .0
211                .foo();
212        }
213
214        inner_mut_fn::<Dummy>(&mut context);
215    }
216
217    #[test]
218    fn test_take_and_remove() {
219        #[derive(Debug, Clone, PartialEq, Eq)]
220        struct TakeMe(u64);
221        tid!(TakeMe);
222
223        let mut context = Context::new();
224        context.insert(TakeMe(7));
225
226        let owned = context.take::<TakeMe>().unwrap();
227        assert_eq!(owned, TakeMe(7));
228        assert_eq!(context.contains::<TakeMe>(), false);
229
230        context.insert(TakeMe(9));
231        let data = context.remove::<TakeMe>().unwrap();
232        assert!(matches!(data.try_take_owned::<TakeMe>(), Ok(TakeMe(9))));
233    }
234
235    #[test]
236    fn test_get_disjoint_mut() {
237        #[derive(Debug, Clone, PartialEq, Eq)]
238        struct A(u8);
239        #[derive(Debug, Clone, PartialEq, Eq)]
240        struct B(u8);
241        tid!(A);
242        tid!(B);
243
244        let mut context = Context::new();
245        context.insert(A(1));
246        context.insert(B(2));
247
248        let [a, b] = context.get_disjoint_mut([&A::id(), &B::id()]);
249        let a = a.unwrap().downcast_mut::<A>().unwrap();
250        let b = b.unwrap().downcast_mut::<B>().unwrap();
251
252        a.0 += 1;
253        b.0 += 2;
254
255        assert_eq!(context.get::<A>().unwrap().0, 2);
256        assert_eq!(context.get::<B>().unwrap().0, 4);
257    }
258
259    #[test]
260    fn test_clear_and_get_data() {
261        #[derive(Debug, Clone, PartialEq, Eq)]
262        struct C(i32);
263        tid!(C);
264
265        let mut context = Context::new();
266        context.insert(C(10));
267
268        let data = context.get_data(&C::id()).unwrap();
269        assert_eq!(data.downcast_ref::<C>().unwrap().0, 10);
270
271        context.clear();
272        assert_eq!(context.get::<C>(), None);
273    }
274}