cat_dev/net/
ext_map.rs

1//! A keeper of "extensions" to a request, this is a copy of hyper extensions.
2//!
3//! Hyperium/HTTP was dual licensed under MIT/APACHE-2 when we copied this on
4//! 9/15/2022
5//!
6//! <https://github.com/hyperium/http/blob/master/LICENSE-MIT>
7//! <https://github.com/hyperium/http/blob/master/LICENSE-APACHE>
8//!
9//! This was to keep the same interface as our axum extensions.
10
11use std::any::{Any, TypeId};
12use std::collections::HashMap;
13use std::fmt;
14use std::hash::{BuildHasherDefault, Hasher};
15
16type AnyMap = HashMap<TypeId, Box<dyn Any + Send + Sync>, BuildHasherDefault<IdHasher>>;
17
18// With TypeIds as keys, there's no need to hash them. They are already hashes
19// themselves, coming from the compiler. The IdHasher just holds the u64 of
20// the TypeId, and then returns it, instead of doing any bit fiddling.
21#[derive(Default)]
22struct IdHasher(u64);
23
24impl Hasher for IdHasher {
25	fn write(&mut self, _: &[u8]) {
26		unreachable!("TypeId calls write_u64");
27	}
28
29	#[inline]
30	fn write_u64(&mut self, id: u64) {
31		self.0 = id;
32	}
33
34	#[inline]
35	fn finish(&self) -> u64 {
36		self.0
37	}
38}
39
40/// A type map of protocol extensions.
41///
42/// `Extensions` can be used by `Request` and `Response` to store
43/// extra data derived from the underlying protocol.
44#[derive(Default)]
45pub struct Extensions {
46	// If extensions are never used, no need to carry around an empty HashMap.
47	// That's 3 words. Instead, this is only 1 word.
48	map: Option<Box<AnyMap>>,
49}
50
51impl Extensions {
52	/// Create an empty `Extensions`.
53	#[inline]
54	#[must_use]
55	pub const fn new() -> Extensions {
56		Extensions { map: None }
57	}
58
59	/// Insert a type into this `Extensions`.
60	///
61	/// If a extension of this type already existed, it will
62	/// be returned.
63	///
64	/// # Example
65	///
66	/// ```
67	/// # use cat_dev::net::Extensions;
68	/// let mut ext = Extensions::new();
69	/// assert!(ext.insert(5i32).is_none());
70	/// assert!(ext.insert(4u8).is_none());
71	/// assert_eq!(ext.insert(9i32), Some(5i32));
72	/// ```
73	pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) -> Option<T> {
74		self.map
75			.get_or_insert_with(Box::default)
76			.insert(TypeId::of::<T>(), Box::new(val))
77			.and_then(|boxed| {
78				(boxed as Box<dyn Any + 'static>)
79					.downcast()
80					.ok()
81					.map(|boxed| *boxed)
82			})
83	}
84
85	/// Get a reference to a type previously inserted on this `Extensions`.
86	///
87	/// # Example
88	///
89	/// ```
90	/// # use cat_dev::net::Extensions;
91	/// let mut ext = Extensions::new();
92	/// assert!(ext.get::<i32>().is_none());
93	/// ext.insert(5i32);
94	///
95	/// assert_eq!(ext.get::<i32>(), Some(&5i32));
96	/// ```
97	#[must_use]
98	pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
99		self.map
100			.as_ref()
101			.and_then(|map| map.get(&TypeId::of::<T>()))
102			.and_then(|boxed| (&**boxed as &(dyn Any + 'static)).downcast_ref())
103	}
104
105	/// Get a mutable reference to a type previously inserted on this `Extensions`.
106	///
107	/// # Example
108	///
109	/// ```
110	/// # use cat_dev::net::Extensions;
111	/// let mut ext = Extensions::new();
112	/// ext.insert(String::from("Hello"));
113	/// ext.get_mut::<String>().unwrap().push_str(" World");
114	///
115	/// assert_eq!(ext.get::<String>().unwrap(), "Hello World");
116	/// ```
117	#[must_use]
118	pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
119		self.map
120			.as_mut()
121			.and_then(|map| map.get_mut(&TypeId::of::<T>()))
122			.and_then(|boxed| (&mut **boxed as &mut (dyn Any + 'static)).downcast_mut())
123	}
124
125	/// Remove a type from this `Extensions`.
126	///
127	/// If a extension of this type existed, it will be returned.
128	///
129	/// # Example
130	///
131	/// ```
132	/// # use cat_dev::net::Extensions;
133	/// let mut ext = Extensions::new();
134	/// ext.insert(5i32);
135	/// assert_eq!(ext.remove::<i32>(), Some(5i32));
136	/// assert!(ext.get::<i32>().is_none());
137	/// ```
138	#[must_use]
139	pub fn remove<T: Send + Sync + 'static>(&mut self) -> Option<T> {
140		self.map
141			.as_mut()
142			.and_then(|map| map.remove(&TypeId::of::<T>()))
143			.and_then(|boxed| {
144				(boxed as Box<dyn Any + 'static>)
145					.downcast()
146					.ok()
147					.map(|boxed| *boxed)
148			})
149	}
150
151	/// Clear the `Extensions` of all inserted extensions.
152	///
153	/// # Example
154	///
155	/// ```
156	/// # use cat_dev::net::Extensions;
157	/// let mut ext = Extensions::new();
158	/// ext.insert(5i32);
159	/// ext.clear();
160	///
161	/// assert!(ext.get::<i32>().is_none());
162	/// ```
163	#[inline]
164	pub fn clear(&mut self) {
165		if let Some(ref mut map) = self.map {
166			map.clear();
167		}
168	}
169
170	/// Check whether the extension set is empty or not.
171	///
172	/// # Example
173	///
174	/// ```
175	/// # use cat_dev::net::Extensions;
176	/// let mut ext = Extensions::new();
177	/// assert!(ext.is_empty());
178	/// ext.insert(5i32);
179	/// assert!(!ext.is_empty());
180	/// ```
181	#[inline]
182	#[must_use]
183	pub fn is_empty(&self) -> bool {
184		self.map.as_ref().is_none_or(|map| map.is_empty())
185	}
186
187	/// Get the numer of extensions available.
188	///
189	/// # Example
190	///
191	/// ```
192	/// # use cat_dev::net::Extensions;
193	/// let mut ext = Extensions::new();
194	/// assert_eq!(ext.len(), 0);
195	/// ext.insert(5i32);
196	/// assert_eq!(ext.len(), 1);
197	/// ```
198	#[inline]
199	#[must_use]
200	pub fn len(&self) -> usize {
201		self.map.as_ref().map_or(0, |map| map.len())
202	}
203
204	/// Extends `self` with another `Extensions`.
205	///
206	/// If an instance of a specific type exists in both, the one in `self` is overwritten with the
207	/// one from `other`.
208	///
209	/// # Example
210	///
211	/// ```
212	/// # use cat_dev::net::Extensions;
213	/// let mut ext_a = Extensions::new();
214	/// ext_a.insert(8u8);
215	/// ext_a.insert(16u16);
216	///
217	/// let mut ext_b = Extensions::new();
218	/// ext_b.insert(4u8);
219	/// ext_b.insert("hello");
220	///
221	/// ext_a.extend(ext_b);
222	/// assert_eq!(ext_a.len(), 3);
223	/// assert_eq!(ext_a.get::<u8>(), Some(&4u8));
224	/// assert_eq!(ext_a.get::<u16>(), Some(&16u16));
225	/// assert_eq!(ext_a.get::<&'static str>().copied(), Some("hello"));
226	/// ```
227	pub fn extend(&mut self, other: Self) {
228		if let Some(other) = other.map {
229			if let Some(map) = &mut self.map {
230				map.extend(*other);
231			} else {
232				self.map = Some(other);
233			}
234		}
235	}
236}
237
238impl fmt::Debug for Extensions {
239	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
240		f.debug_struct("Extensions").finish()
241	}
242}
243
244#[cfg(test)]
245#[test]
246fn test_extensions() {
247	#[derive(Debug, PartialEq)]
248	struct MyType(i32);
249
250	let mut extensions = Extensions::new();
251
252	extensions.insert(5i32);
253	extensions.insert(MyType(10));
254
255	assert_eq!(extensions.get(), Some(&5i32));
256	assert_eq!(extensions.get_mut(), Some(&mut 5i32));
257
258	assert_eq!(extensions.remove::<i32>(), Some(5i32));
259	assert!(extensions.get::<i32>().is_none());
260
261	assert_eq!(extensions.get::<bool>(), None);
262	assert_eq!(extensions.get(), Some(&MyType(10)));
263}