apollo_router/context/extensions/
mod.rs

1pub(crate) mod sync;
2
3// NOTE: this module is taken from tokio's tracing span's extensions
4//       which is taken from https://github.com/hyperium/http/blob/master/src/extensions.rs
5
6use std::any::Any;
7use std::any::TypeId;
8use std::collections::HashMap;
9use std::fmt;
10use std::hash::BuildHasherDefault;
11use std::hash::Hasher;
12
13type AnyMap = HashMap<TypeId, Box<dyn Any + Send + Sync>, BuildHasherDefault<IdHasher>>;
14
15// With TypeIds as keys, there's no need to hash them. They are already hashes
16// themselves, coming from the compiler. The IdHasher just holds the u64 of
17// the TypeId, and then returns it, instead of doing any bit fiddling.
18#[derive(Default)]
19struct IdHasher(u64);
20
21impl Hasher for IdHasher {
22    fn write(&mut self, _: &[u8]) {
23        unreachable!("TypeId calls write_u64");
24    }
25
26    #[inline]
27    fn write_u64(&mut self, id: u64) {
28        self.0 = id;
29    }
30
31    #[inline]
32    fn finish(&self) -> u64 {
33        self.0
34    }
35}
36
37/// A type map of protocol extensions.
38///
39/// `Extensions` can be used by `Request` and `Response` to store
40/// extra data derived from the underlying protocol.
41#[derive(Default)]
42pub struct Extensions {
43    // If extensions are never used, no need to carry around an empty HashMap.
44    // That's 3 words. Instead, this is only 1 word.
45    map: Option<Box<AnyMap>>,
46}
47
48#[allow(unused)]
49impl Extensions {
50    /// Create an empty `Extensions`.
51    #[inline]
52    pub(crate) fn new() -> Extensions {
53        Extensions { map: None }
54    }
55
56    /// Insert a type into this `Extensions`.
57    ///
58    /// If a extension of this type already existed, it will
59    /// be returned.
60    pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) -> Option<T> {
61        self.map
62            .get_or_insert_with(Box::default)
63            .insert(TypeId::of::<T>(), Box::new(val))
64            .and_then(|boxed| {
65                (boxed as Box<dyn Any + 'static>)
66                    .downcast()
67                    .ok()
68                    .map(|boxed| *boxed)
69            })
70    }
71
72    /// Get a reference to a type previously inserted on this `Extensions`.
73    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
74        self.map
75            .as_ref()
76            .and_then(|map| map.get(&TypeId::of::<T>()))
77            .and_then(|boxed| (&**boxed as &(dyn Any + 'static)).downcast_ref())
78    }
79
80    /// Get a mutable reference to a type previously inserted on this `Extensions`.
81    pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
82        self.map
83            .as_mut()
84            .and_then(|map| map.get_mut(&TypeId::of::<T>()))
85            .and_then(|boxed| (&mut **boxed as &mut (dyn Any + 'static)).downcast_mut())
86    }
87
88    /// Get a mutable reference to a type or insert and return the value if it does not exist
89    pub fn get_or_default_mut<T: Default + Send + Sync + 'static>(&mut self) -> &mut T {
90        let map = self.map.get_or_insert_with(Box::default);
91        let value = map
92            .entry(TypeId::of::<T>())
93            .or_insert_with(|| Box::<T>::default());
94        // It should be impossible for the entry to be the wrong type as we don't allow direct access to the map.
95        value
96            .downcast_mut()
97            .expect("default value should be inserted and we should be able to downcast it")
98    }
99
100    /// Returns `true` type has been stored in `Extensions`.
101    pub fn contains_key<T: Send + Sync + 'static>(&self) -> bool {
102        self.map
103            .as_ref()
104            .map(|map| map.contains_key(&TypeId::of::<T>()))
105            .unwrap_or_default()
106    }
107
108    /// Remove a type from this `Extensions`.
109    ///
110    /// If a extension of this type existed, it will be returned.
111    pub fn remove<T: Send + Sync + 'static>(&mut self) -> Option<T> {
112        self.map
113            .as_mut()
114            .and_then(|map| map.remove(&TypeId::of::<T>()))
115            .and_then(|boxed| {
116                (boxed as Box<dyn Any + 'static>)
117                    .downcast()
118                    .ok()
119                    .map(|boxed| *boxed)
120            })
121    }
122
123    /// Clear the `Extensions` of all inserted extensions.
124    #[inline]
125    pub fn clear(&mut self) {
126        if let Some(ref mut map) = self.map {
127            map.clear();
128        }
129    }
130
131    /// Check whether the extension set is empty or not.
132    #[inline]
133    pub fn is_empty(&self) -> bool {
134        self.map.as_ref().is_none_or(|map| map.is_empty())
135    }
136
137    /// Get the numer of extensions available.
138    #[inline]
139    pub fn len(&self) -> usize {
140        self.map.as_ref().map_or(0, |map| map.len())
141    }
142}
143
144impl fmt::Debug for Extensions {
145    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146        f.debug_struct("Extensions").finish()
147    }
148}
149
150#[test]
151fn test_extensions() {
152    #[derive(Debug, PartialEq)]
153    struct MyType(i32);
154
155    let mut extensions = Extensions::new();
156
157    extensions.insert(5i32);
158    extensions.insert(MyType(10));
159
160    assert_eq!(extensions.get(), Some(&5i32));
161    assert_eq!(extensions.get_mut(), Some(&mut 5i32));
162
163    assert_eq!(extensions.remove::<i32>(), Some(5i32));
164    assert!(extensions.get::<i32>().is_none());
165
166    assert_eq!(extensions.get::<bool>(), None);
167    assert_eq!(extensions.get(), Some(&MyType(10)));
168}