reinhardt_http/
extensions.rs1use std::any::{Any, TypeId};
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct IsAuthenticated(pub bool);
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub struct IsAdmin(pub bool);
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub struct IsActive(pub bool);
24
25#[derive(Clone, Default)]
34pub struct Extensions {
35 map: Arc<Mutex<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>,
36}
37
38impl Extensions {
39 pub fn new() -> Self {
50 Self {
51 map: Arc::new(Mutex::new(HashMap::new())),
52 }
53 }
54 pub fn insert<T: Send + Sync + 'static>(&self, value: T) {
69 let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
70 map.insert(TypeId::of::<T>(), Box::new(value));
71 }
72 pub fn get<T>(&self) -> Option<T>
86 where
87 T: Clone + Send + Sync + 'static,
88 {
89 let map = self.map.lock().unwrap_or_else(|e| e.into_inner());
90 map.get(&TypeId::of::<T>())
91 .and_then(|boxed| boxed.downcast_ref::<T>())
92 .cloned()
93 }
94 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
108 let map = self.map.lock().unwrap_or_else(|e| e.into_inner());
109 map.contains_key(&TypeId::of::<T>())
110 }
111 pub fn remove<T>(&self) -> Option<T>
126 where
127 T: Send + Sync + 'static,
128 {
129 let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
130 let boxed = map.remove(&TypeId::of::<T>())?;
131 match boxed.downcast::<T>() {
132 Ok(val) => Some(*val),
133 Err(boxed) => {
134 map.insert(TypeId::of::<T>(), boxed);
136 None
137 }
138 }
139 }
140 pub fn clear(&self) {
160 let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
161 map.clear();
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use rstest::rstest;
169
170 #[derive(Clone, Debug, PartialEq)]
171 struct TestData {
172 value: String,
173 }
174
175 #[rstest]
176 fn test_newtype_bools_coexist_in_extensions() {
177 let extensions = Extensions::new();
179
180 extensions.insert(IsAuthenticated(true));
182 extensions.insert(IsAdmin(false));
183 extensions.insert(IsActive(true));
184
185 assert_eq!(
187 extensions.get::<IsAuthenticated>(),
188 Some(IsAuthenticated(true))
189 );
190 assert_eq!(extensions.get::<IsAdmin>(), Some(IsAdmin(false)));
191 assert_eq!(extensions.get::<IsActive>(), Some(IsActive(true)));
192 }
193
194 #[test]
195 fn test_insert_and_get() {
196 let extensions = Extensions::new();
197 let data = TestData {
198 value: "test".to_string(),
199 };
200
201 extensions.insert(data.clone());
202 let retrieved = extensions.get::<TestData>();
203
204 assert_eq!(retrieved, Some(data));
205 }
206
207 #[test]
208 fn test_get_nonexistent() {
209 let extensions = Extensions::new();
210 let retrieved = extensions.get::<TestData>();
211
212 assert_eq!(retrieved, None);
213 }
214
215 #[test]
216 fn test_contains() {
217 let extensions = Extensions::new();
218 extensions.insert(TestData {
219 value: "test".to_string(),
220 });
221
222 assert!(extensions.contains::<TestData>());
223 assert!(!extensions.contains::<String>());
224 }
225
226 #[test]
227 fn test_remove() {
228 let extensions = Extensions::new();
229 let data = TestData {
230 value: "test".to_string(),
231 };
232
233 extensions.insert(data.clone());
234 let removed = extensions.remove::<TestData>();
235
236 assert_eq!(removed, Some(data));
237 assert!(!extensions.contains::<TestData>());
238 }
239
240 #[test]
241 fn test_clear() {
242 let extensions = Extensions::new();
243 extensions.insert(TestData {
244 value: "test".to_string(),
245 });
246 extensions.insert("another value".to_string());
247
248 extensions.clear();
249
250 assert!(!extensions.contains::<TestData>());
251 assert!(!extensions.contains::<String>());
252 }
253
254 #[test]
255 fn test_remove_wrong_type_preserves_value() {
256 let extensions = Extensions::new();
258 extensions.insert(42u32);
259
260 let removed = extensions.remove::<String>();
262
263 assert_eq!(removed, None);
265 assert!(extensions.contains::<u32>());
266 assert_eq!(extensions.get::<u32>(), Some(42));
267 }
268
269 #[test]
270 fn test_multiple_types() {
271 let extensions = Extensions::new();
272 extensions.insert(TestData {
273 value: "test".to_string(),
274 });
275 extensions.insert(42u32);
276 extensions.insert("string value".to_string());
277
278 assert_eq!(
279 extensions.get::<TestData>(),
280 Some(TestData {
281 value: "test".to_string()
282 })
283 );
284 assert_eq!(extensions.get::<u32>(), Some(42));
285 assert_eq!(extensions.get::<String>(), Some("string value".to_string()));
286 }
287
288 #[test]
289 fn test_clone_shares_backing_store() {
290 let original = Extensions::new();
292 let cloned = original.clone();
293
294 cloned.insert(42u32);
296
297 assert_eq!(original.get::<u32>(), Some(42));
299
300 let removed = original.remove::<u32>();
302
303 assert_eq!(removed, Some(42));
305 assert!(!cloned.contains::<u32>());
306 }
307}