Skip to main content

do_over/
context.rs

1//! Policy context for passing metadata through execution.
2//!
3//! The context module provides types for passing request-scoped data through
4//! policy execution, useful for logging, tracing, and correlation IDs.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use do_over::context::{Context, ContextKey};
10//!
11//! // Define a context key
12//! static REQUEST_ID: ContextKey<String> = ContextKey::new("request_id");
13//!
14//! // Create a context with data
15//! let mut ctx = Context::new();
16//! ctx.insert(&REQUEST_ID, "req-123".to_string());
17//!
18//! // Retrieve data
19//! if let Some(id) = ctx.get(&REQUEST_ID) {
20//!     println!("Request ID: {}", id);
21//! }
22//! ```
23
24use std::any::{Any, TypeId};
25use std::collections::HashMap;
26use std::sync::Arc;
27
28/// A key for storing values in a Context.
29///
30/// Keys are identified by a name and the type they store.
31///
32/// # Examples
33///
34/// ```rust
35/// use do_over::context::ContextKey;
36///
37/// static USER_ID: ContextKey<u64> = ContextKey::new("user_id");
38/// static TRACE_ID: ContextKey<String> = ContextKey::new("trace_id");
39/// ```
40pub struct ContextKey<T> {
41    name: &'static str,
42    _marker: std::marker::PhantomData<T>,
43}
44
45impl<T> ContextKey<T> {
46    /// Create a new context key.
47    ///
48    /// # Arguments
49    ///
50    /// * `name` - A descriptive name for debugging
51    pub const fn new(name: &'static str) -> Self {
52        Self {
53            name,
54            _marker: std::marker::PhantomData,
55        }
56    }
57
58    /// Get the name of this key.
59    pub fn name(&self) -> &'static str {
60        self.name
61    }
62}
63
64/// A type-safe container for request-scoped data.
65///
66/// Context allows you to pass metadata through policy execution without
67/// modifying function signatures.
68///
69/// # Examples
70///
71/// ```rust
72/// use do_over::context::{Context, ContextKey};
73///
74/// static CORRELATION_ID: ContextKey<String> = ContextKey::new("correlation_id");
75/// static RETRY_COUNT: ContextKey<u32> = ContextKey::new("retry_count");
76///
77/// let mut ctx = Context::new();
78/// ctx.insert(&CORRELATION_ID, "abc-123".to_string());
79/// ctx.insert(&RETRY_COUNT, 0u32);
80///
81/// assert_eq!(ctx.get(&CORRELATION_ID), Some(&"abc-123".to_string()));
82/// assert_eq!(ctx.get(&RETRY_COUNT), Some(&0u32));
83/// ```
84#[derive(Default)]
85pub struct Context {
86    values: HashMap<(TypeId, &'static str), Arc<dyn Any + Send + Sync>>,
87}
88
89impl Context {
90    /// Create a new empty context.
91    pub fn new() -> Self {
92        Self::default()
93    }
94
95    /// Insert a value into the context.
96    ///
97    /// # Arguments
98    ///
99    /// * `key` - The context key
100    /// * `value` - The value to store
101    ///
102    /// # Examples
103    ///
104    /// ```rust
105    /// use do_over::context::{Context, ContextKey};
106    ///
107    /// static KEY: ContextKey<String> = ContextKey::new("key");
108    ///
109    /// let mut ctx = Context::new();
110    /// ctx.insert(&KEY, "value".to_string());
111    /// ```
112    pub fn insert<T: Send + Sync + 'static>(&mut self, key: &ContextKey<T>, value: T) {
113        let type_id = TypeId::of::<T>();
114        self.values
115            .insert((type_id, key.name), Arc::new(value));
116    }
117
118    /// Get a value from the context.
119    ///
120    /// # Arguments
121    ///
122    /// * `key` - The context key
123    ///
124    /// # Returns
125    ///
126    /// A reference to the value if it exists, or None.
127    ///
128    /// # Examples
129    ///
130    /// ```rust
131    /// use do_over::context::{Context, ContextKey};
132    ///
133    /// static KEY: ContextKey<String> = ContextKey::new("key");
134    ///
135    /// let mut ctx = Context::new();
136    /// ctx.insert(&KEY, "value".to_string());
137    ///
138    /// assert_eq!(ctx.get(&KEY), Some(&"value".to_string()));
139    /// ```
140    pub fn get<T: Send + Sync + 'static>(&self, key: &ContextKey<T>) -> Option<&T> {
141        let type_id = TypeId::of::<T>();
142        self.values
143            .get(&(type_id, key.name))
144            .and_then(|v| v.downcast_ref::<T>())
145    }
146
147    /// Remove a value from the context.
148    ///
149    /// # Arguments
150    ///
151    /// * `key` - The context key
152    ///
153    /// # Returns
154    ///
155    /// The removed value if it existed (cloned before removal).
156    pub fn remove<T: Send + Sync + Clone + 'static>(&mut self, key: &ContextKey<T>) -> Option<T> {
157        let type_id = TypeId::of::<T>();
158        // Clone the value before removing (required because Arc contains trait object)
159        let value = self.values
160            .get(&(type_id, key.name))
161            .and_then(|v| v.downcast_ref::<T>())
162            .cloned();
163        self.values.remove(&(type_id, key.name));
164        value
165    }
166
167    /// Check if the context contains a key.
168    pub fn contains<T: Send + Sync + 'static>(&self, key: &ContextKey<T>) -> bool {
169        let type_id = TypeId::of::<T>();
170        self.values.contains_key(&(type_id, key.name))
171    }
172
173    /// Clear all values from the context.
174    pub fn clear(&mut self) {
175        self.values.clear();
176    }
177}
178
179impl Clone for Context {
180    fn clone(&self) -> Self {
181        Self {
182            values: self.values.clone(),
183        }
184    }
185}
186
187/// Common context keys for resilience operations.
188pub mod keys {
189    use super::ContextKey;
190
191    /// Correlation ID for distributed tracing.
192    pub static CORRELATION_ID: ContextKey<String> = ContextKey::new("correlation_id");
193
194    /// Current retry attempt number.
195    pub static RETRY_ATTEMPT: ContextKey<u32> = ContextKey::new("retry_attempt");
196
197    /// Operation name or identifier.
198    pub static OPERATION_NAME: ContextKey<String> = ContextKey::new("operation_name");
199
200    /// Start time of the operation.
201    pub static START_TIME: ContextKey<std::time::Instant> = ContextKey::new("start_time");
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    static STRING_KEY: ContextKey<String> = ContextKey::new("string");
209    static INT_KEY: ContextKey<i32> = ContextKey::new("int");
210
211    #[test]
212    fn test_insert_and_get() {
213        let mut ctx = Context::new();
214        ctx.insert(&STRING_KEY, "hello".to_string());
215        ctx.insert(&INT_KEY, 42);
216
217        assert_eq!(ctx.get(&STRING_KEY), Some(&"hello".to_string()));
218        assert_eq!(ctx.get(&INT_KEY), Some(&42));
219    }
220
221    #[test]
222    fn test_get_missing_key() {
223        let ctx = Context::new();
224        assert_eq!(ctx.get(&STRING_KEY), None);
225    }
226
227    #[test]
228    fn test_contains() {
229        let mut ctx = Context::new();
230        assert!(!ctx.contains(&STRING_KEY));
231
232        ctx.insert(&STRING_KEY, "value".to_string());
233        assert!(ctx.contains(&STRING_KEY));
234    }
235
236    #[test]
237    fn test_remove() {
238        let mut ctx = Context::new();
239        ctx.insert(&STRING_KEY, "value".to_string());
240
241        let removed = ctx.remove(&STRING_KEY);
242        assert_eq!(removed, Some("value".to_string()));
243        assert!(!ctx.contains(&STRING_KEY));
244    }
245
246    #[test]
247    fn test_clear() {
248        let mut ctx = Context::new();
249        ctx.insert(&STRING_KEY, "value".to_string());
250        ctx.insert(&INT_KEY, 42);
251
252        ctx.clear();
253        assert!(!ctx.contains(&STRING_KEY));
254        assert!(!ctx.contains(&INT_KEY));
255    }
256
257    #[test]
258    fn test_clone() {
259        let mut ctx = Context::new();
260        ctx.insert(&STRING_KEY, "value".to_string());
261
262        let ctx2 = ctx.clone();
263        assert_eq!(ctx2.get(&STRING_KEY), Some(&"value".to_string()));
264    }
265}