Skip to main content

foundation_utils/
context.rs

1//! Context management patterns
2//!
3//! This module provides thread-safe context management for WASM and native environments.
4//! Contexts are used to propagate information across function calls without explicit
5//! parameter passing.
6//!
7//! ## Key Features
8//!
9//! - **Thread-Local Context**: WASM-compatible thread-local storage
10//! - **Scoped Context**: Automatic context cleanup with RAII
11//! - **Context Inheritance**: Child contexts inherit from parent contexts
12//! - **Type Safety**: Strongly typed context values
13//! - **Performance**: Zero-cost abstractions with compile-time optimization
14
15use crate::raii::Guard;
16use std::any::{Any, TypeId};
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex, RwLock};
19
20/// Thread-local context storage
21///
22/// This provides WASM-compatible thread-local context storage that automatically
23/// cleans up when contexts go out of scope.
24pub struct ThreadLocalContext;
25
26thread_local! {
27    static CONTEXT_STORAGE: std::cell::RefCell<HashMap<TypeId, Box<dyn Any>>> =
28        std::cell::RefCell::new(HashMap::new());
29}
30
31impl ThreadLocalContext {
32    /// Create a new thread-local context
33    pub fn new() -> Self {
34        Self
35    }
36
37    /// Set a typed value in the context
38    pub fn set<T: 'static>(&self, value: T) {
39        CONTEXT_STORAGE.with(|storage| {
40            storage
41                .borrow_mut()
42                .insert(TypeId::of::<T>(), Box::new(value));
43        });
44    }
45
46    /// Get a typed value from the context
47    pub fn get<T: 'static + Clone>(&self) -> Option<T> {
48        CONTEXT_STORAGE.with(|storage| {
49            storage
50                .borrow()
51                .get(&TypeId::of::<T>())
52                .and_then(|any| any.downcast_ref::<T>())
53                .cloned()
54        })
55    }
56
57    /// Remove a typed value from the context
58    pub fn remove<T: 'static>(&self) -> Option<T> {
59        CONTEXT_STORAGE.with(|storage| {
60            storage
61                .borrow_mut()
62                .remove(&TypeId::of::<T>())
63                .and_then(|any| any.downcast::<T>().ok())
64                .map(|boxed| *boxed)
65        })
66    }
67
68    /// Clear all context values
69    pub fn clear(&self) {
70        CONTEXT_STORAGE.with(|storage| {
71            storage.borrow_mut().clear();
72        });
73    }
74
75    /// Create a scoped context guard for a typed value
76    pub fn scoped<T: 'static + Clone>(&self, value: T) -> Guard<T, impl FnOnce(T) + use<T>> {
77        let previous = self.get::<T>();
78        self.set(value.clone());
79
80        Guard::new(value, move |_| {
81            if let Some(prev) = previous {
82                CONTEXT_STORAGE.with(|storage| {
83                    storage
84                        .borrow_mut()
85                        .insert(TypeId::of::<T>(), Box::new(prev));
86                });
87            } else {
88                CONTEXT_STORAGE.with(|storage| {
89                    storage.borrow_mut().remove(&TypeId::of::<T>());
90                });
91            }
92        })
93    }
94}
95
96impl Default for ThreadLocalContext {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102/// Global context manager for shared state
103///
104/// This provides a global context that can be shared across threads
105/// in environments that support it (not available in WASM).
106pub struct GlobalContext {
107    storage: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
108}
109
110impl GlobalContext {
111    /// Create a new global context
112    pub fn new() -> Self {
113        Self {
114            storage: Arc::new(RwLock::new(HashMap::new())),
115        }
116    }
117
118    /// Set a typed value in the global context
119    pub fn set<T: 'static + Send + Sync>(&self, value: T) {
120        let mut storage = self.storage.write().unwrap();
121        storage.insert(TypeId::of::<T>(), Arc::new(value));
122    }
123
124    /// Get a typed value from the global context
125    pub fn get<T: 'static + Send + Sync + Clone>(&self) -> Option<T> {
126        let storage = self.storage.read().unwrap();
127        storage
128            .get(&TypeId::of::<T>())
129            .and_then(|any| any.downcast_ref::<T>())
130            .cloned()
131    }
132
133    /// Remove a typed value from the global context
134    pub fn remove<T: 'static + Send + Sync>(&self) -> bool {
135        let mut storage = self.storage.write().unwrap();
136        storage.remove(&TypeId::of::<T>()).is_some()
137    }
138
139    /// Clear all global context values
140    pub fn clear(&self) {
141        let mut storage = self.storage.write().unwrap();
142        storage.clear();
143    }
144}
145
146impl Default for GlobalContext {
147    fn default() -> Self {
148        Self::new()
149    }
150}
151
152/// Context manager trait for different context implementations
153///
154/// This trait provides a common interface for different types of context storage.
155pub trait ContextManager: Send + Sync {
156    /// Set a string value in the context
157    fn set_string(&self, key: &str, value: String);
158
159    /// Get a string value from the context
160    fn get_string(&self, key: &str) -> Option<String>;
161
162    /// Remove a string value from the context
163    fn remove_string(&self, key: &str) -> bool;
164
165    /// Clear all context values
166    fn clear_all(&self);
167}
168
169/// Simple hash map based context manager
170///
171/// This is useful for basic context management where thread safety
172/// is handled externally.
173pub struct HashMapContext {
174    storage: Arc<Mutex<HashMap<String, String>>>,
175}
176
177impl HashMapContext {
178    /// Create a new hash map context
179    pub fn new() -> Self {
180        Self {
181            storage: Arc::new(Mutex::new(HashMap::new())),
182        }
183    }
184}
185
186impl Default for HashMapContext {
187    fn default() -> Self {
188        Self::new()
189    }
190}
191
192impl ContextManager for HashMapContext {
193    fn set_string(&self, key: &str, value: String) {
194        let mut storage = self.storage.lock().unwrap();
195        storage.insert(key.to_string(), value);
196    }
197
198    fn get_string(&self, key: &str) -> Option<String> {
199        let storage = self.storage.lock().unwrap();
200        storage.get(key).cloned()
201    }
202
203    fn remove_string(&self, key: &str) -> bool {
204        let mut storage = self.storage.lock().unwrap();
205        storage.remove(key).is_some()
206    }
207
208    fn clear_all(&self) {
209        let mut storage = self.storage.lock().unwrap();
210        storage.clear();
211    }
212}
213
214// Convenience functions for working with a global thread-local context
215thread_local! {
216    static GLOBAL_THREAD_CONTEXT: ThreadLocalContext = ThreadLocalContext::new();
217}
218
219/// Set a typed value in the global thread-local context
220pub fn set_context<T: 'static>(value: T) {
221    GLOBAL_THREAD_CONTEXT.with(|ctx| ctx.set(value));
222}
223
224/// Get a typed value from the global thread-local context
225pub fn get_context<T: 'static + Clone>() -> Option<T> {
226    GLOBAL_THREAD_CONTEXT.with(|ctx| ctx.get())
227}
228
229/// Remove a typed value from the global thread-local context
230pub fn remove_context<T: 'static>() -> Option<T> {
231    GLOBAL_THREAD_CONTEXT.with(|ctx| ctx.remove())
232}
233
234/// Clear all values from the global thread-local context
235pub fn clear_context() {
236    GLOBAL_THREAD_CONTEXT.with(|ctx| ctx.clear());
237}
238
239/// Create a scoped context guard for the global thread-local context
240pub fn scoped_context<T: 'static + Clone>(value: T) -> impl Drop {
241    GLOBAL_THREAD_CONTEXT.with(|ctx| ctx.scoped(value))
242}
243
244/// Execute a function with a scoped context value
245pub fn with_context_value<T, F, R>(value: T, f: F) -> R
246where
247    T: 'static + Clone,
248    F: FnOnce() -> R,
249{
250    let _guard = scoped_context(value);
251    f()
252}
253
254/// Context key for type-safe context access
255///
256/// This provides a type-safe way to access context values without
257/// relying on string keys.
258pub struct ContextKey<T> {
259    name: &'static str,
260    _phantom: std::marker::PhantomData<T>,
261}
262
263impl<T> ContextKey<T> {
264    /// Create a new context key
265    pub const fn new(name: &'static str) -> Self {
266        Self {
267            name,
268            _phantom: std::marker::PhantomData,
269        }
270    }
271
272    /// Get the key name
273    pub fn name(&self) -> &'static str {
274        self.name
275    }
276}
277
278impl<T> Clone for ContextKey<T> {
279    fn clone(&self) -> Self {
280        Self {
281            name: self.name,
282            _phantom: std::marker::PhantomData,
283        }
284    }
285}
286
287impl<T> Copy for ContextKey<T> {}
288
289/// Macro for creating context keys
290///
291/// This macro creates a typed context key that can be used for type-safe
292/// context access.
293///
294/// # Example
295/// ```rust
296/// use foundation_utils::context_key;
297///
298/// context_key!(USER_ID, String);
299/// context_key!(REQUEST_ID, String);
300/// context_key!(TRACE_ID, String);
301/// ```
302#[macro_export]
303macro_rules! context_key {
304    ($name:ident, $type:ty) => {
305        pub const $name: $crate::context::ContextKey<$type> =
306            $crate::context::ContextKey::new(stringify!($name));
307    };
308}
309
310/// Macro for scoped context operations
311///
312/// This macro provides a convenient way to set context values for a scope.
313///
314/// # Example
315/// ```rust
316/// use foundation_utils::{with_context_scoped, context_key};
317///
318/// context_key!(USER_ID, String);
319///
320/// let result = with_context_scoped!(USER_ID, "user123".to_string(), {
321///     // Work with context
322///     42
323/// });
324/// ```
325#[macro_export]
326macro_rules! with_context_scoped {
327    ($key:expr_2021, $value:expr_2021, $block:block) => {
328        $crate::context::with_context_value($value, || $block)
329    };
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn test_thread_local_context() {
338        let ctx = ThreadLocalContext::new();
339
340        // Test setting and getting values
341        ctx.set(42i32);
342        ctx.set("hello".to_string());
343
344        assert_eq!(ctx.get::<i32>(), Some(42));
345        assert_eq!(ctx.get::<String>(), Some("hello".to_string()));
346        assert_eq!(ctx.get::<f64>(), None);
347
348        // Test removing values
349        assert_eq!(ctx.remove::<i32>(), Some(42));
350        assert_eq!(ctx.get::<i32>(), None);
351
352        // Test clearing
353        ctx.clear();
354        assert_eq!(ctx.get::<String>(), None);
355    }
356
357    #[test]
358    fn test_global_context() {
359        let ctx = GlobalContext::new();
360
361        // Test setting and getting values
362        ctx.set(42i32);
363        ctx.set("hello".to_string());
364
365        assert_eq!(ctx.get::<i32>(), Some(42));
366        assert_eq!(ctx.get::<String>(), Some("hello".to_string()));
367        assert_eq!(ctx.get::<f64>(), None);
368
369        // Test removing values
370        assert!(ctx.remove::<i32>());
371        assert_eq!(ctx.get::<i32>(), None);
372        assert!(!ctx.remove::<i32>()); // Already removed
373
374        // Test clearing
375        ctx.clear();
376        assert_eq!(ctx.get::<String>(), None);
377    }
378
379    #[test]
380    fn test_hashmap_context() {
381        let ctx = HashMapContext::new();
382
383        // Test setting and getting values
384        ctx.set_string("key1", "value1".to_string());
385        ctx.set_string("key2", "value2".to_string());
386
387        assert_eq!(ctx.get_string("key1"), Some("value1".to_string()));
388        assert_eq!(ctx.get_string("key2"), Some("value2".to_string()));
389        assert_eq!(ctx.get_string("key3"), None);
390
391        // Test removing values
392        assert!(ctx.remove_string("key1"));
393        assert_eq!(ctx.get_string("key1"), None);
394        assert!(!ctx.remove_string("key1")); // Already removed
395
396        // Test clearing
397        ctx.clear_all();
398        assert_eq!(ctx.get_string("key2"), None);
399    }
400
401    #[test]
402    fn test_scoped_context() {
403        // Set initial value
404        set_context(42i32);
405        assert_eq!(get_context::<i32>(), Some(42));
406
407        {
408            // Create scoped context
409            let _guard = scoped_context(100i32);
410            assert_eq!(get_context::<i32>(), Some(100));
411        } // Guard drops here
412
413        // Should restore previous value
414        assert_eq!(get_context::<i32>(), Some(42));
415
416        // Clean up
417        clear_context();
418        assert_eq!(get_context::<i32>(), None);
419    }
420
421    #[test]
422    fn test_with_context_value() {
423        // Test scoped context with function
424        let result = with_context_value(42i32, || get_context::<i32>().unwrap() + 10);
425
426        assert_eq!(result, 52);
427
428        // Context should be cleared after function
429        assert_eq!(get_context::<i32>(), None);
430    }
431
432    #[test]
433    fn test_nested_scoped_context() {
434        set_context(10i32);
435
436        let result = with_context_value(20i32, || {
437            let inner_result = with_context_value(30i32, || get_context::<i32>().unwrap());
438
439            assert_eq!(inner_result, 30);
440            get_context::<i32>().unwrap()
441        });
442
443        assert_eq!(result, 20);
444        assert_eq!(get_context::<i32>(), Some(10));
445
446        clear_context();
447    }
448
449    #[test]
450    fn test_context_key() {
451        context_key!(USER_ID, String);
452        context_key!(SESSION_ID, i64);
453
454        assert_eq!(USER_ID.name(), "USER_ID");
455        assert_eq!(SESSION_ID.name(), "SESSION_ID");
456    }
457
458    #[test]
459    fn test_panic_safety() {
460        set_context(42i32);
461
462        let result = std::panic::catch_unwind(|| {
463            with_context_value(100i32, || {
464                panic!("test panic");
465            })
466        });
467
468        assert!(result.is_err());
469
470        // Original context should be restored
471        assert_eq!(get_context::<i32>(), Some(42));
472
473        clear_context();
474    }
475}