armature_core/extensions.rs
1//! Type-safe request extensions for zero-cost state extraction.
2//!
3//! This module provides a way to attach typed data to requests without
4//! runtime type checking overhead. Unlike the DI container which uses
5//! `Any` downcasting, extensions use a type-erased map that only requires
6//! type checks at the point of insertion, not retrieval.
7//!
8//! # Performance
9//!
10//! - **Insertion**: O(1) with a single type check
11//! - **Retrieval**: O(1) with no type checking (uses pre-verified TypeId)
12//! - **Memory**: One `Arc<T>` per extension type
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use armature_core::{Extensions, State};
18//!
19//! // Application state
20//! #[derive(Clone)]
21//! struct AppState {
22//! db_pool: Pool,
23//! }
24//!
25//! // Insert state at startup
26//! let mut extensions = Extensions::new();
27//! extensions.insert(AppState { db_pool });
28//!
29//! // Extract in handler (zero-cost after setup)
30//! async fn handler(state: State<AppState>) -> Result<HttpResponse, Error> {
31//! let pool = &state.db_pool;
32//! // ...
33//! }
34//! ```
35
36use std::any::{Any, TypeId};
37use std::collections::HashMap;
38use std::sync::Arc;
39
40/// Type-safe extensions container.
41///
42/// Stores typed values keyed by `TypeId` for O(1) retrieval without
43/// runtime type checking after initial insertion.
44#[derive(Clone, Default)]
45pub struct Extensions {
46 /// Internal storage: TypeId -> type-erased Arc
47 /// The Arc contains the actual typed value, wrapped for thread-safety
48 map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
49}
50
51impl Extensions {
52 /// Create a new empty extensions container.
53 #[inline]
54 pub fn new() -> Self {
55 Self {
56 map: HashMap::new(),
57 }
58 }
59
60 /// Create with pre-allocated capacity.
61 #[inline]
62 pub fn with_capacity(capacity: usize) -> Self {
63 Self {
64 map: HashMap::with_capacity(capacity),
65 }
66 }
67
68 /// Insert a typed value into the extensions.
69 ///
70 /// If a value of this type already exists, it is replaced.
71 ///
72 /// # Example
73 ///
74 /// ```rust
75 /// use armature_core::Extensions;
76 ///
77 /// let mut ext = Extensions::new();
78 /// ext.insert(42i32);
79 /// ext.insert("hello");
80 /// ```
81 #[inline]
82 pub fn insert<T: Send + Sync + 'static>(&mut self, value: T) {
83 let type_id = TypeId::of::<T>();
84 let arc = Arc::new(value) as Arc<dyn Any + Send + Sync>;
85 self.map.insert(type_id, arc);
86 }
87
88 /// Insert an Arc-wrapped value directly.
89 ///
90 /// This is more efficient when you already have an Arc.
91 #[inline]
92 pub fn insert_arc<T: Send + Sync + 'static>(&mut self, value: Arc<T>) {
93 let type_id = TypeId::of::<T>();
94 let arc = value as Arc<dyn Any + Send + Sync>;
95 self.map.insert(type_id, arc);
96 }
97
98 /// Get a reference to a typed value.
99 ///
100 /// Returns `None` if no value of this type exists.
101 ///
102 /// # Performance
103 ///
104 /// This is O(1) and only involves a HashMap lookup followed by
105 /// a pointer cast (no runtime type checking).
106 ///
107 /// # Example
108 ///
109 /// ```rust
110 /// use armature_core::Extensions;
111 ///
112 /// let mut ext = Extensions::new();
113 /// ext.insert(42i32);
114 ///
115 /// assert_eq!(ext.get::<i32>(), Some(&42));
116 /// assert_eq!(ext.get::<String>(), None);
117 /// ```
118 #[inline]
119 pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
120 let type_id = TypeId::of::<T>();
121 self.map
122 .get(&type_id)
123 .and_then(|arc| arc.downcast_ref::<T>())
124 }
125
126 /// Get an Arc reference to a typed value.
127 ///
128 /// This is useful when you need to clone the Arc for async operations.
129 #[inline]
130 pub fn get_arc<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
131 let type_id = TypeId::of::<T>();
132 self.map
133 .get(&type_id)
134 .and_then(|arc| arc.clone().downcast::<T>().ok())
135 }
136
137 /// Check if a value of this type exists.
138 #[inline]
139 pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
140 let type_id = TypeId::of::<T>();
141 self.map.contains_key(&type_id)
142 }
143
144 /// Remove a typed value from the extensions.
145 ///
146 /// Returns true if the value existed and was removed.
147 #[inline]
148 pub fn remove<T: Send + Sync + 'static>(&mut self) -> bool {
149 let type_id = TypeId::of::<T>();
150 self.map.remove(&type_id).is_some()
151 }
152
153 /// Clear all extensions.
154 #[inline]
155 pub fn clear(&mut self) {
156 self.map.clear();
157 }
158
159 /// Get the number of extensions.
160 #[inline]
161 pub fn len(&self) -> usize {
162 self.map.len()
163 }
164
165 /// Check if extensions is empty.
166 #[inline]
167 pub fn is_empty(&self) -> bool {
168 self.map.is_empty()
169 }
170
171 /// Merge another extensions container into this one.
172 ///
173 /// Values from `other` will overwrite values in `self` for the same type.
174 #[inline]
175 pub fn extend(&mut self, other: Extensions) {
176 self.map.extend(other.map);
177 }
178}
179
180impl std::fmt::Debug for Extensions {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 f.debug_struct("Extensions")
183 .field("count", &self.map.len())
184 .finish()
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn test_insert_and_get() {
194 let mut ext = Extensions::new();
195
196 ext.insert(42i32);
197 ext.insert("hello".to_string());
198
199 assert_eq!(ext.get::<i32>(), Some(&42));
200 assert_eq!(ext.get::<String>(), Some(&"hello".to_string()));
201 assert_eq!(ext.get::<f64>(), None);
202 }
203
204 #[test]
205 fn test_insert_replaces() {
206 let mut ext = Extensions::new();
207
208 ext.insert(42i32);
209 ext.insert(100i32);
210
211 assert_eq!(ext.get::<i32>(), Some(&100));
212 }
213
214 #[test]
215 fn test_contains() {
216 let mut ext = Extensions::new();
217
218 assert!(!ext.contains::<i32>());
219 ext.insert(42i32);
220 assert!(ext.contains::<i32>());
221 }
222
223 #[test]
224 fn test_remove() {
225 let mut ext = Extensions::new();
226 ext.insert(42i32);
227
228 let removed = ext.remove::<i32>();
229 assert!(removed);
230 assert!(!ext.contains::<i32>());
231 }
232
233 #[test]
234 fn test_arc_insert() {
235 let mut ext = Extensions::new();
236 let arc = Arc::new(42i32);
237
238 ext.insert_arc(arc.clone());
239
240 let retrieved = ext.get_arc::<i32>().unwrap();
241 assert_eq!(*retrieved, 42);
242 }
243
244 #[test]
245 fn test_clone() {
246 let mut ext = Extensions::new();
247 ext.insert(42i32);
248
249 let cloned = ext.clone();
250 assert_eq!(cloned.get::<i32>(), Some(&42));
251 }
252}