rama_core/context/
extensions.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::fmt;
4use std::hash::{BuildHasherDefault, Hasher};
5
6type AnyMap = HashMap<TypeId, Box<dyn AnyClone + Send + Sync>, BuildHasherDefault<IdHasher>>;
7
8// With TypeIds as keys, there's no need to hash them. They are already hashes
9// themselves, coming from the compiler. The IdHasher just holds the u64 of
10// the TypeId, and then returns it, instead of doing any bit fiddling.
11#[derive(Default)]
12struct IdHasher(u64);
13
14impl Hasher for IdHasher {
15    fn write(&mut self, _: &[u8]) {
16        unreachable!("TypeId calls write_u64");
17    }
18
19    #[inline]
20    fn write_u64(&mut self, id: u64) {
21        self.0 = id;
22    }
23
24    #[inline]
25    fn finish(&self) -> u64 {
26        self.0
27    }
28}
29
30/// A type map of protocol extensions.
31///
32/// `Extensions` can be used by `Request` and `Response` to store
33/// extra data derived from the underlying protocol.
34#[derive(Clone, Default)]
35pub struct Extensions {
36    // If extensions are never used, no need to carry around an empty HashMap.
37    // That's 3 words. Instead, this is only 1 word.
38    map: Option<Box<AnyMap>>,
39}
40
41impl Extensions {
42    /// Create an empty `Extensions`.
43    #[inline]
44    pub const fn new() -> Extensions {
45        Extensions { map: None }
46    }
47
48    /// Insert a type into this `Extensions`.
49    ///
50    /// If a extension of this type already existed, it will
51    /// be returned.
52    pub fn insert<T: Clone + Send + Sync + 'static>(&mut self, val: T) -> Option<T> {
53        self.map
54            .get_or_insert_with(Box::default)
55            .insert(TypeId::of::<T>(), Box::new(val))
56            .and_then(|boxed| boxed.into_any().downcast().ok().map(|boxed| *boxed))
57    }
58
59    /// Insert a type only into this `Extensions`, if the value is `Some(T)`.
60    ///
61    /// See [`Self::insert`] for more information.
62    pub fn maybe_insert<T: Clone + Send + Sync + 'static>(
63        &mut self,
64        mut val: Option<T>,
65    ) -> Option<T> {
66        val.take().and_then(|val| self.insert(val))
67    }
68
69    /// Extend these extensions with another Extensions.
70    pub fn extend(&mut self, other: Extensions) {
71        if let Some(other_map) = other.map {
72            let map = self.map.get_or_insert_with(Box::default);
73            #[allow(clippy::useless_conversion)]
74            map.extend(other_map.into_iter());
75        }
76    }
77
78    /// Clear the `Extensions` of all inserted extensions.
79    pub fn clear(&mut self) {
80        if let Some(map) = self.map.as_mut() {
81            map.clear();
82        }
83    }
84
85    /// Returns true if the `Extensions` contains the given type.
86    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
87        self.map
88            .as_ref()
89            .map(|map| map.contains_key(&TypeId::of::<T>()))
90            .unwrap_or_default()
91    }
92
93    /// Get a shared reference to a type previously inserted on this `Extensions`.
94    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
95        self.map
96            .as_ref()
97            .and_then(|map| map.get(&TypeId::of::<T>()))
98            .and_then(|boxed| (**boxed).as_any().downcast_ref())
99    }
100
101    /// Get an exclusive reference to a type previously inserted on this `Extensions`.
102    pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
103        self.map
104            .as_mut()
105            .and_then(|map| map.get_mut(&TypeId::of::<T>()))
106            .and_then(|boxed| (**boxed).as_any_mut().downcast_mut())
107    }
108
109    /// Inserts a value into the map computed from `f` into if it is [`None`],
110    /// then returns an exclusive reference to the contained value.
111    ///
112    /// Use the cheaper [`Self::get_or_insert_with`] in case you do not need access to
113    /// the extensions for the creation of `T`, as this function comes with
114    /// an extra cost.
115    pub fn get_or_insert_with_ext<T: Clone + Send + Sync + 'static>(
116        &mut self,
117        f: impl FnOnce(&Self) -> T,
118    ) -> &mut T {
119        if self.contains::<T>() {
120            // NOTE: once <https://github.com/rust-lang/polonius>
121            // is merged into rust we can use directly `if let Some(v) = self.extensions.get_mut()`,
122            // until then we need this work around.
123            return self.get_mut().unwrap();
124        }
125        let v = f(self);
126        self.insert(v);
127        self.get_mut().unwrap()
128    }
129
130    /// Inserts a value into the map computed from `f` into if it is [`None`],
131    /// then returns an exclusive reference to the contained value.
132    pub fn get_or_insert_with<T: Send + Sync + Clone + 'static>(
133        &mut self,
134        f: impl FnOnce() -> T,
135    ) -> &mut T {
136        let map = self.map.get_or_insert_with(Box::default);
137        let entry = map.entry(TypeId::of::<T>());
138        let boxed = entry.or_insert_with(|| Box::new(f()));
139        (**boxed)
140            .as_any_mut()
141            .downcast_mut()
142            .expect("type mismatch")
143    }
144
145    /// Inserts a value into the map computed by converting `U` into `T` if it is `None`
146    /// then returns an exclusive reference to the contained value.
147    pub fn get_or_insert_from<T, U>(&mut self, src: U) -> &mut T
148    where
149        T: Send + Sync + Clone + 'static,
150        U: Into<T>,
151    {
152        let map = self.map.get_or_insert_with(Box::default);
153        let entry = map.entry(TypeId::of::<T>());
154        let boxed = entry.or_insert_with(|| Box::new(src.into()));
155        (**boxed)
156            .as_any_mut()
157            .downcast_mut()
158            .expect("type mismatch")
159    }
160
161    /// Retrieves a value of type `T` from the context.
162    ///
163    /// If the value does not exist, the given value is inserted and an exclusive reference to it is returned.
164    pub fn get_or_insert<T: Clone + Send + Sync + 'static>(&mut self, fallback: T) -> &mut T {
165        self.get_or_insert_with(|| fallback)
166    }
167
168    /// Get an extension or `T`'s [`Default`].
169    ///
170    /// see [`Extensions::get`] for more details.
171    pub fn get_or_insert_default<T: Default + Clone + Send + Sync + 'static>(&mut self) -> &mut T {
172        self.get_or_insert_with(T::default)
173    }
174
175    /// Remove a type from this `Extensions`.
176    pub fn remove<T: Clone + Send + Sync + 'static>(&mut self) -> Option<T> {
177        self.map
178            .as_mut()
179            .and_then(|map| map.remove(&TypeId::of::<T>()))
180            .and_then(|boxed| boxed.into_any().downcast().ok().map(|boxed| *boxed))
181    }
182}
183
184impl fmt::Debug for Extensions {
185    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186        f.debug_struct("Extensions").finish()
187    }
188}
189
190trait AnyClone: Any {
191    fn clone_box(&self) -> Box<dyn AnyClone + Send + Sync>;
192    fn as_any(&self) -> &dyn Any;
193    fn as_any_mut(&mut self) -> &mut dyn Any;
194    fn into_any(self: Box<Self>) -> Box<dyn Any>;
195}
196
197impl<T: Clone + Send + Sync + 'static> AnyClone for T {
198    fn clone_box(&self) -> Box<dyn AnyClone + Send + Sync> {
199        Box::new(self.clone())
200    }
201
202    fn as_any(&self) -> &dyn Any {
203        self
204    }
205
206    fn as_any_mut(&mut self) -> &mut dyn Any {
207        self
208    }
209
210    fn into_any(self: Box<Self>) -> Box<dyn Any> {
211        self
212    }
213}
214
215impl Clone for Box<dyn AnyClone + Send + Sync> {
216    fn clone(&self) -> Self {
217        (**self).clone_box()
218    }
219}
220
221#[test]
222fn test_extensions() {
223    #[derive(Clone, Debug, PartialEq, Eq, Hash)]
224    struct MyType(i32);
225
226    let mut extensions = Extensions::new();
227
228    extensions.insert(5i32);
229    extensions.insert(MyType(10));
230
231    assert_eq!(extensions.get(), Some(&5i32));
232
233    let mut ext2 = extensions.clone();
234
235    ext2.insert(true);
236
237    assert_eq!(ext2.get(), Some(&5i32));
238    assert_eq!(ext2.get(), Some(&MyType(10)));
239    assert_eq!(ext2.get(), Some(&true));
240
241    // test extend
242    let mut extensions = Extensions::new();
243    extensions.insert(5i32);
244    extensions.insert(MyType(10));
245
246    let mut extensions2 = Extensions::new();
247    extensions2.extend(extensions);
248    assert_eq!(extensions2.get(), Some(&5i32));
249    assert_eq!(extensions2.get(), Some(&MyType(10)));
250
251    // test clear
252    extensions2.clear();
253    assert_eq!(extensions2.get::<i32>(), None);
254    assert_eq!(extensions2.get::<MyType>(), None);
255}