rama_core/context/
extensions.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::fmt;
4use std::hash::{BuildHasherDefault, Hasher};
5
6type AnyMap = HashMap<TypeId, Box<dyn AnyClone + Send + Sync>, BuildHasherDefault<IdHasher>>;
7
8// With TypeIds as keys, there's no need to hash them. They are already hashes
9// themselves, coming from the compiler. The IdHasher just holds the u64 of
10// the TypeId, and then returns it, instead of doing any bit fiddling.
11#[derive(Default)]
12struct IdHasher(u64);
13
14impl Hasher for IdHasher {
15    fn write(&mut self, _: &[u8]) {
16        unreachable!("TypeId calls write_u64");
17    }
18
19    #[inline]
20    fn write_u64(&mut self, id: u64) {
21        self.0 = id;
22    }
23
24    #[inline]
25    fn finish(&self) -> u64 {
26        self.0
27    }
28}
29
30/// A type map of protocol extensions.
31///
32/// `Extensions` can be used by `Request` and `Response` to store
33/// extra data derived from the underlying protocol.
34#[derive(Clone, Default)]
35pub struct Extensions {
36    // If extensions are never used, no need to carry around an empty HashMap.
37    // That's 3 words. Instead, this is only 1 word.
38    map: Option<Box<AnyMap>>,
39}
40
41impl Extensions {
42    /// Create an empty `Extensions`.
43    #[inline]
44    #[must_use]
45    pub const fn new() -> Self {
46        Self { map: None }
47    }
48
49    /// Insert a type into this `Extensions`.
50    ///
51    /// If a extension of this type already existed, it will
52    /// be returned.
53    pub fn insert<T: Clone + Send + Sync + 'static>(&mut self, val: T) -> Option<T> {
54        self.map
55            .get_or_insert_with(Box::default)
56            .insert(TypeId::of::<T>(), Box::new(val))
57            .and_then(|boxed| boxed.into_any().downcast().ok().map(|boxed| *boxed))
58    }
59
60    /// Insert a type only into this `Extensions`, if the value is `Some(T)`.
61    ///
62    /// See [`Self::insert`] for more information.
63    pub fn maybe_insert<T: Clone + Send + Sync + 'static>(
64        &mut self,
65        mut val: Option<T>,
66    ) -> Option<T> {
67        val.take().and_then(|val| self.insert(val))
68    }
69
70    /// Extend these extensions with another Extensions.
71    pub fn extend(&mut self, other: Self) {
72        if let Some(other_map) = other.map {
73            let map = self.map.get_or_insert_with(Box::default);
74            #[allow(clippy::useless_conversion)]
75            map.extend(other_map.into_iter());
76        }
77    }
78
79    /// Clear the `Extensions` of all inserted extensions.
80    pub fn clear(&mut self) {
81        if let Some(map) = self.map.as_mut() {
82            map.clear();
83        }
84    }
85
86    /// Returns true if the `Extensions` contains the given type.
87    #[must_use]
88    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
89        self.map
90            .as_ref()
91            .map(|map| map.contains_key(&TypeId::of::<T>()))
92            .unwrap_or_default()
93    }
94
95    /// Get a shared reference to a type previously inserted on this `Extensions`.
96    #[must_use]
97    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
98        self.map
99            .as_ref()
100            .and_then(|map| map.get(&TypeId::of::<T>()))
101            .and_then(|boxed| (**boxed).as_any().downcast_ref())
102    }
103
104    /// Get an exclusive reference to a type previously inserted on this `Extensions`.
105    pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
106        self.map
107            .as_mut()
108            .and_then(|map| map.get_mut(&TypeId::of::<T>()))
109            .and_then(|boxed| (**boxed).as_any_mut().downcast_mut())
110    }
111
112    /// Inserts a value into the map computed from `f` into if it is [`None`],
113    /// then returns an exclusive reference to the contained value.
114    ///
115    /// Use the cheaper [`Self::get_or_insert_with`] in case you do not need access to
116    /// the extensions for the creation of `T`, as this function comes with
117    /// an extra cost.
118    pub fn get_or_insert_with_ext<T: Clone + Send + Sync + 'static>(
119        &mut self,
120        f: impl FnOnce(&Self) -> T,
121    ) -> &mut T {
122        if self.contains::<T>() {
123            // NOTE: once <https://github.com/rust-lang/polonius>
124            // is merged into rust we can use directly `if let Some(v) = self.extensions.get_mut()`,
125            // until then we need this work around.
126            return self.get_mut().unwrap();
127        }
128        let v = f(self);
129        self.insert(v);
130        self.get_mut().unwrap()
131    }
132
133    /// Inserts a value into the map computed from `f` into if it is [`None`],
134    /// then returns an exclusive reference to the contained value.
135    pub fn get_or_insert_with<T: Send + Sync + Clone + 'static>(
136        &mut self,
137        f: impl FnOnce() -> T,
138    ) -> &mut T {
139        let map = self.map.get_or_insert_with(Box::default);
140        let entry = map.entry(TypeId::of::<T>());
141        let boxed = entry.or_insert_with(|| Box::new(f()));
142        (**boxed)
143            .as_any_mut()
144            .downcast_mut()
145            .expect("type mismatch")
146    }
147
148    /// Inserts a value into the map computed by converting `U` into `T` if it is `None`
149    /// then returns an exclusive reference to the contained value.
150    pub fn get_or_insert_from<T, U>(&mut self, src: U) -> &mut T
151    where
152        T: Send + Sync + Clone + 'static,
153        U: Into<T>,
154    {
155        let map = self.map.get_or_insert_with(Box::default);
156        let entry = map.entry(TypeId::of::<T>());
157        let boxed = entry.or_insert_with(|| Box::new(src.into()));
158        (**boxed)
159            .as_any_mut()
160            .downcast_mut()
161            .expect("type mismatch")
162    }
163
164    /// Retrieves a value of type `T` from the context.
165    ///
166    /// If the value does not exist, the given value is inserted and an exclusive reference to it is returned.
167    pub fn get_or_insert<T: Clone + Send + Sync + 'static>(&mut self, fallback: T) -> &mut T {
168        self.get_or_insert_with(|| fallback)
169    }
170
171    /// Get an extension or `T`'s [`Default`].
172    ///
173    /// see [`Extensions::get`] for more details.
174    pub fn get_or_insert_default<T: Default + Clone + Send + Sync + 'static>(&mut self) -> &mut T {
175        self.get_or_insert_with(T::default)
176    }
177
178    /// Remove a type from this `Extensions`.
179    pub fn remove<T: Clone + Send + Sync + 'static>(&mut self) -> Option<T> {
180        self.map
181            .as_mut()
182            .and_then(|map| map.remove(&TypeId::of::<T>()))
183            .and_then(|boxed| boxed.into_any().downcast().ok().map(|boxed| *boxed))
184    }
185}
186
187impl fmt::Debug for Extensions {
188    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189        f.debug_struct("Extensions").finish()
190    }
191}
192
193trait AnyClone: Any {
194    fn clone_box(&self) -> Box<dyn AnyClone + Send + Sync>;
195    fn as_any(&self) -> &dyn Any;
196    fn as_any_mut(&mut self) -> &mut dyn Any;
197    fn into_any(self: Box<Self>) -> Box<dyn Any>;
198}
199
200impl<T: Clone + Send + Sync + 'static> AnyClone for T {
201    fn clone_box(&self) -> Box<dyn AnyClone + Send + Sync> {
202        Box::new(self.clone())
203    }
204
205    fn as_any(&self) -> &dyn Any {
206        self
207    }
208
209    fn as_any_mut(&mut self) -> &mut dyn Any {
210        self
211    }
212
213    fn into_any(self: Box<Self>) -> Box<dyn Any> {
214        self
215    }
216}
217
218impl Clone for Box<dyn AnyClone + Send + Sync> {
219    fn clone(&self) -> Self {
220        (**self).clone_box()
221    }
222}
223
224#[test]
225fn test_extensions() {
226    #[derive(Clone, Debug, PartialEq, Eq, Hash)]
227    struct MyType(i32);
228
229    let mut extensions = Extensions::new();
230
231    extensions.insert(5i32);
232    extensions.insert(MyType(10));
233
234    assert_eq!(extensions.get(), Some(&5i32));
235
236    let mut ext2 = extensions.clone();
237
238    ext2.insert(true);
239
240    assert_eq!(ext2.get(), Some(&5i32));
241    assert_eq!(ext2.get(), Some(&MyType(10)));
242    assert_eq!(ext2.get(), Some(&true));
243
244    // test extend
245    let mut extensions = Extensions::new();
246    extensions.insert(5i32);
247    extensions.insert(MyType(10));
248
249    let mut extensions2 = Extensions::new();
250    extensions2.extend(extensions);
251    assert_eq!(extensions2.get(), Some(&5i32));
252    assert_eq!(extensions2.get(), Some(&MyType(10)));
253
254    // test clear
255    extensions2.clear();
256    assert_eq!(extensions2.get::<i32>(), None);
257    assert_eq!(extensions2.get::<MyType>(), None);
258}