reinhardt_http/
extensions.rs1use std::any::{Any, TypeId};
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9
10#[derive(Clone, Default)]
12pub struct Extensions {
13 map: Arc<Mutex<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>,
14}
15
16impl Extensions {
17 pub fn new() -> Self {
28 Self {
29 map: Arc::new(Mutex::new(HashMap::new())),
30 }
31 }
32 pub fn insert<T: Send + Sync + 'static>(&self, value: T) {
47 let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
48 map.insert(TypeId::of::<T>(), Box::new(value));
49 }
50 pub fn get<T>(&self) -> Option<T>
64 where
65 T: Clone + Send + Sync + 'static,
66 {
67 let map = self.map.lock().unwrap_or_else(|e| e.into_inner());
68 map.get(&TypeId::of::<T>())
69 .and_then(|boxed| boxed.downcast_ref::<T>())
70 .cloned()
71 }
72 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
86 let map = self.map.lock().unwrap_or_else(|e| e.into_inner());
87 map.contains_key(&TypeId::of::<T>())
88 }
89 pub fn remove<T>(&self) -> Option<T>
104 where
105 T: Clone + Send + Sync + 'static,
106 {
107 let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
108 map.remove(&TypeId::of::<T>())
109 .and_then(|boxed| boxed.downcast_ref::<T>().cloned())
110 }
111 pub fn clear(&self) {
131 let mut map = self.map.lock().unwrap_or_else(|e| e.into_inner());
132 map.clear();
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[derive(Clone, Debug, PartialEq)]
141 struct TestData {
142 value: String,
143 }
144
145 #[test]
146 fn test_insert_and_get() {
147 let extensions = Extensions::new();
148 let data = TestData {
149 value: "test".to_string(),
150 };
151
152 extensions.insert(data.clone());
153 let retrieved = extensions.get::<TestData>();
154
155 assert_eq!(retrieved, Some(data));
156 }
157
158 #[test]
159 fn test_get_nonexistent() {
160 let extensions = Extensions::new();
161 let retrieved = extensions.get::<TestData>();
162
163 assert_eq!(retrieved, None);
164 }
165
166 #[test]
167 fn test_contains() {
168 let extensions = Extensions::new();
169 extensions.insert(TestData {
170 value: "test".to_string(),
171 });
172
173 assert!(extensions.contains::<TestData>());
174 assert!(!extensions.contains::<String>());
175 }
176
177 #[test]
178 fn test_remove() {
179 let extensions = Extensions::new();
180 let data = TestData {
181 value: "test".to_string(),
182 };
183
184 extensions.insert(data.clone());
185 let removed = extensions.remove::<TestData>();
186
187 assert_eq!(removed, Some(data));
188 assert!(!extensions.contains::<TestData>());
189 }
190
191 #[test]
192 fn test_clear() {
193 let extensions = Extensions::new();
194 extensions.insert(TestData {
195 value: "test".to_string(),
196 });
197 extensions.insert("another value".to_string());
198
199 extensions.clear();
200
201 assert!(!extensions.contains::<TestData>());
202 assert!(!extensions.contains::<String>());
203 }
204
205 #[test]
206 fn test_multiple_types() {
207 let extensions = Extensions::new();
208 extensions.insert(TestData {
209 value: "test".to_string(),
210 });
211 extensions.insert(42u32);
212 extensions.insert("string value".to_string());
213
214 assert_eq!(
215 extensions.get::<TestData>(),
216 Some(TestData {
217 value: "test".to_string()
218 })
219 );
220 assert_eq!(extensions.get::<u32>(), Some(42));
221 assert_eq!(extensions.get::<String>(), Some("string value".to_string()));
222 }
223}