1use tokio::sync::mpsc;
9
10use serde_json::Value;
11
12use crate::tools::ToolSchema;
13use crate::{AgentEvent, Session};
14
15#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
25pub struct ToolExecutionSessionFlags {
26 pub bypass_permissions: bool,
29}
30
31impl ToolExecutionSessionFlags {
32 pub fn from_session(session: &Session) -> Self {
35 Self {
36 bypass_permissions: session
37 .agent_runtime_state
38 .as_ref()
39 .is_some_and(|state| state.bypass_permissions),
40 }
41 }
42}
43
44#[derive(Clone, Copy, Debug)]
54pub struct ToolExecutionContext<'a> {
55 pub session_id: Option<&'a str>,
57 pub tool_call_id: &'a str,
59 pub event_tx: Option<&'a mpsc::Sender<AgentEvent>>,
61 pub available_tool_schemas: Option<&'a [ToolSchema]>,
63 pub bypass_permissions: bool,
67 pub can_async_resume: bool,
77 pub pre_parsed_args: Option<&'a Value>,
87}
88
89impl<'a> ToolExecutionContext<'a> {
90 pub fn none(tool_call_id: &'a str) -> Self {
91 Self {
92 session_id: None,
93 tool_call_id,
94 event_tx: None,
95 available_tool_schemas: None,
96 bypass_permissions: false,
97 can_async_resume: false,
98 pre_parsed_args: None,
99 }
100 }
101
102 pub fn for_dispatch(
108 session_id: &'a str,
109 tool_call_id: &'a str,
110 event_tx: &'a mpsc::Sender<AgentEvent>,
111 available_tool_schemas: &'a [ToolSchema],
112 flags: ToolExecutionSessionFlags,
113 can_async_resume: bool,
118 pre_parsed_args: Option<&'a Value>,
126 ) -> Self {
127 Self {
128 session_id: Some(session_id),
129 tool_call_id,
130 event_tx: Some(event_tx),
131 available_tool_schemas: Some(available_tool_schemas),
132 bypass_permissions: flags.bypass_permissions,
133 can_async_resume,
134 pre_parsed_args,
135 }
136 }
137
138 pub fn cloned_sender(&self) -> Option<mpsc::Sender<AgentEvent>> {
140 self.event_tx.cloned()
141 }
142
143 pub async fn emit(&self, event: AgentEvent) {
145 if let Some(tx) = self.event_tx {
146 let event = match event {
150 AgentEvent::Token { content } => AgentEvent::ToolToken {
151 tool_call_id: self.tool_call_id.to_string(),
152 content,
153 },
154 other => other,
155 };
156 let _ = tx.try_send(event);
157 }
158 }
159
160 pub async fn emit_tool_token(&self, content: impl Into<String>) {
162 self.emit(AgentEvent::ToolToken {
163 tool_call_id: self.tool_call_id.to_string(),
164 content: content.into(),
165 })
166 .await;
167 }
168}
169
170#[cfg(test)]
171mod session_flags_tests {
172 use super::*;
173 use bamboo_domain::AgentRuntimeState;
174
175 #[test]
176 fn from_session_defaults_false_without_runtime_state() {
177 let session = Session::new("s-none", "test-model");
178 assert_eq!(
179 ToolExecutionSessionFlags::from_session(&session),
180 ToolExecutionSessionFlags {
181 bypass_permissions: false
182 }
183 );
184 }
185
186 #[test]
187 fn from_session_reads_bypass_from_runtime_state() {
188 let mut session = Session::new("s-bypass", "test-model");
189 let mut runtime = AgentRuntimeState::new("run-1");
190 runtime.bypass_permissions = true;
191 session.agent_runtime_state = Some(runtime);
192 assert!(ToolExecutionSessionFlags::from_session(&session).bypass_permissions);
193 }
194
195 #[test]
196 fn for_dispatch_maps_flags_onto_context() {
197 let (tx, _rx) = mpsc::channel(1);
198 let ctx = ToolExecutionContext::for_dispatch(
199 "s1",
200 "call-1",
201 &tx,
202 &[],
203 ToolExecutionSessionFlags {
204 bypass_permissions: true,
205 },
206 true,
207 None,
208 );
209 assert_eq!(ctx.session_id, Some("s1"));
210 assert!(ctx.bypass_permissions);
211 assert!(ctx.can_async_resume);
212 assert!(ctx.pre_parsed_args.is_none());
213 }
214
215 #[test]
216 fn for_dispatch_threads_pre_parsed_args() {
217 let (tx, _rx) = mpsc::channel(1);
218 let parsed = serde_json::json!({"v": "x"});
219 let ctx = ToolExecutionContext::for_dispatch(
220 "s1",
221 "call-1",
222 &tx,
223 &[],
224 ToolExecutionSessionFlags::default(),
225 false,
226 Some(&parsed),
227 );
228 assert_eq!(ctx.pre_parsed_args, Some(&parsed));
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[tokio::test]
237 async fn emit_does_not_block_when_channel_is_full() {
238 let (tx, mut rx) = mpsc::channel(1);
239 tx.send(AgentEvent::Token {
240 content: "full".to_string(),
241 })
242 .await
243 .unwrap();
244 let ctx = ToolExecutionContext {
245 session_id: Some("session_1"),
246 tool_call_id: "call_1",
247 event_tx: Some(&tx),
248 available_tool_schemas: None,
249 bypass_permissions: false,
250 can_async_resume: false,
251 pre_parsed_args: None,
252 };
253
254 tokio::time::timeout(
255 std::time::Duration::from_millis(100),
256 ctx.emit(AgentEvent::Token {
257 content: "next".to_string(),
258 }),
259 )
260 .await
261 .expect("emit should not block on full channel");
262
263 let first = rx.recv().await.unwrap();
264 match first {
265 AgentEvent::Token { content } => assert_eq!(content, "full"),
266 other => panic!("unexpected event: {other:?}"),
267 }
268 }
269
270 #[tokio::test]
271 async fn emit_converts_token_to_tool_token() {
272 let (tx, mut rx) = mpsc::channel(10);
273 let ctx = ToolExecutionContext {
274 session_id: Some("session_1"),
275 tool_call_id: "call_123",
276 event_tx: Some(&tx),
277 available_tool_schemas: None,
278 bypass_permissions: false,
279 can_async_resume: false,
280 pre_parsed_args: None,
281 };
282
283 ctx.emit(AgentEvent::Token {
284 content: "test content".to_string(),
285 })
286 .await;
287
288 let event = rx.recv().await.unwrap();
289 match event {
290 AgentEvent::ToolToken {
291 tool_call_id,
292 content,
293 } => {
294 assert_eq!(tool_call_id, "call_123");
295 assert_eq!(content, "test content");
296 }
297 other => panic!("Expected ToolToken, got: {other:?}"),
298 }
299 }
300
301 #[tokio::test]
302 async fn emit_passes_through_non_token_events() {
303 let (tx, mut rx) = mpsc::channel(10);
304 let ctx = ToolExecutionContext {
305 session_id: Some("session_1"),
306 tool_call_id: "call_456",
307 event_tx: Some(&tx),
308 available_tool_schemas: None,
309 bypass_permissions: false,
310 can_async_resume: false,
311 pre_parsed_args: None,
312 };
313
314 ctx.emit(AgentEvent::ToolToken {
316 tool_call_id: "other".to_string(),
317 content: "direct tool token".to_string(),
318 })
319 .await;
320
321 let event = rx.recv().await.unwrap();
322 match event {
323 AgentEvent::ToolToken { content, .. } => {
324 assert_eq!(content, "direct tool token");
325 }
326 other => panic!("Expected ToolToken, got: {other:?}"),
327 }
328 }
329
330 #[tokio::test]
331 async fn emit_does_nothing_when_no_sender() {
332 let ctx = ToolExecutionContext::none("call_789");
333
334 ctx.emit(AgentEvent::Token {
336 content: "test".to_string(),
337 })
338 .await;
339
340 }
342
343 #[tokio::test]
344 async fn emit_tool_token_convenience_method() {
345 let (tx, mut rx) = mpsc::channel(10);
346 let ctx = ToolExecutionContext {
347 session_id: None,
348 tool_call_id: "call_abc",
349 event_tx: Some(&tx),
350 available_tool_schemas: None,
351 bypass_permissions: false,
352 can_async_resume: false,
353 pre_parsed_args: None,
354 };
355
356 ctx.emit_tool_token("convenient output").await;
357
358 let event = rx.recv().await.unwrap();
359 match event {
360 AgentEvent::ToolToken {
361 tool_call_id,
362 content,
363 } => {
364 assert_eq!(tool_call_id, "call_abc");
365 assert_eq!(content, "convenient output");
366 }
367 other => panic!("Expected ToolToken, got: {other:?}"),
368 }
369 }
370
371 #[tokio::test]
372 async fn emit_tool_token_with_no_sender_does_nothing() {
373 let ctx = ToolExecutionContext::none("call_def");
374
375 ctx.emit_tool_token("test").await;
377
378 }
380
381 #[test]
382 fn none_creates_context_with_no_optional_fields() {
383 let ctx = ToolExecutionContext::none("call_xyz");
384
385 assert_eq!(ctx.session_id, None);
386 assert_eq!(ctx.tool_call_id, "call_xyz");
387 assert!(ctx.event_tx.is_none());
388 }
389
390 #[test]
391 fn cloned_sender_returns_none_when_no_sender() {
392 let ctx = ToolExecutionContext::none("call_test");
393 assert!(ctx.cloned_sender().is_none());
394 }
395
396 #[tokio::test]
397 async fn cloned_sender_returns_clone_when_sender_present() {
398 let (tx, _rx) = mpsc::channel(10);
399 let ctx = ToolExecutionContext {
400 session_id: None,
401 tool_call_id: "call_clone",
402 event_tx: Some(&tx),
403 available_tool_schemas: None,
404 bypass_permissions: false,
405 can_async_resume: false,
406 pre_parsed_args: None,
407 };
408
409 let cloned = ctx.cloned_sender();
410 assert!(cloned.is_some());
411
412 cloned
414 .unwrap()
415 .send(AgentEvent::Token {
416 content: "test".to_string(),
417 })
418 .await
419 .unwrap();
420 }
421
422 #[tokio::test]
423 async fn emit_handles_multiple_sequential_calls() {
424 let (tx, mut rx) = mpsc::channel(10);
425 let ctx = ToolExecutionContext {
426 session_id: Some("session_multi"),
427 tool_call_id: "call_multi",
428 event_tx: Some(&tx),
429 available_tool_schemas: None,
430 bypass_permissions: false,
431 can_async_resume: false,
432 pre_parsed_args: None,
433 };
434
435 for i in 0..5 {
436 ctx.emit(AgentEvent::Token {
437 content: format!("message {}", i),
438 })
439 .await;
440 }
441
442 for i in 0..5 {
443 let event = rx.recv().await.unwrap();
444 match event {
445 AgentEvent::ToolToken { content, .. } => {
446 assert_eq!(content, format!("message {}", i));
447 }
448 other => panic!("Expected ToolToken, got: {other:?}"),
449 }
450 }
451 }
452
453 #[test]
454 fn context_is_clone_and_copy() {
455 let (tx, _rx) = mpsc::channel(10);
456 let ctx = ToolExecutionContext {
457 session_id: Some("session_copy"),
458 tool_call_id: "call_copy",
459 event_tx: Some(&tx),
460 available_tool_schemas: None,
461 bypass_permissions: false,
462 can_async_resume: false,
463 pre_parsed_args: None,
464 };
465
466 let _cloned = ctx;
468
469 let copied = ctx;
471
472 assert_eq!(copied.tool_call_id, "call_copy");
474 }
475
476 #[test]
477 fn context_is_debug() {
478 let ctx = ToolExecutionContext::none("call_debug");
479 let debug_str = format!("{:?}", ctx);
480 assert!(debug_str.contains("call_debug"));
481 }
482
483 #[tokio::test]
484 async fn emit_with_empty_tool_call_id() {
485 let (tx, mut rx) = mpsc::channel(10);
486 let ctx = ToolExecutionContext {
487 session_id: None,
488 tool_call_id: "",
489 event_tx: Some(&tx),
490 available_tool_schemas: None,
491 bypass_permissions: false,
492 can_async_resume: false,
493 pre_parsed_args: None,
494 };
495
496 ctx.emit(AgentEvent::Token {
497 content: "test".to_string(),
498 })
499 .await;
500
501 let event = rx.recv().await.unwrap();
502 match event {
503 AgentEvent::ToolToken { tool_call_id, .. } => {
504 assert_eq!(tool_call_id, "");
505 }
506 other => panic!("Expected ToolToken, got: {other:?}"),
507 }
508 }
509
510 #[tokio::test]
511 async fn emit_with_unicode_content() {
512 let (tx, mut rx) = mpsc::channel(10);
513 let ctx = ToolExecutionContext {
514 session_id: Some("会话"),
515 tool_call_id: "调用_123",
516 event_tx: Some(&tx),
517 available_tool_schemas: None,
518 bypass_permissions: false,
519 can_async_resume: false,
520 pre_parsed_args: None,
521 };
522
523 ctx.emit(AgentEvent::Token {
524 content: "测试内容 🎯".to_string(),
525 })
526 .await;
527
528 let event = rx.recv().await.unwrap();
529 match event {
530 AgentEvent::ToolToken {
531 tool_call_id,
532 content,
533 } => {
534 assert_eq!(tool_call_id, "调用_123");
535 assert_eq!(content, "测试内容 🎯");
536 }
537 other => panic!("Expected ToolToken, got: {other:?}"),
538 }
539 }
540
541 #[tokio::test]
542 async fn emit_with_special_characters_in_tool_call_id() {
543 let (tx, mut rx) = mpsc::channel(10);
544 let ctx = ToolExecutionContext {
545 session_id: None,
546 tool_call_id: "call-with_special.chars:123",
547 event_tx: Some(&tx),
548 available_tool_schemas: None,
549 bypass_permissions: false,
550 can_async_resume: false,
551 pre_parsed_args: None,
552 };
553
554 ctx.emit(AgentEvent::Token {
555 content: "test".to_string(),
556 })
557 .await;
558
559 let event = rx.recv().await.unwrap();
560 match event {
561 AgentEvent::ToolToken { tool_call_id, .. } => {
562 assert_eq!(tool_call_id, "call-with_special.chars:123");
563 }
564 other => panic!("Expected ToolToken, got: {other:?}"),
565 }
566 }
567
568 #[tokio::test]
569 async fn emit_tool_token_with_string_content() {
570 let (tx, mut rx) = mpsc::channel(10);
571 let ctx = ToolExecutionContext {
572 session_id: None,
573 tool_call_id: "call_string",
574 event_tx: Some(&tx),
575 available_tool_schemas: None,
576 bypass_permissions: false,
577 can_async_resume: false,
578 pre_parsed_args: None,
579 };
580
581 let content = String::from("owned string");
582 ctx.emit_tool_token(content).await;
583
584 let event = rx.recv().await.unwrap();
585 match event {
586 AgentEvent::ToolToken { content, .. } => {
587 assert_eq!(content, "owned string");
588 }
589 other => panic!("Expected ToolToken, got: {other:?}"),
590 }
591 }
592
593 #[tokio::test]
594 async fn emit_tool_token_with_str_content() {
595 let (tx, mut rx) = mpsc::channel(10);
596 let ctx = ToolExecutionContext {
597 session_id: None,
598 tool_call_id: "call_str",
599 event_tx: Some(&tx),
600 available_tool_schemas: None,
601 bypass_permissions: false,
602 can_async_resume: false,
603 pre_parsed_args: None,
604 };
605
606 ctx.emit_tool_token("string slice").await;
607
608 let event = rx.recv().await.unwrap();
609 match event {
610 AgentEvent::ToolToken { content, .. } => {
611 assert_eq!(content, "string slice");
612 }
613 other => panic!("Expected ToolToken, got: {other:?}"),
614 }
615 }
616}