descry_tool_core/
context.rs1use dashmap::DashMap;
6use serde::{Deserialize, Serialize};
7use std::any::{Any, TypeId};
8use std::sync::Arc;
9
10use crate::JsonObject;
11
12#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14pub struct Meta {
15 pub request_id: Option<String>,
17 pub timestamp: Option<u64>,
19 pub source: Option<String>,
21}
22
23impl Meta {
24 pub fn new() -> Self {
26 Self::default()
27 }
28
29 pub fn with_request_id(mut self, id: impl Into<String>) -> Self {
31 self.request_id = Some(id.into());
32 self
33 }
34
35 pub fn with_timestamp(mut self, ts: u64) -> Self {
37 self.timestamp = Some(ts);
38 self
39 }
40
41 pub fn with_source(mut self, source: impl Into<String>) -> Self {
43 self.source = Some(source.into());
44 self
45 }
46}
47
48#[derive(Debug, Clone)]
73pub struct ToolContext {
74 pub meta: Meta,
76 args: Option<JsonObject>,
78 extensions: Arc<DashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
80}
81
82impl ToolContext {
83 pub fn new() -> Self {
85 Self {
86 meta: Meta::default(),
87 args: Some(JsonObject::new()),
88 extensions: Arc::new(DashMap::new()),
89 }
90 }
91
92 pub fn with_args(args: JsonObject) -> Self {
94 Self {
95 meta: Meta::default(),
96 args: Some(args),
97 extensions: Arc::new(DashMap::new()),
98 }
99 }
100
101 pub fn empty() -> Self {
103 Self {
104 meta: Meta::default(),
105 args: None,
106 extensions: Arc::new(DashMap::new()),
107 }
108 }
109
110 pub fn with_meta(mut self, meta: Meta) -> Self {
112 self.meta = meta;
113 self
114 }
115
116 pub fn take_args(&mut self) -> Option<JsonObject> {
120 self.args.take()
121 }
122
123 pub fn peek_args(&self) -> Option<&JsonObject> {
125 self.args.as_ref()
126 }
127
128 pub fn insert<T>(&self, ext: T) -> Option<Arc<T>>
132 where
133 T: Send + Sync + 'static,
134 {
135 self.extensions
136 .insert(TypeId::of::<T>(), Arc::new(ext))
137 .and_then(|arc| arc.downcast().ok())
138 }
139
140 pub fn get<T>(&self) -> Option<Arc<T>>
142 where
143 T: Send + Sync + 'static,
144 {
145 self.extensions
146 .get(&TypeId::of::<T>())
147 .and_then(|entry| Arc::downcast(entry.clone()).ok())
148 }
149
150 pub fn remove<T>(&self) -> Option<Arc<T>>
152 where
153 T: Send + Sync + 'static,
154 {
155 self.extensions
156 .remove(&TypeId::of::<T>())
157 .and_then(|(_, arc)| Arc::downcast(arc).ok())
158 }
159
160 pub fn contains<T>(&self) -> bool
162 where
163 T: Send + Sync + 'static,
164 {
165 self.extensions.contains_key(&TypeId::of::<T>())
166 }
167
168 pub fn clear(&self) {
170 self.extensions.clear();
171 }
172
173 pub fn len(&self) -> usize {
175 self.extensions.len()
176 }
177
178 pub fn is_empty(&self) -> bool {
180 self.extensions.is_empty()
181 }
182}
183
184impl Default for ToolContext {
185 fn default() -> Self {
186 Self::new()
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_meta_builder() {
196 let meta = Meta::new()
197 .with_request_id("test-123")
198 .with_timestamp(1234567890)
199 .with_source("test");
200
201 assert_eq!(meta.request_id, Some("test-123".to_string()));
202 assert_eq!(meta.timestamp, Some(1234567890));
203 assert_eq!(meta.source, Some("test".to_string()));
204 }
205
206 #[test]
207 fn test_context_new() {
208 let ctx = ToolContext::new();
209 assert!(ctx.peek_args().is_some());
210 assert_eq!(ctx.meta.request_id, None);
211 }
212
213 #[test]
214 fn test_context_take_args() {
215 let mut ctx = ToolContext::with_args(
216 serde_json::json!({"key": "value"})
217 .as_object()
218 .unwrap()
219 .clone(),
220 );
221
222 let taken = ctx.take_args().unwrap();
223 assert!(taken.contains_key("key"));
224
225 assert!(ctx.take_args().is_none());
227 }
228
229 #[test]
230 fn test_context_extensions() {
231 let ctx = ToolContext::new();
232
233 #[derive(Debug, PartialEq)]
234 struct MyExt {
235 value: i32,
236 }
237
238 assert!(ctx.insert(MyExt { value: 42 }).is_none());
240
241 let ext = ctx.get::<MyExt>().unwrap();
243 assert_eq!(ext.value, 42);
244
245 assert!(ctx.contains::<MyExt>());
247
248 let removed = ctx.remove::<MyExt>().unwrap();
250 assert_eq!(removed.value, 42);
251 assert!(!ctx.contains::<MyExt>());
252 }
253
254 #[test]
255 fn test_context_clone() {
256 #[derive(Debug, PartialEq)]
257 struct MyExt {
258 value: i32,
259 }
260
261 let ctx = ToolContext::new();
262 ctx.insert(MyExt { value: 42 });
263
264 let cloned = ctx.clone();
266 let ext = cloned.get::<MyExt>().unwrap();
267 assert_eq!(ext.value, 42);
268 }
269
270 #[test]
271 fn test_context_thread_safety() {
272 use std::sync::Arc;
273 use std::thread;
274
275 let ctx = Arc::new(ToolContext::new());
276 let ctx_clone = ctx.clone();
277
278 #[derive(Debug)]
279 struct Counter;
280
281 ctx.insert(Counter);
283
284 let handle = thread::spawn(move || {
286 assert!(ctx_clone.contains::<Counter>());
287 });
288
289 handle.join().unwrap();
290 assert!(ctx.contains::<Counter>());
291 }
292}