cat_dev/net/ext_map.rs
1//! A keeper of "extensions" to a request, this is a copy of hyper extensions.
2//!
3//! Hyperium/HTTP was dual licensed under MIT/APACHE-2 when we copied this on
4//! 9/15/2022
5//!
6//! <https://github.com/hyperium/http/blob/master/LICENSE-MIT>
7//! <https://github.com/hyperium/http/blob/master/LICENSE-APACHE>
8//!
9//! This was to keep the same interface as our axum extensions.
10
11use std::any::{Any, TypeId};
12use std::collections::HashMap;
13use std::fmt;
14use std::hash::{BuildHasherDefault, Hasher};
15
16type AnyMap = HashMap<TypeId, Box<dyn Any + Send + Sync>, BuildHasherDefault<IdHasher>>;
17
18// With TypeIds as keys, there's no need to hash them. They are already hashes
19// themselves, coming from the compiler. The IdHasher just holds the u64 of
20// the TypeId, and then returns it, instead of doing any bit fiddling.
21#[derive(Default)]
22struct IdHasher(u64);
23
24impl Hasher for IdHasher {
25 fn write(&mut self, _: &[u8]) {
26 unreachable!("TypeId calls write_u64");
27 }
28
29 #[inline]
30 fn write_u64(&mut self, id: u64) {
31 self.0 = id;
32 }
33
34 #[inline]
35 fn finish(&self) -> u64 {
36 self.0
37 }
38}
39
40/// A type map of protocol extensions.
41///
42/// `Extensions` can be used by `Request` and `Response` to store
43/// extra data derived from the underlying protocol.
44#[derive(Default)]
45pub struct Extensions {
46 // If extensions are never used, no need to carry around an empty HashMap.
47 // That's 3 words. Instead, this is only 1 word.
48 map: Option<Box<AnyMap>>,
49}
50
51impl Extensions {
52 /// Create an empty `Extensions`.
53 #[inline]
54 #[must_use]
55 pub const fn new() -> Extensions {
56 Extensions { map: None }
57 }
58
59 /// Insert a type into this `Extensions`.
60 ///
61 /// If a extension of this type already existed, it will
62 /// be returned.
63 ///
64 /// # Example
65 ///
66 /// ```
67 /// # use cat_dev::net::Extensions;
68 /// let mut ext = Extensions::new();
69 /// assert!(ext.insert(5i32).is_none());
70 /// assert!(ext.insert(4u8).is_none());
71 /// assert_eq!(ext.insert(9i32), Some(5i32));
72 /// ```
73 pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) -> Option<T> {
74 self.map
75 .get_or_insert_with(Box::default)
76 .insert(TypeId::of::<T>(), Box::new(val))
77 .and_then(|boxed| {
78 (boxed as Box<dyn Any + 'static>)
79 .downcast()
80 .ok()
81 .map(|boxed| *boxed)
82 })
83 }
84
85 /// Get a reference to a type previously inserted on this `Extensions`.
86 ///
87 /// # Example
88 ///
89 /// ```
90 /// # use cat_dev::net::Extensions;
91 /// let mut ext = Extensions::new();
92 /// assert!(ext.get::<i32>().is_none());
93 /// ext.insert(5i32);
94 ///
95 /// assert_eq!(ext.get::<i32>(), Some(&5i32));
96 /// ```
97 #[must_use]
98 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
99 self.map
100 .as_ref()
101 .and_then(|map| map.get(&TypeId::of::<T>()))
102 .and_then(|boxed| (&**boxed as &(dyn Any + 'static)).downcast_ref())
103 }
104
105 /// Get a mutable reference to a type previously inserted on this `Extensions`.
106 ///
107 /// # Example
108 ///
109 /// ```
110 /// # use cat_dev::net::Extensions;
111 /// let mut ext = Extensions::new();
112 /// ext.insert(String::from("Hello"));
113 /// ext.get_mut::<String>().unwrap().push_str(" World");
114 ///
115 /// assert_eq!(ext.get::<String>().unwrap(), "Hello World");
116 /// ```
117 #[must_use]
118 pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
119 self.map
120 .as_mut()
121 .and_then(|map| map.get_mut(&TypeId::of::<T>()))
122 .and_then(|boxed| (&mut **boxed as &mut (dyn Any + 'static)).downcast_mut())
123 }
124
125 /// Remove a type from this `Extensions`.
126 ///
127 /// If a extension of this type existed, it will be returned.
128 ///
129 /// # Example
130 ///
131 /// ```
132 /// # use cat_dev::net::Extensions;
133 /// let mut ext = Extensions::new();
134 /// ext.insert(5i32);
135 /// assert_eq!(ext.remove::<i32>(), Some(5i32));
136 /// assert!(ext.get::<i32>().is_none());
137 /// ```
138 #[must_use]
139 pub fn remove<T: Send + Sync + 'static>(&mut self) -> Option<T> {
140 self.map
141 .as_mut()
142 .and_then(|map| map.remove(&TypeId::of::<T>()))
143 .and_then(|boxed| {
144 (boxed as Box<dyn Any + 'static>)
145 .downcast()
146 .ok()
147 .map(|boxed| *boxed)
148 })
149 }
150
151 /// Clear the `Extensions` of all inserted extensions.
152 ///
153 /// # Example
154 ///
155 /// ```
156 /// # use cat_dev::net::Extensions;
157 /// let mut ext = Extensions::new();
158 /// ext.insert(5i32);
159 /// ext.clear();
160 ///
161 /// assert!(ext.get::<i32>().is_none());
162 /// ```
163 #[inline]
164 pub fn clear(&mut self) {
165 if let Some(ref mut map) = self.map {
166 map.clear();
167 }
168 }
169
170 /// Check whether the extension set is empty or not.
171 ///
172 /// # Example
173 ///
174 /// ```
175 /// # use cat_dev::net::Extensions;
176 /// let mut ext = Extensions::new();
177 /// assert!(ext.is_empty());
178 /// ext.insert(5i32);
179 /// assert!(!ext.is_empty());
180 /// ```
181 #[inline]
182 #[must_use]
183 pub fn is_empty(&self) -> bool {
184 self.map.as_ref().is_none_or(|map| map.is_empty())
185 }
186
187 /// Get the numer of extensions available.
188 ///
189 /// # Example
190 ///
191 /// ```
192 /// # use cat_dev::net::Extensions;
193 /// let mut ext = Extensions::new();
194 /// assert_eq!(ext.len(), 0);
195 /// ext.insert(5i32);
196 /// assert_eq!(ext.len(), 1);
197 /// ```
198 #[inline]
199 #[must_use]
200 pub fn len(&self) -> usize {
201 self.map.as_ref().map_or(0, |map| map.len())
202 }
203
204 /// Extends `self` with another `Extensions`.
205 ///
206 /// If an instance of a specific type exists in both, the one in `self` is overwritten with the
207 /// one from `other`.
208 ///
209 /// # Example
210 ///
211 /// ```
212 /// # use cat_dev::net::Extensions;
213 /// let mut ext_a = Extensions::new();
214 /// ext_a.insert(8u8);
215 /// ext_a.insert(16u16);
216 ///
217 /// let mut ext_b = Extensions::new();
218 /// ext_b.insert(4u8);
219 /// ext_b.insert("hello");
220 ///
221 /// ext_a.extend(ext_b);
222 /// assert_eq!(ext_a.len(), 3);
223 /// assert_eq!(ext_a.get::<u8>(), Some(&4u8));
224 /// assert_eq!(ext_a.get::<u16>(), Some(&16u16));
225 /// assert_eq!(ext_a.get::<&'static str>().copied(), Some("hello"));
226 /// ```
227 pub fn extend(&mut self, other: Self) {
228 if let Some(other) = other.map {
229 if let Some(map) = &mut self.map {
230 map.extend(*other);
231 } else {
232 self.map = Some(other);
233 }
234 }
235 }
236}
237
238impl fmt::Debug for Extensions {
239 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
240 f.debug_struct("Extensions").finish()
241 }
242}
243
244#[cfg(test)]
245#[test]
246fn test_extensions() {
247 #[derive(Debug, PartialEq)]
248 struct MyType(i32);
249
250 let mut extensions = Extensions::new();
251
252 extensions.insert(5i32);
253 extensions.insert(MyType(10));
254
255 assert_eq!(extensions.get(), Some(&5i32));
256 assert_eq!(extensions.get_mut(), Some(&mut 5i32));
257
258 assert_eq!(extensions.remove::<i32>(), Some(5i32));
259 assert!(extensions.get::<i32>().is_none());
260
261 assert_eq!(extensions.get::<bool>(), None);
262 assert_eq!(extensions.get(), Some(&MyType(10)));
263}