Skip to main content

descry_tool_core/
context.rs

1//! Tool execution context
2//!
3//! Provides thread-safe, async-friendly context management using DashMap.
4
5use dashmap::DashMap;
6use serde::{Deserialize, Serialize};
7use std::any::{Any, TypeId};
8use std::sync::Arc;
9
10use crate::JsonObject;
11
12/// Request metadata
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14pub struct Meta {
15    /// Unique request identifier
16    pub request_id: Option<String>,
17    /// Unix timestamp in milliseconds
18    pub timestamp: Option<u64>,
19    /// Source of the request (e.g., "cli", "http", "mcp")
20    pub source: Option<String>,
21}
22
23impl Meta {
24    /// Create empty metadata
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    /// Set request ID
30    pub fn with_request_id(mut self, id: impl Into<String>) -> Self {
31        self.request_id = Some(id.into());
32        self
33    }
34
35    /// Set timestamp
36    pub fn with_timestamp(mut self, ts: u64) -> Self {
37        self.timestamp = Some(ts);
38        self
39    }
40
41    /// Set source
42    pub fn with_source(mut self, source: impl Into<String>) -> Self {
43        self.source = Some(source.into());
44        self
45    }
46}
47
48/// Tool execution context
49///
50/// Thread-safe context using `Arc<DashMap>` for extensions.
51/// Can be freely cloned and passed across await points.
52///
53/// # Thread Safety
54///
55/// All operations are thread-safe:
56/// - `insert`, `get`, `remove` use DashMap's concurrent HashMap
57/// - Extensions wrapped in `Arc<T>` for cheap cloning
58///
59/// # Examples
60///
61/// ```
62/// use descry_tool_core::ToolContext;
63/// use std::sync::Arc;
64///
65/// let ctx = Arc::new(ToolContext::new());
66///
67/// // Insert extension
68/// #[derive(Debug)]
69/// struct MyService;
70/// ctx.insert(MyService);
71/// ```
72#[derive(Debug, Clone)]
73pub struct ToolContext {
74    /// Request metadata
75    pub meta: Meta,
76    /// Raw arguments (consumed on first access)
77    args: Option<JsonObject>,
78    /// Thread-safe extensions (DashMap + Arc)
79    extensions: Arc<DashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
80}
81
82impl ToolContext {
83    /// Create a new context with empty args
84    pub fn new() -> Self {
85        Self {
86            meta: Meta::default(),
87            args: Some(JsonObject::new()),
88            extensions: Arc::new(DashMap::new()),
89        }
90    }
91
92    /// Create context with args
93    pub fn with_args(args: JsonObject) -> Self {
94        Self {
95            meta: Meta::default(),
96            args: Some(args),
97            extensions: Arc::new(DashMap::new()),
98        }
99    }
100
101    /// Create empty context
102    pub fn empty() -> Self {
103        Self {
104            meta: Meta::default(),
105            args: None,
106            extensions: Arc::new(DashMap::new()),
107        }
108    }
109
110    /// Set metadata
111    pub fn with_meta(mut self, meta: Meta) -> Self {
112        self.meta = meta;
113        self
114    }
115
116    /// Take and consume the raw arguments
117    ///
118    /// This can only be called once. Subsequent calls will return None.
119    pub fn take_args(&mut self) -> Option<JsonObject> {
120        self.args.take()
121    }
122
123    /// Peek at the raw arguments without consuming
124    pub fn peek_args(&self) -> Option<&JsonObject> {
125        self.args.as_ref()
126    }
127
128    /// Insert an extension into the context
129    ///
130    /// Returns the previous value if it existed.
131    pub fn insert<T>(&self, ext: T) -> Option<Arc<T>>
132    where
133        T: Send + Sync + 'static,
134    {
135        self.extensions
136            .insert(TypeId::of::<T>(), Arc::new(ext))
137            .and_then(|arc| arc.downcast().ok())
138    }
139
140    /// Get a reference to an extension
141    pub fn get<T>(&self) -> Option<Arc<T>>
142    where
143        T: Send + Sync + 'static,
144    {
145        self.extensions
146            .get(&TypeId::of::<T>())
147            .and_then(|entry| Arc::downcast(entry.clone()).ok())
148    }
149
150    /// Remove an extension from the context
151    pub fn remove<T>(&self) -> Option<Arc<T>>
152    where
153        T: Send + Sync + 'static,
154    {
155        self.extensions
156            .remove(&TypeId::of::<T>())
157            .and_then(|(_, arc)| Arc::downcast(arc).ok())
158    }
159
160    /// Check if an extension exists
161    pub fn contains<T>(&self) -> bool
162    where
163        T: Send + Sync + 'static,
164    {
165        self.extensions.contains_key(&TypeId::of::<T>())
166    }
167
168    /// Clear all extensions
169    pub fn clear(&self) {
170        self.extensions.clear();
171    }
172
173    /// Get extension count
174    pub fn len(&self) -> usize {
175        self.extensions.len()
176    }
177
178    /// Check if extensions is empty
179    pub fn is_empty(&self) -> bool {
180        self.extensions.is_empty()
181    }
182}
183
184impl Default for ToolContext {
185    fn default() -> Self {
186        Self::new()
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[test]
195    fn test_meta_builder() {
196        let meta = Meta::new()
197            .with_request_id("test-123")
198            .with_timestamp(1234567890)
199            .with_source("test");
200
201        assert_eq!(meta.request_id, Some("test-123".to_string()));
202        assert_eq!(meta.timestamp, Some(1234567890));
203        assert_eq!(meta.source, Some("test".to_string()));
204    }
205
206    #[test]
207    fn test_context_new() {
208        let ctx = ToolContext::new();
209        assert!(ctx.peek_args().is_some());
210        assert_eq!(ctx.meta.request_id, None);
211    }
212
213    #[test]
214    fn test_context_take_args() {
215        let mut ctx = ToolContext::with_args(
216            serde_json::json!({"key": "value"})
217                .as_object()
218                .unwrap()
219                .clone(),
220        );
221
222        let taken = ctx.take_args().unwrap();
223        assert!(taken.contains_key("key"));
224
225        // Second call should return None
226        assert!(ctx.take_args().is_none());
227    }
228
229    #[test]
230    fn test_context_extensions() {
231        let ctx = ToolContext::new();
232
233        #[derive(Debug, PartialEq)]
234        struct MyExt {
235            value: i32,
236        }
237
238        // Insert (returns None for new key)
239        assert!(ctx.insert(MyExt { value: 42 }).is_none());
240
241        // Get
242        let ext = ctx.get::<MyExt>().unwrap();
243        assert_eq!(ext.value, 42);
244
245        // Contains
246        assert!(ctx.contains::<MyExt>());
247
248        // Remove
249        let removed = ctx.remove::<MyExt>().unwrap();
250        assert_eq!(removed.value, 42);
251        assert!(!ctx.contains::<MyExt>());
252    }
253
254    #[test]
255    fn test_context_clone() {
256        #[derive(Debug, PartialEq)]
257        struct MyExt {
258            value: i32,
259        }
260
261        let ctx = ToolContext::new();
262        ctx.insert(MyExt { value: 42 });
263
264        // Clone should preserve extensions
265        let cloned = ctx.clone();
266        let ext = cloned.get::<MyExt>().unwrap();
267        assert_eq!(ext.value, 42);
268    }
269
270    #[test]
271    fn test_context_thread_safety() {
272        use std::sync::Arc;
273        use std::thread;
274
275        let ctx = Arc::new(ToolContext::new());
276        let ctx_clone = ctx.clone();
277
278        #[derive(Debug)]
279        struct Counter;
280
281        // Insert in main thread
282        ctx.insert(Counter);
283
284        // Access in spawned thread
285        let handle = thread::spawn(move || {
286            assert!(ctx_clone.contains::<Counter>());
287        });
288
289        handle.join().unwrap();
290        assert!(ctx.contains::<Counter>());
291    }
292}