foundation_utils/
context.rs1use crate::raii::Guard;
16use std::any::{Any, TypeId};
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex, RwLock};
19
20pub struct ThreadLocalContext;
25
26thread_local! {
27 static CONTEXT_STORAGE: std::cell::RefCell<HashMap<TypeId, Box<dyn Any>>> =
28 std::cell::RefCell::new(HashMap::new());
29}
30
31impl ThreadLocalContext {
32 pub fn new() -> Self {
34 Self
35 }
36
37 pub fn set<T: 'static>(&self, value: T) {
39 CONTEXT_STORAGE.with(|storage| {
40 storage
41 .borrow_mut()
42 .insert(TypeId::of::<T>(), Box::new(value));
43 });
44 }
45
46 pub fn get<T: 'static + Clone>(&self) -> Option<T> {
48 CONTEXT_STORAGE.with(|storage| {
49 storage
50 .borrow()
51 .get(&TypeId::of::<T>())
52 .and_then(|any| any.downcast_ref::<T>())
53 .cloned()
54 })
55 }
56
57 pub fn remove<T: 'static>(&self) -> Option<T> {
59 CONTEXT_STORAGE.with(|storage| {
60 storage
61 .borrow_mut()
62 .remove(&TypeId::of::<T>())
63 .and_then(|any| any.downcast::<T>().ok())
64 .map(|boxed| *boxed)
65 })
66 }
67
68 pub fn clear(&self) {
70 CONTEXT_STORAGE.with(|storage| {
71 storage.borrow_mut().clear();
72 });
73 }
74
75 pub fn scoped<T: 'static + Clone>(&self, value: T) -> Guard<T, impl FnOnce(T) + use<T>> {
77 let previous = self.get::<T>();
78 self.set(value.clone());
79
80 Guard::new(value, move |_| {
81 if let Some(prev) = previous {
82 CONTEXT_STORAGE.with(|storage| {
83 storage
84 .borrow_mut()
85 .insert(TypeId::of::<T>(), Box::new(prev));
86 });
87 } else {
88 CONTEXT_STORAGE.with(|storage| {
89 storage.borrow_mut().remove(&TypeId::of::<T>());
90 });
91 }
92 })
93 }
94}
95
96impl Default for ThreadLocalContext {
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102pub struct GlobalContext {
107 storage: Arc<RwLock<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
108}
109
110impl GlobalContext {
111 pub fn new() -> Self {
113 Self {
114 storage: Arc::new(RwLock::new(HashMap::new())),
115 }
116 }
117
118 pub fn set<T: 'static + Send + Sync>(&self, value: T) {
120 let mut storage = self.storage.write().unwrap();
121 storage.insert(TypeId::of::<T>(), Arc::new(value));
122 }
123
124 pub fn get<T: 'static + Send + Sync + Clone>(&self) -> Option<T> {
126 let storage = self.storage.read().unwrap();
127 storage
128 .get(&TypeId::of::<T>())
129 .and_then(|any| any.downcast_ref::<T>())
130 .cloned()
131 }
132
133 pub fn remove<T: 'static + Send + Sync>(&self) -> bool {
135 let mut storage = self.storage.write().unwrap();
136 storage.remove(&TypeId::of::<T>()).is_some()
137 }
138
139 pub fn clear(&self) {
141 let mut storage = self.storage.write().unwrap();
142 storage.clear();
143 }
144}
145
146impl Default for GlobalContext {
147 fn default() -> Self {
148 Self::new()
149 }
150}
151
152pub trait ContextManager: Send + Sync {
156 fn set_string(&self, key: &str, value: String);
158
159 fn get_string(&self, key: &str) -> Option<String>;
161
162 fn remove_string(&self, key: &str) -> bool;
164
165 fn clear_all(&self);
167}
168
169pub struct HashMapContext {
174 storage: Arc<Mutex<HashMap<String, String>>>,
175}
176
177impl HashMapContext {
178 pub fn new() -> Self {
180 Self {
181 storage: Arc::new(Mutex::new(HashMap::new())),
182 }
183 }
184}
185
186impl Default for HashMapContext {
187 fn default() -> Self {
188 Self::new()
189 }
190}
191
192impl ContextManager for HashMapContext {
193 fn set_string(&self, key: &str, value: String) {
194 let mut storage = self.storage.lock().unwrap();
195 storage.insert(key.to_string(), value);
196 }
197
198 fn get_string(&self, key: &str) -> Option<String> {
199 let storage = self.storage.lock().unwrap();
200 storage.get(key).cloned()
201 }
202
203 fn remove_string(&self, key: &str) -> bool {
204 let mut storage = self.storage.lock().unwrap();
205 storage.remove(key).is_some()
206 }
207
208 fn clear_all(&self) {
209 let mut storage = self.storage.lock().unwrap();
210 storage.clear();
211 }
212}
213
214thread_local! {
216 static GLOBAL_THREAD_CONTEXT: ThreadLocalContext = ThreadLocalContext::new();
217}
218
219pub fn set_context<T: 'static>(value: T) {
221 GLOBAL_THREAD_CONTEXT.with(|ctx| ctx.set(value));
222}
223
224pub fn get_context<T: 'static + Clone>() -> Option<T> {
226 GLOBAL_THREAD_CONTEXT.with(|ctx| ctx.get())
227}
228
229pub fn remove_context<T: 'static>() -> Option<T> {
231 GLOBAL_THREAD_CONTEXT.with(|ctx| ctx.remove())
232}
233
234pub fn clear_context() {
236 GLOBAL_THREAD_CONTEXT.with(|ctx| ctx.clear());
237}
238
239pub fn scoped_context<T: 'static + Clone>(value: T) -> impl Drop {
241 GLOBAL_THREAD_CONTEXT.with(|ctx| ctx.scoped(value))
242}
243
244pub fn with_context_value<T, F, R>(value: T, f: F) -> R
246where
247 T: 'static + Clone,
248 F: FnOnce() -> R,
249{
250 let _guard = scoped_context(value);
251 f()
252}
253
254pub struct ContextKey<T> {
259 name: &'static str,
260 _phantom: std::marker::PhantomData<T>,
261}
262
263impl<T> ContextKey<T> {
264 pub const fn new(name: &'static str) -> Self {
266 Self {
267 name,
268 _phantom: std::marker::PhantomData,
269 }
270 }
271
272 pub fn name(&self) -> &'static str {
274 self.name
275 }
276}
277
278impl<T> Clone for ContextKey<T> {
279 fn clone(&self) -> Self {
280 Self {
281 name: self.name,
282 _phantom: std::marker::PhantomData,
283 }
284 }
285}
286
287impl<T> Copy for ContextKey<T> {}
288
289#[macro_export]
303macro_rules! context_key {
304 ($name:ident, $type:ty) => {
305 pub const $name: $crate::context::ContextKey<$type> =
306 $crate::context::ContextKey::new(stringify!($name));
307 };
308}
309
310#[macro_export]
326macro_rules! with_context_scoped {
327 ($key:expr_2021, $value:expr_2021, $block:block) => {
328 $crate::context::with_context_value($value, || $block)
329 };
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_thread_local_context() {
338 let ctx = ThreadLocalContext::new();
339
340 ctx.set(42i32);
342 ctx.set("hello".to_string());
343
344 assert_eq!(ctx.get::<i32>(), Some(42));
345 assert_eq!(ctx.get::<String>(), Some("hello".to_string()));
346 assert_eq!(ctx.get::<f64>(), None);
347
348 assert_eq!(ctx.remove::<i32>(), Some(42));
350 assert_eq!(ctx.get::<i32>(), None);
351
352 ctx.clear();
354 assert_eq!(ctx.get::<String>(), None);
355 }
356
357 #[test]
358 fn test_global_context() {
359 let ctx = GlobalContext::new();
360
361 ctx.set(42i32);
363 ctx.set("hello".to_string());
364
365 assert_eq!(ctx.get::<i32>(), Some(42));
366 assert_eq!(ctx.get::<String>(), Some("hello".to_string()));
367 assert_eq!(ctx.get::<f64>(), None);
368
369 assert!(ctx.remove::<i32>());
371 assert_eq!(ctx.get::<i32>(), None);
372 assert!(!ctx.remove::<i32>()); ctx.clear();
376 assert_eq!(ctx.get::<String>(), None);
377 }
378
379 #[test]
380 fn test_hashmap_context() {
381 let ctx = HashMapContext::new();
382
383 ctx.set_string("key1", "value1".to_string());
385 ctx.set_string("key2", "value2".to_string());
386
387 assert_eq!(ctx.get_string("key1"), Some("value1".to_string()));
388 assert_eq!(ctx.get_string("key2"), Some("value2".to_string()));
389 assert_eq!(ctx.get_string("key3"), None);
390
391 assert!(ctx.remove_string("key1"));
393 assert_eq!(ctx.get_string("key1"), None);
394 assert!(!ctx.remove_string("key1")); ctx.clear_all();
398 assert_eq!(ctx.get_string("key2"), None);
399 }
400
401 #[test]
402 fn test_scoped_context() {
403 set_context(42i32);
405 assert_eq!(get_context::<i32>(), Some(42));
406
407 {
408 let _guard = scoped_context(100i32);
410 assert_eq!(get_context::<i32>(), Some(100));
411 } assert_eq!(get_context::<i32>(), Some(42));
415
416 clear_context();
418 assert_eq!(get_context::<i32>(), None);
419 }
420
421 #[test]
422 fn test_with_context_value() {
423 let result = with_context_value(42i32, || get_context::<i32>().unwrap() + 10);
425
426 assert_eq!(result, 52);
427
428 assert_eq!(get_context::<i32>(), None);
430 }
431
432 #[test]
433 fn test_nested_scoped_context() {
434 set_context(10i32);
435
436 let result = with_context_value(20i32, || {
437 let inner_result = with_context_value(30i32, || get_context::<i32>().unwrap());
438
439 assert_eq!(inner_result, 30);
440 get_context::<i32>().unwrap()
441 });
442
443 assert_eq!(result, 20);
444 assert_eq!(get_context::<i32>(), Some(10));
445
446 clear_context();
447 }
448
449 #[test]
450 fn test_context_key() {
451 context_key!(USER_ID, String);
452 context_key!(SESSION_ID, i64);
453
454 assert_eq!(USER_ID.name(), "USER_ID");
455 assert_eq!(SESSION_ID.name(), "SESSION_ID");
456 }
457
458 #[test]
459 fn test_panic_safety() {
460 set_context(42i32);
461
462 let result = std::panic::catch_unwind(|| {
463 with_context_value(100i32, || {
464 panic!("test panic");
465 })
466 });
467
468 assert!(result.is_err());
469
470 assert_eq!(get_context::<i32>(), Some(42));
472
473 clear_context();
474 }
475}