1use tokio::sync::mpsc;
9
10use crate::tools::ToolSchema;
11use crate::{AgentEvent, Session};
12
13#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
23pub struct ToolExecutionSessionFlags {
24 pub bypass_permissions: bool,
27}
28
29impl ToolExecutionSessionFlags {
30 pub fn from_session(session: &Session) -> Self {
33 Self {
34 bypass_permissions: session
35 .agent_runtime_state
36 .as_ref()
37 .is_some_and(|state| state.bypass_permissions),
38 }
39 }
40}
41
42#[derive(Clone, Copy, Debug)]
52pub struct ToolExecutionContext<'a> {
53 pub session_id: Option<&'a str>,
55 pub tool_call_id: &'a str,
57 pub event_tx: Option<&'a mpsc::Sender<AgentEvent>>,
59 pub available_tool_schemas: Option<&'a [ToolSchema]>,
61 pub bypass_permissions: bool,
65 pub can_async_resume: bool,
75}
76
77impl<'a> ToolExecutionContext<'a> {
78 pub fn none(tool_call_id: &'a str) -> Self {
79 Self {
80 session_id: None,
81 tool_call_id,
82 event_tx: None,
83 available_tool_schemas: None,
84 bypass_permissions: false,
85 can_async_resume: false,
86 }
87 }
88
89 pub fn for_dispatch(
95 session_id: &'a str,
96 tool_call_id: &'a str,
97 event_tx: &'a mpsc::Sender<AgentEvent>,
98 available_tool_schemas: &'a [ToolSchema],
99 flags: ToolExecutionSessionFlags,
100 can_async_resume: bool,
105 ) -> Self {
106 Self {
107 session_id: Some(session_id),
108 tool_call_id,
109 event_tx: Some(event_tx),
110 available_tool_schemas: Some(available_tool_schemas),
111 bypass_permissions: flags.bypass_permissions,
112 can_async_resume,
113 }
114 }
115
116 pub fn cloned_sender(&self) -> Option<mpsc::Sender<AgentEvent>> {
118 self.event_tx.cloned()
119 }
120
121 pub async fn emit(&self, event: AgentEvent) {
123 if let Some(tx) = self.event_tx {
124 let event = match event {
128 AgentEvent::Token { content } => AgentEvent::ToolToken {
129 tool_call_id: self.tool_call_id.to_string(),
130 content,
131 },
132 other => other,
133 };
134 let _ = tx.try_send(event);
135 }
136 }
137
138 pub async fn emit_tool_token(&self, content: impl Into<String>) {
140 self.emit(AgentEvent::ToolToken {
141 tool_call_id: self.tool_call_id.to_string(),
142 content: content.into(),
143 })
144 .await;
145 }
146}
147
148#[cfg(test)]
149mod session_flags_tests {
150 use super::*;
151 use bamboo_domain::AgentRuntimeState;
152
153 #[test]
154 fn from_session_defaults_false_without_runtime_state() {
155 let session = Session::new("s-none", "test-model");
156 assert_eq!(
157 ToolExecutionSessionFlags::from_session(&session),
158 ToolExecutionSessionFlags {
159 bypass_permissions: false
160 }
161 );
162 }
163
164 #[test]
165 fn from_session_reads_bypass_from_runtime_state() {
166 let mut session = Session::new("s-bypass", "test-model");
167 let mut runtime = AgentRuntimeState::new("run-1");
168 runtime.bypass_permissions = true;
169 session.agent_runtime_state = Some(runtime);
170 assert!(ToolExecutionSessionFlags::from_session(&session).bypass_permissions);
171 }
172
173 #[test]
174 fn for_dispatch_maps_flags_onto_context() {
175 let (tx, _rx) = mpsc::channel(1);
176 let ctx = ToolExecutionContext::for_dispatch(
177 "s1",
178 "call-1",
179 &tx,
180 &[],
181 ToolExecutionSessionFlags {
182 bypass_permissions: true,
183 },
184 true,
185 );
186 assert_eq!(ctx.session_id, Some("s1"));
187 assert!(ctx.bypass_permissions);
188 assert!(ctx.can_async_resume);
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[tokio::test]
197 async fn emit_does_not_block_when_channel_is_full() {
198 let (tx, mut rx) = mpsc::channel(1);
199 tx.send(AgentEvent::Token {
200 content: "full".to_string(),
201 })
202 .await
203 .unwrap();
204 let ctx = ToolExecutionContext {
205 session_id: Some("session_1"),
206 tool_call_id: "call_1",
207 event_tx: Some(&tx),
208 available_tool_schemas: None,
209 bypass_permissions: false,
210 can_async_resume: false,
211 };
212
213 tokio::time::timeout(
214 std::time::Duration::from_millis(100),
215 ctx.emit(AgentEvent::Token {
216 content: "next".to_string(),
217 }),
218 )
219 .await
220 .expect("emit should not block on full channel");
221
222 let first = rx.recv().await.unwrap();
223 match first {
224 AgentEvent::Token { content } => assert_eq!(content, "full"),
225 other => panic!("unexpected event: {other:?}"),
226 }
227 }
228
229 #[tokio::test]
230 async fn emit_converts_token_to_tool_token() {
231 let (tx, mut rx) = mpsc::channel(10);
232 let ctx = ToolExecutionContext {
233 session_id: Some("session_1"),
234 tool_call_id: "call_123",
235 event_tx: Some(&tx),
236 available_tool_schemas: None,
237 bypass_permissions: false,
238 can_async_resume: false,
239 };
240
241 ctx.emit(AgentEvent::Token {
242 content: "test content".to_string(),
243 })
244 .await;
245
246 let event = rx.recv().await.unwrap();
247 match event {
248 AgentEvent::ToolToken {
249 tool_call_id,
250 content,
251 } => {
252 assert_eq!(tool_call_id, "call_123");
253 assert_eq!(content, "test content");
254 }
255 other => panic!("Expected ToolToken, got: {other:?}"),
256 }
257 }
258
259 #[tokio::test]
260 async fn emit_passes_through_non_token_events() {
261 let (tx, mut rx) = mpsc::channel(10);
262 let ctx = ToolExecutionContext {
263 session_id: Some("session_1"),
264 tool_call_id: "call_456",
265 event_tx: Some(&tx),
266 available_tool_schemas: None,
267 bypass_permissions: false,
268 can_async_resume: false,
269 };
270
271 ctx.emit(AgentEvent::ToolToken {
273 tool_call_id: "other".to_string(),
274 content: "direct tool token".to_string(),
275 })
276 .await;
277
278 let event = rx.recv().await.unwrap();
279 match event {
280 AgentEvent::ToolToken { content, .. } => {
281 assert_eq!(content, "direct tool token");
282 }
283 other => panic!("Expected ToolToken, got: {other:?}"),
284 }
285 }
286
287 #[tokio::test]
288 async fn emit_does_nothing_when_no_sender() {
289 let ctx = ToolExecutionContext::none("call_789");
290
291 ctx.emit(AgentEvent::Token {
293 content: "test".to_string(),
294 })
295 .await;
296
297 }
299
300 #[tokio::test]
301 async fn emit_tool_token_convenience_method() {
302 let (tx, mut rx) = mpsc::channel(10);
303 let ctx = ToolExecutionContext {
304 session_id: None,
305 tool_call_id: "call_abc",
306 event_tx: Some(&tx),
307 available_tool_schemas: None,
308 bypass_permissions: false,
309 can_async_resume: false,
310 };
311
312 ctx.emit_tool_token("convenient output").await;
313
314 let event = rx.recv().await.unwrap();
315 match event {
316 AgentEvent::ToolToken {
317 tool_call_id,
318 content,
319 } => {
320 assert_eq!(tool_call_id, "call_abc");
321 assert_eq!(content, "convenient output");
322 }
323 other => panic!("Expected ToolToken, got: {other:?}"),
324 }
325 }
326
327 #[tokio::test]
328 async fn emit_tool_token_with_no_sender_does_nothing() {
329 let ctx = ToolExecutionContext::none("call_def");
330
331 ctx.emit_tool_token("test").await;
333
334 }
336
337 #[test]
338 fn none_creates_context_with_no_optional_fields() {
339 let ctx = ToolExecutionContext::none("call_xyz");
340
341 assert_eq!(ctx.session_id, None);
342 assert_eq!(ctx.tool_call_id, "call_xyz");
343 assert!(ctx.event_tx.is_none());
344 }
345
346 #[test]
347 fn cloned_sender_returns_none_when_no_sender() {
348 let ctx = ToolExecutionContext::none("call_test");
349 assert!(ctx.cloned_sender().is_none());
350 }
351
352 #[tokio::test]
353 async fn cloned_sender_returns_clone_when_sender_present() {
354 let (tx, _rx) = mpsc::channel(10);
355 let ctx = ToolExecutionContext {
356 session_id: None,
357 tool_call_id: "call_clone",
358 event_tx: Some(&tx),
359 available_tool_schemas: None,
360 bypass_permissions: false,
361 can_async_resume: false,
362 };
363
364 let cloned = ctx.cloned_sender();
365 assert!(cloned.is_some());
366
367 cloned
369 .unwrap()
370 .send(AgentEvent::Token {
371 content: "test".to_string(),
372 })
373 .await
374 .unwrap();
375 }
376
377 #[tokio::test]
378 async fn emit_handles_multiple_sequential_calls() {
379 let (tx, mut rx) = mpsc::channel(10);
380 let ctx = ToolExecutionContext {
381 session_id: Some("session_multi"),
382 tool_call_id: "call_multi",
383 event_tx: Some(&tx),
384 available_tool_schemas: None,
385 bypass_permissions: false,
386 can_async_resume: false,
387 };
388
389 for i in 0..5 {
390 ctx.emit(AgentEvent::Token {
391 content: format!("message {}", i),
392 })
393 .await;
394 }
395
396 for i in 0..5 {
397 let event = rx.recv().await.unwrap();
398 match event {
399 AgentEvent::ToolToken { content, .. } => {
400 assert_eq!(content, format!("message {}", i));
401 }
402 other => panic!("Expected ToolToken, got: {other:?}"),
403 }
404 }
405 }
406
407 #[test]
408 fn context_is_clone_and_copy() {
409 let (tx, _rx) = mpsc::channel(10);
410 let ctx = ToolExecutionContext {
411 session_id: Some("session_copy"),
412 tool_call_id: "call_copy",
413 event_tx: Some(&tx),
414 available_tool_schemas: None,
415 bypass_permissions: false,
416 can_async_resume: false,
417 };
418
419 let _cloned = ctx;
421
422 let copied = ctx;
424
425 assert_eq!(copied.tool_call_id, "call_copy");
427 }
428
429 #[test]
430 fn context_is_debug() {
431 let ctx = ToolExecutionContext::none("call_debug");
432 let debug_str = format!("{:?}", ctx);
433 assert!(debug_str.contains("call_debug"));
434 }
435
436 #[tokio::test]
437 async fn emit_with_empty_tool_call_id() {
438 let (tx, mut rx) = mpsc::channel(10);
439 let ctx = ToolExecutionContext {
440 session_id: None,
441 tool_call_id: "",
442 event_tx: Some(&tx),
443 available_tool_schemas: None,
444 bypass_permissions: false,
445 can_async_resume: false,
446 };
447
448 ctx.emit(AgentEvent::Token {
449 content: "test".to_string(),
450 })
451 .await;
452
453 let event = rx.recv().await.unwrap();
454 match event {
455 AgentEvent::ToolToken { tool_call_id, .. } => {
456 assert_eq!(tool_call_id, "");
457 }
458 other => panic!("Expected ToolToken, got: {other:?}"),
459 }
460 }
461
462 #[tokio::test]
463 async fn emit_with_unicode_content() {
464 let (tx, mut rx) = mpsc::channel(10);
465 let ctx = ToolExecutionContext {
466 session_id: Some("会话"),
467 tool_call_id: "调用_123",
468 event_tx: Some(&tx),
469 available_tool_schemas: None,
470 bypass_permissions: false,
471 can_async_resume: false,
472 };
473
474 ctx.emit(AgentEvent::Token {
475 content: "测试内容 🎯".to_string(),
476 })
477 .await;
478
479 let event = rx.recv().await.unwrap();
480 match event {
481 AgentEvent::ToolToken {
482 tool_call_id,
483 content,
484 } => {
485 assert_eq!(tool_call_id, "调用_123");
486 assert_eq!(content, "测试内容 🎯");
487 }
488 other => panic!("Expected ToolToken, got: {other:?}"),
489 }
490 }
491
492 #[tokio::test]
493 async fn emit_with_special_characters_in_tool_call_id() {
494 let (tx, mut rx) = mpsc::channel(10);
495 let ctx = ToolExecutionContext {
496 session_id: None,
497 tool_call_id: "call-with_special.chars:123",
498 event_tx: Some(&tx),
499 available_tool_schemas: None,
500 bypass_permissions: false,
501 can_async_resume: false,
502 };
503
504 ctx.emit(AgentEvent::Token {
505 content: "test".to_string(),
506 })
507 .await;
508
509 let event = rx.recv().await.unwrap();
510 match event {
511 AgentEvent::ToolToken { tool_call_id, .. } => {
512 assert_eq!(tool_call_id, "call-with_special.chars:123");
513 }
514 other => panic!("Expected ToolToken, got: {other:?}"),
515 }
516 }
517
518 #[tokio::test]
519 async fn emit_tool_token_with_string_content() {
520 let (tx, mut rx) = mpsc::channel(10);
521 let ctx = ToolExecutionContext {
522 session_id: None,
523 tool_call_id: "call_string",
524 event_tx: Some(&tx),
525 available_tool_schemas: None,
526 bypass_permissions: false,
527 can_async_resume: false,
528 };
529
530 let content = String::from("owned string");
531 ctx.emit_tool_token(content).await;
532
533 let event = rx.recv().await.unwrap();
534 match event {
535 AgentEvent::ToolToken { content, .. } => {
536 assert_eq!(content, "owned string");
537 }
538 other => panic!("Expected ToolToken, got: {other:?}"),
539 }
540 }
541
542 #[tokio::test]
543 async fn emit_tool_token_with_str_content() {
544 let (tx, mut rx) = mpsc::channel(10);
545 let ctx = ToolExecutionContext {
546 session_id: None,
547 tool_call_id: "call_str",
548 event_tx: Some(&tx),
549 available_tool_schemas: None,
550 bypass_permissions: false,
551 can_async_resume: false,
552 };
553
554 ctx.emit_tool_token("string slice").await;
555
556 let event = rx.recv().await.unwrap();
557 match event {
558 AgentEvent::ToolToken { content, .. } => {
559 assert_eq!(content, "string slice");
560 }
561 other => panic!("Expected ToolToken, got: {other:?}"),
562 }
563 }
564}