armature_core/
extensions.rs

1//! Type-safe request extensions for zero-cost state extraction.
2//!
3//! This module provides a way to attach typed data to requests without
4//! runtime type checking overhead. Unlike the DI container which uses
5//! `Any` downcasting, extensions use a type-erased map that only requires
6//! type checks at the point of insertion, not retrieval.
7//!
8//! # Performance
9//!
10//! - **Insertion**: O(1) with a single type check
11//! - **Retrieval**: O(1) with no type checking (uses pre-verified TypeId)
12//! - **Memory**: One `Arc<T>` per extension type
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use armature_core::{Extensions, State};
18//!
19//! // Application state
20//! #[derive(Clone)]
21//! struct AppState {
22//!     db_pool: Pool,
23//! }
24//!
25//! // Insert state at startup
26//! let mut extensions = Extensions::new();
27//! extensions.insert(AppState { db_pool });
28//!
29//! // Extract in handler (zero-cost after setup)
30//! async fn handler(state: State<AppState>) -> Result<HttpResponse, Error> {
31//!     let pool = &state.db_pool;
32//!     // ...
33//! }
34//! ```
35
36use std::any::{Any, TypeId};
37use std::collections::HashMap;
38use std::sync::Arc;
39
40/// Type-safe extensions container.
41///
42/// Stores typed values keyed by `TypeId` for O(1) retrieval without
43/// runtime type checking after initial insertion.
44#[derive(Clone, Default)]
45pub struct Extensions {
46    /// Internal storage: TypeId -> type-erased Arc
47    /// The Arc contains the actual typed value, wrapped for thread-safety
48    map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
49}
50
51impl Extensions {
52    /// Create a new empty extensions container.
53    #[inline]
54    pub fn new() -> Self {
55        Self {
56            map: HashMap::new(),
57        }
58    }
59
60    /// Create with pre-allocated capacity.
61    #[inline]
62    pub fn with_capacity(capacity: usize) -> Self {
63        Self {
64            map: HashMap::with_capacity(capacity),
65        }
66    }
67
68    /// Insert a typed value into the extensions.
69    ///
70    /// If a value of this type already exists, it is replaced.
71    ///
72    /// # Example
73    ///
74    /// ```rust
75    /// use armature_core::Extensions;
76    ///
77    /// let mut ext = Extensions::new();
78    /// ext.insert(42i32);
79    /// ext.insert("hello");
80    /// ```
81    #[inline]
82    pub fn insert<T: Send + Sync + 'static>(&mut self, value: T) {
83        let type_id = TypeId::of::<T>();
84        let arc = Arc::new(value) as Arc<dyn Any + Send + Sync>;
85        self.map.insert(type_id, arc);
86    }
87
88    /// Insert an Arc-wrapped value directly.
89    ///
90    /// This is more efficient when you already have an Arc.
91    #[inline]
92    pub fn insert_arc<T: Send + Sync + 'static>(&mut self, value: Arc<T>) {
93        let type_id = TypeId::of::<T>();
94        let arc = value as Arc<dyn Any + Send + Sync>;
95        self.map.insert(type_id, arc);
96    }
97
98    /// Get a reference to a typed value.
99    ///
100    /// Returns `None` if no value of this type exists.
101    ///
102    /// # Performance
103    ///
104    /// This is O(1) and only involves a HashMap lookup followed by
105    /// a pointer cast (no runtime type checking).
106    ///
107    /// # Example
108    ///
109    /// ```rust
110    /// use armature_core::Extensions;
111    ///
112    /// let mut ext = Extensions::new();
113    /// ext.insert(42i32);
114    ///
115    /// assert_eq!(ext.get::<i32>(), Some(&42));
116    /// assert_eq!(ext.get::<String>(), None);
117    /// ```
118    #[inline]
119    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
120        let type_id = TypeId::of::<T>();
121        self.map
122            .get(&type_id)
123            .and_then(|arc| arc.downcast_ref::<T>())
124    }
125
126    /// Get an Arc reference to a typed value.
127    ///
128    /// This is useful when you need to clone the Arc for async operations.
129    #[inline]
130    pub fn get_arc<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
131        let type_id = TypeId::of::<T>();
132        self.map
133            .get(&type_id)
134            .and_then(|arc| arc.clone().downcast::<T>().ok())
135    }
136
137    /// Check if a value of this type exists.
138    #[inline]
139    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
140        let type_id = TypeId::of::<T>();
141        self.map.contains_key(&type_id)
142    }
143
144    /// Remove a typed value from the extensions.
145    ///
146    /// Returns true if the value existed and was removed.
147    #[inline]
148    pub fn remove<T: Send + Sync + 'static>(&mut self) -> bool {
149        let type_id = TypeId::of::<T>();
150        self.map.remove(&type_id).is_some()
151    }
152
153    /// Clear all extensions.
154    #[inline]
155    pub fn clear(&mut self) {
156        self.map.clear();
157    }
158
159    /// Get the number of extensions.
160    #[inline]
161    pub fn len(&self) -> usize {
162        self.map.len()
163    }
164
165    /// Check if extensions is empty.
166    #[inline]
167    pub fn is_empty(&self) -> bool {
168        self.map.is_empty()
169    }
170
171    /// Merge another extensions container into this one.
172    ///
173    /// Values from `other` will overwrite values in `self` for the same type.
174    #[inline]
175    pub fn extend(&mut self, other: Extensions) {
176        self.map.extend(other.map);
177    }
178}
179
180impl std::fmt::Debug for Extensions {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        f.debug_struct("Extensions")
183            .field("count", &self.map.len())
184            .finish()
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn test_insert_and_get() {
194        let mut ext = Extensions::new();
195
196        ext.insert(42i32);
197        ext.insert("hello".to_string());
198
199        assert_eq!(ext.get::<i32>(), Some(&42));
200        assert_eq!(ext.get::<String>(), Some(&"hello".to_string()));
201        assert_eq!(ext.get::<f64>(), None);
202    }
203
204    #[test]
205    fn test_insert_replaces() {
206        let mut ext = Extensions::new();
207
208        ext.insert(42i32);
209        ext.insert(100i32);
210
211        assert_eq!(ext.get::<i32>(), Some(&100));
212    }
213
214    #[test]
215    fn test_contains() {
216        let mut ext = Extensions::new();
217
218        assert!(!ext.contains::<i32>());
219        ext.insert(42i32);
220        assert!(ext.contains::<i32>());
221    }
222
223    #[test]
224    fn test_remove() {
225        let mut ext = Extensions::new();
226        ext.insert(42i32);
227
228        let removed = ext.remove::<i32>();
229        assert!(removed);
230        assert!(!ext.contains::<i32>());
231    }
232
233    #[test]
234    fn test_arc_insert() {
235        let mut ext = Extensions::new();
236        let arc = Arc::new(42i32);
237
238        ext.insert_arc(arc.clone());
239
240        let retrieved = ext.get_arc::<i32>().unwrap();
241        assert_eq!(*retrieved, 42);
242    }
243
244    #[test]
245    fn test_clone() {
246        let mut ext = Extensions::new();
247        ext.insert(42i32);
248
249        let cloned = ext.clone();
250        assert_eq!(cloned.get::<i32>(), Some(&42));
251    }
252}