Skip to main content

reinhardt_http/
extensions.rs

1//! Type-safe extensions for Request
2//!
3//! Provides a simple type-safe storage mechanism for arbitrary data
4//! that can be attached to requests.
5
6use std::any::{Any, TypeId};
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9
10/// Whether the current user is authenticated.
11/// Newtype wrapper to avoid `TypeId` collision with other bool values in extensions.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct IsAuthenticated(pub bool);
14
15/// Whether the current user has admin privileges (staff or superuser).
16/// Newtype wrapper to avoid `TypeId` collision with other bool values in extensions.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub struct IsAdmin(pub bool);
19
20/// Whether the current user account is active.
21/// Newtype wrapper to avoid `TypeId` collision with other bool values in extensions.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub struct IsActive(pub bool);
24
25/// Type-safe extension storage
26///
27/// # Clone semantics
28///
29/// `Extensions` uses `Arc<Mutex<HashMap>>` internally. Cloning an
30/// `Extensions` creates a **shared** reference to the same backing
31/// store — it does NOT deep-copy the stored values. Mutations through
32/// one clone are visible through all other clones.
33#[derive(Clone, Default)]
34pub struct Extensions {
35	map: Arc<Mutex<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>,
36}
37
38impl Extensions {
39	/// Create a new Extensions instance
40	///
41	/// # Examples
42	///
43	/// ```
44	/// use reinhardt_http::Extensions;
45	///
46	/// let extensions = Extensions::new();
47	/// assert!(!extensions.contains::<String>());
48	/// ```
49	pub fn new() -> Self {
50		Self {
51			map: Arc::new(Mutex::new(HashMap::new())),
52		}
53	}
54	/// Insert a value into extensions
55	///
56	/// # Examples
57	///
58	/// ```
59	/// use reinhardt_http::Extensions;
60	///
61	/// let extensions = Extensions::new();
62	/// extensions.insert(42u32);
63	/// extensions.insert("hello".to_string());
64	///
65	/// assert!(extensions.contains::<u32>());
66	/// assert!(extensions.contains::<String>());
67	/// ```
68	pub fn insert<T: Send + Sync + 'static>(&self, value: T) {
69		let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
70		map.insert(TypeId::of::<T>(), Box::new(value));
71	}
72	/// Get a cloned value from extensions
73	///
74	/// # Examples
75	///
76	/// ```
77	/// use reinhardt_http::Extensions;
78	///
79	/// let extensions = Extensions::new();
80	/// extensions.insert(42u32);
81	///
82	/// assert_eq!(extensions.get::<u32>(), Some(42));
83	/// assert_eq!(extensions.get::<String>(), None);
84	/// ```
85	pub fn get<T>(&self) -> Option<T>
86	where
87		T: Clone + Send + Sync + 'static,
88	{
89		let map = self.map.lock().unwrap_or_else(|e| e.into_inner());
90		map.get(&TypeId::of::<T>())
91			.and_then(|boxed| boxed.downcast_ref::<T>())
92			.cloned()
93	}
94	/// Check if a value of the given type exists
95	///
96	/// # Examples
97	///
98	/// ```
99	/// use reinhardt_http::Extensions;
100	///
101	/// let extensions = Extensions::new();
102	/// extensions.insert("hello".to_string());
103	///
104	/// assert!(extensions.contains::<String>());
105	/// assert!(!extensions.contains::<u32>());
106	/// ```
107	pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
108		let map = self.map.lock().unwrap_or_else(|e| e.into_inner());
109		map.contains_key(&TypeId::of::<T>())
110	}
111	/// Remove a value from extensions and return it
112	///
113	/// # Examples
114	///
115	/// ```
116	/// use reinhardt_http::Extensions;
117	///
118	/// let extensions = Extensions::new();
119	/// extensions.insert(42u32);
120	///
121	/// assert_eq!(extensions.remove::<u32>(), Some(42));
122	/// assert!(!extensions.contains::<u32>());
123	/// assert_eq!(extensions.remove::<u32>(), None);
124	/// ```
125	pub fn remove<T>(&self) -> Option<T>
126	where
127		T: Send + Sync + 'static,
128	{
129		let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
130		let boxed = map.remove(&TypeId::of::<T>())?;
131		match boxed.downcast::<T>() {
132			Ok(val) => Some(*val),
133			Err(boxed) => {
134				// Re-insert to prevent value loss on type mismatch
135				map.insert(TypeId::of::<T>(), boxed);
136				None
137			}
138		}
139	}
140	/// Clear all extensions
141	///
142	/// # Examples
143	///
144	/// ```
145	/// use reinhardt_http::Extensions;
146	///
147	/// let extensions = Extensions::new();
148	/// extensions.insert(42u32);
149	/// extensions.insert("hello".to_string());
150	///
151	/// assert!(extensions.contains::<u32>());
152	/// assert!(extensions.contains::<String>());
153	///
154	/// extensions.clear();
155	///
156	/// assert!(!extensions.contains::<u32>());
157	/// assert!(!extensions.contains::<String>());
158	/// ```
159	pub fn clear(&self) {
160		let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
161		map.clear();
162	}
163}
164
165#[cfg(test)]
166mod tests {
167	use super::*;
168	use rstest::rstest;
169
170	#[derive(Clone, Debug, PartialEq)]
171	struct TestData {
172		value: String,
173	}
174
175	#[rstest]
176	fn test_newtype_bools_coexist_in_extensions() {
177		// Arrange
178		let extensions = Extensions::new();
179
180		// Act
181		extensions.insert(IsAuthenticated(true));
182		extensions.insert(IsAdmin(false));
183		extensions.insert(IsActive(true));
184
185		// Assert
186		assert_eq!(
187			extensions.get::<IsAuthenticated>(),
188			Some(IsAuthenticated(true))
189		);
190		assert_eq!(extensions.get::<IsAdmin>(), Some(IsAdmin(false)));
191		assert_eq!(extensions.get::<IsActive>(), Some(IsActive(true)));
192	}
193
194	#[test]
195	fn test_insert_and_get() {
196		let extensions = Extensions::new();
197		let data = TestData {
198			value: "test".to_string(),
199		};
200
201		extensions.insert(data.clone());
202		let retrieved = extensions.get::<TestData>();
203
204		assert_eq!(retrieved, Some(data));
205	}
206
207	#[test]
208	fn test_get_nonexistent() {
209		let extensions = Extensions::new();
210		let retrieved = extensions.get::<TestData>();
211
212		assert_eq!(retrieved, None);
213	}
214
215	#[test]
216	fn test_contains() {
217		let extensions = Extensions::new();
218		extensions.insert(TestData {
219			value: "test".to_string(),
220		});
221
222		assert!(extensions.contains::<TestData>());
223		assert!(!extensions.contains::<String>());
224	}
225
226	#[test]
227	fn test_remove() {
228		let extensions = Extensions::new();
229		let data = TestData {
230			value: "test".to_string(),
231		};
232
233		extensions.insert(data.clone());
234		let removed = extensions.remove::<TestData>();
235
236		assert_eq!(removed, Some(data));
237		assert!(!extensions.contains::<TestData>());
238	}
239
240	#[test]
241	fn test_clear() {
242		let extensions = Extensions::new();
243		extensions.insert(TestData {
244			value: "test".to_string(),
245		});
246		extensions.insert("another value".to_string());
247
248		extensions.clear();
249
250		assert!(!extensions.contains::<TestData>());
251		assert!(!extensions.contains::<String>());
252	}
253
254	#[test]
255	fn test_remove_wrong_type_preserves_value() {
256		// Arrange
257		let extensions = Extensions::new();
258		extensions.insert(42u32);
259
260		// Act - try to remove as wrong type
261		let removed = extensions.remove::<String>();
262
263		// Assert - removal fails and original value is preserved
264		assert_eq!(removed, None);
265		assert!(extensions.contains::<u32>());
266		assert_eq!(extensions.get::<u32>(), Some(42));
267	}
268
269	#[test]
270	fn test_multiple_types() {
271		let extensions = Extensions::new();
272		extensions.insert(TestData {
273			value: "test".to_string(),
274		});
275		extensions.insert(42u32);
276		extensions.insert("string value".to_string());
277
278		assert_eq!(
279			extensions.get::<TestData>(),
280			Some(TestData {
281				value: "test".to_string()
282			})
283		);
284		assert_eq!(extensions.get::<u32>(), Some(42));
285		assert_eq!(extensions.get::<String>(), Some("string value".to_string()));
286	}
287
288	#[test]
289	fn test_clone_shares_backing_store() {
290		// Arrange
291		let original = Extensions::new();
292		let cloned = original.clone();
293
294		// Act - insert via clone
295		cloned.insert(42u32);
296
297		// Assert - original sees the value
298		assert_eq!(original.get::<u32>(), Some(42));
299
300		// Act - remove via original
301		let removed = original.remove::<u32>();
302
303		// Assert - clone no longer sees it
304		assert_eq!(removed, Some(42));
305		assert!(!cloned.contains::<u32>());
306	}
307}