channel_server/
extensions.rs

1use std::{
2    any::{Any, TypeId},
3    fmt,
4};
5
6use ahash::AHashMap;
7
8/// A type map for request extensions.
9///
10/// All entries into this map must be owned types (or static references).
11#[derive(Default)]
12pub struct Extensions {
13    /// Use AHasher with a std HashMap with for faster lookups on the small `TypeId` keys.
14    map: AHashMap<TypeId, Box<dyn Any + Send>>,
15}
16
17impl Extensions {
18    /// Creates an empty `Extensions`.
19    #[inline]
20    pub fn new() -> Extensions {
21        Extensions {
22            map: AHashMap::new(),
23        }
24    }
25
26    /// Insert an item into the map.
27    ///
28    /// If an item of this type was already stored, it will be replaced and returned.
29    ///
30    pub fn insert<T: 'static + Send>(&mut self, val: T) -> Option<T> {
31        self.map
32            .insert(TypeId::of::<T>(), Box::new(val))
33            .and_then(downcast_owned)
34    }
35
36    /// Check if map contains an item of a given type.
37    ///
38    pub fn contains<T: 'static>(&self) -> bool {
39        self.map.contains_key(&TypeId::of::<T>())
40    }
41
42    /// Get a reference to an item of a given type.
43    ///
44    pub fn get<T: 'static>(&self) -> Option<&T> {
45        self.map
46            .get(&TypeId::of::<T>())
47            .and_then(|boxed| boxed.downcast_ref())
48    }
49
50    /// Get a mutable reference to an item of a given type.
51    ///
52    pub fn get_mut<T: 'static>(&mut self) -> Option<&mut T> {
53        self.map
54            .get_mut(&TypeId::of::<T>())
55            .and_then(|boxed| boxed.downcast_mut())
56    }
57
58    /// Remove an item from the map of a given type.
59    ///
60    /// If an item of this type was already stored, it will be returned.
61    ///
62    pub fn remove<T: 'static + Send>(&mut self) -> Option<T> {
63        self.map.remove(&TypeId::of::<T>()).and_then(downcast_owned)
64    }
65
66    /// Clear the `Extensions` of all inserted extensions.
67    ///
68    #[inline]
69    pub fn clear(&mut self) {
70        self.map.clear();
71    }
72
73    /// Extends self with the items from another `Extensions`.
74    pub fn extend(&mut self, other: Extensions) {
75        self.map.extend(other.map);
76    }
77}
78
79impl fmt::Debug for Extensions {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        f.debug_struct("Extensions").finish()
82    }
83}
84
85fn downcast_owned<T: 'static + Send>(boxed: Box<dyn Any + Send>) -> Option<T> {
86    boxed.downcast().ok().map(|boxed| *boxed)
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn test_remove() {
95        let mut map = Extensions::new();
96
97        map.insert::<i8>(123);
98        assert!(map.get::<i8>().is_some());
99
100        map.remove::<i8>();
101        assert!(map.get::<i8>().is_none());
102    }
103
104    #[test]
105    fn test_clear() {
106        let mut map = Extensions::new();
107
108        map.insert::<i8>(8);
109        map.insert::<i16>(16);
110        map.insert::<i32>(32);
111
112        assert!(map.contains::<i8>());
113        assert!(map.contains::<i16>());
114        assert!(map.contains::<i32>());
115
116        map.clear();
117
118        assert!(!map.contains::<i8>());
119        assert!(!map.contains::<i16>());
120        assert!(!map.contains::<i32>());
121
122        map.insert::<i8>(10);
123        assert_eq!(*map.get::<i8>().unwrap(), 10);
124    }
125
126    #[test]
127    fn test_integers() {
128        static A: u32 = 8;
129
130        let mut map = Extensions::new();
131
132        map.insert::<i8>(8);
133        map.insert::<i16>(16);
134        map.insert::<i32>(32);
135        map.insert::<i64>(64);
136        map.insert::<i128>(128);
137        map.insert::<u8>(8);
138        map.insert::<u16>(16);
139        map.insert::<u32>(32);
140        map.insert::<u64>(64);
141        map.insert::<u128>(128);
142        map.insert::<&'static u32>(&A);
143        assert!(map.get::<i8>().is_some());
144        assert!(map.get::<i16>().is_some());
145        assert!(map.get::<i32>().is_some());
146        assert!(map.get::<i64>().is_some());
147        assert!(map.get::<i128>().is_some());
148        assert!(map.get::<u8>().is_some());
149        assert!(map.get::<u16>().is_some());
150        assert!(map.get::<u32>().is_some());
151        assert!(map.get::<u64>().is_some());
152        assert!(map.get::<u128>().is_some());
153        assert!(map.get::<&'static u32>().is_some());
154    }
155
156    #[test]
157    fn test_composition() {
158        struct Magi<T>(pub T);
159
160        struct Madoka {
161            pub god: bool,
162        }
163
164        struct Homura {
165            pub attempts: usize,
166        }
167
168        struct Mami {
169            pub guns: usize,
170        }
171
172        let mut map = Extensions::new();
173
174        map.insert(Magi(Madoka { god: false }));
175        map.insert(Magi(Homura { attempts: 0 }));
176        map.insert(Magi(Mami { guns: 999 }));
177
178        assert!(!map.get::<Magi<Madoka>>().unwrap().0.god);
179        assert_eq!(0, map.get::<Magi<Homura>>().unwrap().0.attempts);
180        assert_eq!(999, map.get::<Magi<Mami>>().unwrap().0.guns);
181    }
182
183    #[test]
184    fn test_extensions() {
185        #[derive(Debug, PartialEq)]
186        struct MyType(i32);
187
188        let mut extensions = Extensions::new();
189
190        extensions.insert(5i32);
191        extensions.insert(MyType(10));
192
193        assert_eq!(extensions.get(), Some(&5i32));
194        assert_eq!(extensions.get_mut(), Some(&mut 5i32));
195
196        assert_eq!(extensions.remove::<i32>(), Some(5i32));
197        assert!(extensions.get::<i32>().is_none());
198
199        assert_eq!(extensions.get::<bool>(), None);
200        assert_eq!(extensions.get(), Some(&MyType(10)));
201    }
202
203    #[test]
204    fn test_extend() {
205        #[derive(Debug, PartialEq)]
206        struct MyType(i32);
207
208        let mut extensions = Extensions::new();
209
210        extensions.insert(5i32);
211        extensions.insert(MyType(10));
212
213        let mut other = Extensions::new();
214
215        other.insert(15i32);
216        other.insert(20u8);
217
218        extensions.extend(other);
219
220        assert_eq!(extensions.get(), Some(&15i32));
221        assert_eq!(extensions.get_mut(), Some(&mut 15i32));
222
223        assert_eq!(extensions.remove::<i32>(), Some(15i32));
224        assert!(extensions.get::<i32>().is_none());
225
226        assert_eq!(extensions.get::<bool>(), None);
227        assert_eq!(extensions.get(), Some(&MyType(10)));
228
229        assert_eq!(extensions.get(), Some(&20u8));
230        assert_eq!(extensions.get_mut(), Some(&mut 20u8));
231    }
232}