Skip to main content

reinhardt_http/
extensions.rs

1//! Type-safe extensions for Request
2//!
3//! Provides a simple type-safe storage mechanism for arbitrary data
4//! that can be attached to requests.
5
6use std::any::{Any, TypeId};
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9
10/// Type-safe extension storage
11#[derive(Clone, Default)]
12pub struct Extensions {
13	map: Arc<Mutex<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>,
14}
15
16impl Extensions {
17	/// Create a new Extensions instance
18	///
19	/// # Examples
20	///
21	/// ```
22	/// use reinhardt_http::Extensions;
23	///
24	/// let extensions = Extensions::new();
25	/// assert!(!extensions.contains::<String>());
26	/// ```
27	pub fn new() -> Self {
28		Self {
29			map: Arc::new(Mutex::new(HashMap::new())),
30		}
31	}
32	/// Insert a value into extensions
33	///
34	/// # Examples
35	///
36	/// ```
37	/// use reinhardt_http::Extensions;
38	///
39	/// let extensions = Extensions::new();
40	/// extensions.insert(42u32);
41	/// extensions.insert("hello".to_string());
42	///
43	/// assert!(extensions.contains::<u32>());
44	/// assert!(extensions.contains::<String>());
45	/// ```
46	pub fn insert<T: Send + Sync + 'static>(&self, value: T) {
47		let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
48		map.insert(TypeId::of::<T>(), Box::new(value));
49	}
50	/// Get a cloned value from extensions
51	///
52	/// # Examples
53	///
54	/// ```
55	/// use reinhardt_http::Extensions;
56	///
57	/// let extensions = Extensions::new();
58	/// extensions.insert(42u32);
59	///
60	/// assert_eq!(extensions.get::<u32>(), Some(42));
61	/// assert_eq!(extensions.get::<String>(), None);
62	/// ```
63	pub fn get<T>(&self) -> Option<T>
64	where
65		T: Clone + Send + Sync + 'static,
66	{
67		let map = self.map.lock().unwrap_or_else(|e| e.into_inner());
68		map.get(&TypeId::of::<T>())
69			.and_then(|boxed| boxed.downcast_ref::<T>())
70			.cloned()
71	}
72	/// Check if a value of the given type exists
73	///
74	/// # Examples
75	///
76	/// ```
77	/// use reinhardt_http::Extensions;
78	///
79	/// let extensions = Extensions::new();
80	/// extensions.insert("hello".to_string());
81	///
82	/// assert!(extensions.contains::<String>());
83	/// assert!(!extensions.contains::<u32>());
84	/// ```
85	pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
86		let map = self.map.lock().unwrap_or_else(|e| e.into_inner());
87		map.contains_key(&TypeId::of::<T>())
88	}
89	/// Remove a value from extensions and return it
90	///
91	/// # Examples
92	///
93	/// ```
94	/// use reinhardt_http::Extensions;
95	///
96	/// let extensions = Extensions::new();
97	/// extensions.insert(42u32);
98	///
99	/// assert_eq!(extensions.remove::<u32>(), Some(42));
100	/// assert!(!extensions.contains::<u32>());
101	/// assert_eq!(extensions.remove::<u32>(), None);
102	/// ```
103	pub fn remove<T>(&self) -> Option<T>
104	where
105		T: Clone + Send + Sync + 'static,
106	{
107		let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
108		map.remove(&TypeId::of::<T>())
109			.and_then(|boxed| boxed.downcast_ref::<T>().cloned())
110	}
111	/// Clear all extensions
112	///
113	/// # Examples
114	///
115	/// ```
116	/// use reinhardt_http::Extensions;
117	///
118	/// let extensions = Extensions::new();
119	/// extensions.insert(42u32);
120	/// extensions.insert("hello".to_string());
121	///
122	/// assert!(extensions.contains::<u32>());
123	/// assert!(extensions.contains::<String>());
124	///
125	/// extensions.clear();
126	///
127	/// assert!(!extensions.contains::<u32>());
128	/// assert!(!extensions.contains::<String>());
129	/// ```
130	pub fn clear(&self) {
131		let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
132		map.clear();
133	}
134}
135
136#[cfg(test)]
137mod tests {
138	use super::*;
139
140	#[derive(Clone, Debug, PartialEq)]
141	struct TestData {
142		value: String,
143	}
144
145	#[test]
146	fn test_insert_and_get() {
147		let extensions = Extensions::new();
148		let data = TestData {
149			value: "test".to_string(),
150		};
151
152		extensions.insert(data.clone());
153		let retrieved = extensions.get::<TestData>();
154
155		assert_eq!(retrieved, Some(data));
156	}
157
158	#[test]
159	fn test_get_nonexistent() {
160		let extensions = Extensions::new();
161		let retrieved = extensions.get::<TestData>();
162
163		assert_eq!(retrieved, None);
164	}
165
166	#[test]
167	fn test_contains() {
168		let extensions = Extensions::new();
169		extensions.insert(TestData {
170			value: "test".to_string(),
171		});
172
173		assert!(extensions.contains::<TestData>());
174		assert!(!extensions.contains::<String>());
175	}
176
177	#[test]
178	fn test_remove() {
179		let extensions = Extensions::new();
180		let data = TestData {
181			value: "test".to_string(),
182		};
183
184		extensions.insert(data.clone());
185		let removed = extensions.remove::<TestData>();
186
187		assert_eq!(removed, Some(data));
188		assert!(!extensions.contains::<TestData>());
189	}
190
191	#[test]
192	fn test_clear() {
193		let extensions = Extensions::new();
194		extensions.insert(TestData {
195			value: "test".to_string(),
196		});
197		extensions.insert("another value".to_string());
198
199		extensions.clear();
200
201		assert!(!extensions.contains::<TestData>());
202		assert!(!extensions.contains::<String>());
203	}
204
205	#[test]
206	fn test_multiple_types() {
207		let extensions = Extensions::new();
208		extensions.insert(TestData {
209			value: "test".to_string(),
210		});
211		extensions.insert(42u32);
212		extensions.insert("string value".to_string());
213
214		assert_eq!(
215			extensions.get::<TestData>(),
216			Some(TestData {
217				value: "test".to_string()
218			})
219		);
220		assert_eq!(extensions.get::<u32>(), Some(42));
221		assert_eq!(extensions.get::<String>(), Some("string value".to_string()));
222	}
223}