1use tokio::sync::mpsc;
9
10use crate::tools::ToolSchema;
11use crate::AgentEvent;
12
13#[derive(Clone, Copy, Debug)]
17pub struct ToolExecutionContext<'a> {
18 pub session_id: Option<&'a str>,
20 pub tool_call_id: &'a str,
22 pub event_tx: Option<&'a mpsc::Sender<AgentEvent>>,
24 pub available_tool_schemas: Option<&'a [ToolSchema]>,
26}
27
28impl<'a> ToolExecutionContext<'a> {
29 pub fn none(tool_call_id: &'a str) -> Self {
30 Self {
31 session_id: None,
32 tool_call_id,
33 event_tx: None,
34 available_tool_schemas: None,
35 }
36 }
37
38 pub fn cloned_sender(&self) -> Option<mpsc::Sender<AgentEvent>> {
40 self.event_tx.cloned()
41 }
42
43 pub async fn emit(&self, event: AgentEvent) {
45 if let Some(tx) = self.event_tx {
46 let event = match event {
50 AgentEvent::Token { content } => AgentEvent::ToolToken {
51 tool_call_id: self.tool_call_id.to_string(),
52 content,
53 },
54 other => other,
55 };
56 let _ = tx.try_send(event);
57 }
58 }
59
60 pub async fn emit_tool_token(&self, content: impl Into<String>) {
62 self.emit(AgentEvent::ToolToken {
63 tool_call_id: self.tool_call_id.to_string(),
64 content: content.into(),
65 })
66 .await;
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73
74 #[tokio::test]
75 async fn emit_does_not_block_when_channel_is_full() {
76 let (tx, mut rx) = mpsc::channel(1);
77 tx.send(AgentEvent::Token {
78 content: "full".to_string(),
79 })
80 .await
81 .unwrap();
82 let ctx = ToolExecutionContext {
83 session_id: Some("session_1"),
84 tool_call_id: "call_1",
85 event_tx: Some(&tx),
86 available_tool_schemas: None,
87 };
88
89 tokio::time::timeout(
90 std::time::Duration::from_millis(100),
91 ctx.emit(AgentEvent::Token {
92 content: "next".to_string(),
93 }),
94 )
95 .await
96 .expect("emit should not block on full channel");
97
98 let first = rx.recv().await.unwrap();
99 match first {
100 AgentEvent::Token { content } => assert_eq!(content, "full"),
101 other => panic!("unexpected event: {other:?}"),
102 }
103 }
104
105 #[tokio::test]
106 async fn emit_converts_token_to_tool_token() {
107 let (tx, mut rx) = mpsc::channel(10);
108 let ctx = ToolExecutionContext {
109 session_id: Some("session_1"),
110 tool_call_id: "call_123",
111 event_tx: Some(&tx),
112 available_tool_schemas: None,
113 };
114
115 ctx.emit(AgentEvent::Token {
116 content: "test content".to_string(),
117 })
118 .await;
119
120 let event = rx.recv().await.unwrap();
121 match event {
122 AgentEvent::ToolToken {
123 tool_call_id,
124 content,
125 } => {
126 assert_eq!(tool_call_id, "call_123");
127 assert_eq!(content, "test content");
128 }
129 other => panic!("Expected ToolToken, got: {other:?}"),
130 }
131 }
132
133 #[tokio::test]
134 async fn emit_passes_through_non_token_events() {
135 let (tx, mut rx) = mpsc::channel(10);
136 let ctx = ToolExecutionContext {
137 session_id: Some("session_1"),
138 tool_call_id: "call_456",
139 event_tx: Some(&tx),
140 available_tool_schemas: None,
141 };
142
143 ctx.emit(AgentEvent::ToolToken {
145 tool_call_id: "other".to_string(),
146 content: "direct tool token".to_string(),
147 })
148 .await;
149
150 let event = rx.recv().await.unwrap();
151 match event {
152 AgentEvent::ToolToken { content, .. } => {
153 assert_eq!(content, "direct tool token");
154 }
155 other => panic!("Expected ToolToken, got: {other:?}"),
156 }
157 }
158
159 #[tokio::test]
160 async fn emit_does_nothing_when_no_sender() {
161 let ctx = ToolExecutionContext::none("call_789");
162
163 ctx.emit(AgentEvent::Token {
165 content: "test".to_string(),
166 })
167 .await;
168
169 }
171
172 #[tokio::test]
173 async fn emit_tool_token_convenience_method() {
174 let (tx, mut rx) = mpsc::channel(10);
175 let ctx = ToolExecutionContext {
176 session_id: None,
177 tool_call_id: "call_abc",
178 event_tx: Some(&tx),
179 available_tool_schemas: None,
180 };
181
182 ctx.emit_tool_token("convenient output").await;
183
184 let event = rx.recv().await.unwrap();
185 match event {
186 AgentEvent::ToolToken {
187 tool_call_id,
188 content,
189 } => {
190 assert_eq!(tool_call_id, "call_abc");
191 assert_eq!(content, "convenient output");
192 }
193 other => panic!("Expected ToolToken, got: {other:?}"),
194 }
195 }
196
197 #[tokio::test]
198 async fn emit_tool_token_with_no_sender_does_nothing() {
199 let ctx = ToolExecutionContext::none("call_def");
200
201 ctx.emit_tool_token("test").await;
203
204 }
206
207 #[test]
208 fn none_creates_context_with_no_optional_fields() {
209 let ctx = ToolExecutionContext::none("call_xyz");
210
211 assert_eq!(ctx.session_id, None);
212 assert_eq!(ctx.tool_call_id, "call_xyz");
213 assert!(ctx.event_tx.is_none());
214 }
215
216 #[test]
217 fn cloned_sender_returns_none_when_no_sender() {
218 let ctx = ToolExecutionContext::none("call_test");
219 assert!(ctx.cloned_sender().is_none());
220 }
221
222 #[tokio::test]
223 async fn cloned_sender_returns_clone_when_sender_present() {
224 let (tx, _rx) = mpsc::channel(10);
225 let ctx = ToolExecutionContext {
226 session_id: None,
227 tool_call_id: "call_clone",
228 event_tx: Some(&tx),
229 available_tool_schemas: None,
230 };
231
232 let cloned = ctx.cloned_sender();
233 assert!(cloned.is_some());
234
235 cloned
237 .unwrap()
238 .send(AgentEvent::Token {
239 content: "test".to_string(),
240 })
241 .await
242 .unwrap();
243 }
244
245 #[tokio::test]
246 async fn emit_handles_multiple_sequential_calls() {
247 let (tx, mut rx) = mpsc::channel(10);
248 let ctx = ToolExecutionContext {
249 session_id: Some("session_multi"),
250 tool_call_id: "call_multi",
251 event_tx: Some(&tx),
252 available_tool_schemas: None,
253 };
254
255 for i in 0..5 {
256 ctx.emit(AgentEvent::Token {
257 content: format!("message {}", i),
258 })
259 .await;
260 }
261
262 for i in 0..5 {
263 let event = rx.recv().await.unwrap();
264 match event {
265 AgentEvent::ToolToken { content, .. } => {
266 assert_eq!(content, format!("message {}", i));
267 }
268 other => panic!("Expected ToolToken, got: {other:?}"),
269 }
270 }
271 }
272
273 #[test]
274 fn context_is_clone_and_copy() {
275 let (tx, _rx) = mpsc::channel(10);
276 let ctx = ToolExecutionContext {
277 session_id: Some("session_copy"),
278 tool_call_id: "call_copy",
279 event_tx: Some(&tx),
280 available_tool_schemas: None,
281 };
282
283 let _cloned = ctx.clone();
285
286 let copied = ctx;
288
289 assert_eq!(copied.tool_call_id, "call_copy");
291 }
292
293 #[test]
294 fn context_is_debug() {
295 let ctx = ToolExecutionContext::none("call_debug");
296 let debug_str = format!("{:?}", ctx);
297 assert!(debug_str.contains("call_debug"));
298 }
299
300 #[tokio::test]
301 async fn emit_with_empty_tool_call_id() {
302 let (tx, mut rx) = mpsc::channel(10);
303 let ctx = ToolExecutionContext {
304 session_id: None,
305 tool_call_id: "",
306 event_tx: Some(&tx),
307 available_tool_schemas: None,
308 };
309
310 ctx.emit(AgentEvent::Token {
311 content: "test".to_string(),
312 })
313 .await;
314
315 let event = rx.recv().await.unwrap();
316 match event {
317 AgentEvent::ToolToken { tool_call_id, .. } => {
318 assert_eq!(tool_call_id, "");
319 }
320 other => panic!("Expected ToolToken, got: {other:?}"),
321 }
322 }
323
324 #[tokio::test]
325 async fn emit_with_unicode_content() {
326 let (tx, mut rx) = mpsc::channel(10);
327 let ctx = ToolExecutionContext {
328 session_id: Some("会话"),
329 tool_call_id: "调用_123",
330 event_tx: Some(&tx),
331 available_tool_schemas: None,
332 };
333
334 ctx.emit(AgentEvent::Token {
335 content: "测试内容 🎯".to_string(),
336 })
337 .await;
338
339 let event = rx.recv().await.unwrap();
340 match event {
341 AgentEvent::ToolToken {
342 tool_call_id,
343 content,
344 } => {
345 assert_eq!(tool_call_id, "调用_123");
346 assert_eq!(content, "测试内容 🎯");
347 }
348 other => panic!("Expected ToolToken, got: {other:?}"),
349 }
350 }
351
352 #[tokio::test]
353 async fn emit_with_special_characters_in_tool_call_id() {
354 let (tx, mut rx) = mpsc::channel(10);
355 let ctx = ToolExecutionContext {
356 session_id: None,
357 tool_call_id: "call-with_special.chars:123",
358 event_tx: Some(&tx),
359 available_tool_schemas: None,
360 };
361
362 ctx.emit(AgentEvent::Token {
363 content: "test".to_string(),
364 })
365 .await;
366
367 let event = rx.recv().await.unwrap();
368 match event {
369 AgentEvent::ToolToken { tool_call_id, .. } => {
370 assert_eq!(tool_call_id, "call-with_special.chars:123");
371 }
372 other => panic!("Expected ToolToken, got: {other:?}"),
373 }
374 }
375
376 #[tokio::test]
377 async fn emit_tool_token_with_string_content() {
378 let (tx, mut rx) = mpsc::channel(10);
379 let ctx = ToolExecutionContext {
380 session_id: None,
381 tool_call_id: "call_string",
382 event_tx: Some(&tx),
383 available_tool_schemas: None,
384 };
385
386 let content = String::from("owned string");
387 ctx.emit_tool_token(content).await;
388
389 let event = rx.recv().await.unwrap();
390 match event {
391 AgentEvent::ToolToken { content, .. } => {
392 assert_eq!(content, "owned string");
393 }
394 other => panic!("Expected ToolToken, got: {other:?}"),
395 }
396 }
397
398 #[tokio::test]
399 async fn emit_tool_token_with_str_content() {
400 let (tx, mut rx) = mpsc::channel(10);
401 let ctx = ToolExecutionContext {
402 session_id: None,
403 tool_call_id: "call_str",
404 event_tx: Some(&tx),
405 available_tool_schemas: None,
406 };
407
408 ctx.emit_tool_token("string slice").await;
409
410 let event = rx.recv().await.unwrap();
411 match event {
412 AgentEvent::ToolToken { content, .. } => {
413 assert_eq!(content, "string slice");
414 }
415 other => panic!("Expected ToolToken, got: {other:?}"),
416 }
417 }
418}