1use tokio::sync::mpsc;
9
10use crate::tools::ToolSchema;
11use crate::{AgentEvent, Session};
12
13#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
23pub struct ToolExecutionSessionFlags {
24 pub bypass_permissions: bool,
27}
28
29impl ToolExecutionSessionFlags {
30 pub fn from_session(session: &Session) -> Self {
33 Self {
34 bypass_permissions: session
35 .agent_runtime_state
36 .as_ref()
37 .is_some_and(|state| state.bypass_permissions),
38 }
39 }
40}
41
42#[derive(Clone, Copy, Debug)]
52pub struct ToolExecutionContext<'a> {
53 pub session_id: Option<&'a str>,
55 pub tool_call_id: &'a str,
57 pub event_tx: Option<&'a mpsc::Sender<AgentEvent>>,
59 pub available_tool_schemas: Option<&'a [ToolSchema]>,
61 pub bypass_permissions: bool,
65}
66
67impl<'a> ToolExecutionContext<'a> {
68 pub fn none(tool_call_id: &'a str) -> Self {
69 Self {
70 session_id: None,
71 tool_call_id,
72 event_tx: None,
73 available_tool_schemas: None,
74 bypass_permissions: false,
75 }
76 }
77
78 pub fn for_dispatch(
84 session_id: &'a str,
85 tool_call_id: &'a str,
86 event_tx: &'a mpsc::Sender<AgentEvent>,
87 available_tool_schemas: &'a [ToolSchema],
88 flags: ToolExecutionSessionFlags,
89 ) -> Self {
90 Self {
91 session_id: Some(session_id),
92 tool_call_id,
93 event_tx: Some(event_tx),
94 available_tool_schemas: Some(available_tool_schemas),
95 bypass_permissions: flags.bypass_permissions,
96 }
97 }
98
99 pub fn cloned_sender(&self) -> Option<mpsc::Sender<AgentEvent>> {
101 self.event_tx.cloned()
102 }
103
104 pub async fn emit(&self, event: AgentEvent) {
106 if let Some(tx) = self.event_tx {
107 let event = match event {
111 AgentEvent::Token { content } => AgentEvent::ToolToken {
112 tool_call_id: self.tool_call_id.to_string(),
113 content,
114 },
115 other => other,
116 };
117 let _ = tx.try_send(event);
118 }
119 }
120
121 pub async fn emit_tool_token(&self, content: impl Into<String>) {
123 self.emit(AgentEvent::ToolToken {
124 tool_call_id: self.tool_call_id.to_string(),
125 content: content.into(),
126 })
127 .await;
128 }
129}
130
131#[cfg(test)]
132mod session_flags_tests {
133 use super::*;
134 use bamboo_domain::AgentRuntimeState;
135
136 #[test]
137 fn from_session_defaults_false_without_runtime_state() {
138 let session = Session::new("s-none", "test-model");
139 assert_eq!(
140 ToolExecutionSessionFlags::from_session(&session),
141 ToolExecutionSessionFlags {
142 bypass_permissions: false
143 }
144 );
145 }
146
147 #[test]
148 fn from_session_reads_bypass_from_runtime_state() {
149 let mut session = Session::new("s-bypass", "test-model");
150 let mut runtime = AgentRuntimeState::new("run-1");
151 runtime.bypass_permissions = true;
152 session.agent_runtime_state = Some(runtime);
153 assert!(ToolExecutionSessionFlags::from_session(&session).bypass_permissions);
154 }
155
156 #[test]
157 fn for_dispatch_maps_flags_onto_context() {
158 let (tx, _rx) = mpsc::channel(1);
159 let ctx = ToolExecutionContext::for_dispatch(
160 "s1",
161 "call-1",
162 &tx,
163 &[],
164 ToolExecutionSessionFlags {
165 bypass_permissions: true,
166 },
167 );
168 assert_eq!(ctx.session_id, Some("s1"));
169 assert!(ctx.bypass_permissions);
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[tokio::test]
178 async fn emit_does_not_block_when_channel_is_full() {
179 let (tx, mut rx) = mpsc::channel(1);
180 tx.send(AgentEvent::Token {
181 content: "full".to_string(),
182 })
183 .await
184 .unwrap();
185 let ctx = ToolExecutionContext {
186 session_id: Some("session_1"),
187 tool_call_id: "call_1",
188 event_tx: Some(&tx),
189 available_tool_schemas: None,
190 bypass_permissions: false,
191 };
192
193 tokio::time::timeout(
194 std::time::Duration::from_millis(100),
195 ctx.emit(AgentEvent::Token {
196 content: "next".to_string(),
197 }),
198 )
199 .await
200 .expect("emit should not block on full channel");
201
202 let first = rx.recv().await.unwrap();
203 match first {
204 AgentEvent::Token { content } => assert_eq!(content, "full"),
205 other => panic!("unexpected event: {other:?}"),
206 }
207 }
208
209 #[tokio::test]
210 async fn emit_converts_token_to_tool_token() {
211 let (tx, mut rx) = mpsc::channel(10);
212 let ctx = ToolExecutionContext {
213 session_id: Some("session_1"),
214 tool_call_id: "call_123",
215 event_tx: Some(&tx),
216 available_tool_schemas: None,
217 bypass_permissions: false,
218 };
219
220 ctx.emit(AgentEvent::Token {
221 content: "test content".to_string(),
222 })
223 .await;
224
225 let event = rx.recv().await.unwrap();
226 match event {
227 AgentEvent::ToolToken {
228 tool_call_id,
229 content,
230 } => {
231 assert_eq!(tool_call_id, "call_123");
232 assert_eq!(content, "test content");
233 }
234 other => panic!("Expected ToolToken, got: {other:?}"),
235 }
236 }
237
238 #[tokio::test]
239 async fn emit_passes_through_non_token_events() {
240 let (tx, mut rx) = mpsc::channel(10);
241 let ctx = ToolExecutionContext {
242 session_id: Some("session_1"),
243 tool_call_id: "call_456",
244 event_tx: Some(&tx),
245 available_tool_schemas: None,
246 bypass_permissions: false,
247 };
248
249 ctx.emit(AgentEvent::ToolToken {
251 tool_call_id: "other".to_string(),
252 content: "direct tool token".to_string(),
253 })
254 .await;
255
256 let event = rx.recv().await.unwrap();
257 match event {
258 AgentEvent::ToolToken { content, .. } => {
259 assert_eq!(content, "direct tool token");
260 }
261 other => panic!("Expected ToolToken, got: {other:?}"),
262 }
263 }
264
265 #[tokio::test]
266 async fn emit_does_nothing_when_no_sender() {
267 let ctx = ToolExecutionContext::none("call_789");
268
269 ctx.emit(AgentEvent::Token {
271 content: "test".to_string(),
272 })
273 .await;
274
275 }
277
278 #[tokio::test]
279 async fn emit_tool_token_convenience_method() {
280 let (tx, mut rx) = mpsc::channel(10);
281 let ctx = ToolExecutionContext {
282 session_id: None,
283 tool_call_id: "call_abc",
284 event_tx: Some(&tx),
285 available_tool_schemas: None,
286 bypass_permissions: false,
287 };
288
289 ctx.emit_tool_token("convenient output").await;
290
291 let event = rx.recv().await.unwrap();
292 match event {
293 AgentEvent::ToolToken {
294 tool_call_id,
295 content,
296 } => {
297 assert_eq!(tool_call_id, "call_abc");
298 assert_eq!(content, "convenient output");
299 }
300 other => panic!("Expected ToolToken, got: {other:?}"),
301 }
302 }
303
304 #[tokio::test]
305 async fn emit_tool_token_with_no_sender_does_nothing() {
306 let ctx = ToolExecutionContext::none("call_def");
307
308 ctx.emit_tool_token("test").await;
310
311 }
313
314 #[test]
315 fn none_creates_context_with_no_optional_fields() {
316 let ctx = ToolExecutionContext::none("call_xyz");
317
318 assert_eq!(ctx.session_id, None);
319 assert_eq!(ctx.tool_call_id, "call_xyz");
320 assert!(ctx.event_tx.is_none());
321 }
322
323 #[test]
324 fn cloned_sender_returns_none_when_no_sender() {
325 let ctx = ToolExecutionContext::none("call_test");
326 assert!(ctx.cloned_sender().is_none());
327 }
328
329 #[tokio::test]
330 async fn cloned_sender_returns_clone_when_sender_present() {
331 let (tx, _rx) = mpsc::channel(10);
332 let ctx = ToolExecutionContext {
333 session_id: None,
334 tool_call_id: "call_clone",
335 event_tx: Some(&tx),
336 available_tool_schemas: None,
337 bypass_permissions: false,
338 };
339
340 let cloned = ctx.cloned_sender();
341 assert!(cloned.is_some());
342
343 cloned
345 .unwrap()
346 .send(AgentEvent::Token {
347 content: "test".to_string(),
348 })
349 .await
350 .unwrap();
351 }
352
353 #[tokio::test]
354 async fn emit_handles_multiple_sequential_calls() {
355 let (tx, mut rx) = mpsc::channel(10);
356 let ctx = ToolExecutionContext {
357 session_id: Some("session_multi"),
358 tool_call_id: "call_multi",
359 event_tx: Some(&tx),
360 available_tool_schemas: None,
361 bypass_permissions: false,
362 };
363
364 for i in 0..5 {
365 ctx.emit(AgentEvent::Token {
366 content: format!("message {}", i),
367 })
368 .await;
369 }
370
371 for i in 0..5 {
372 let event = rx.recv().await.unwrap();
373 match event {
374 AgentEvent::ToolToken { content, .. } => {
375 assert_eq!(content, format!("message {}", i));
376 }
377 other => panic!("Expected ToolToken, got: {other:?}"),
378 }
379 }
380 }
381
382 #[test]
383 fn context_is_clone_and_copy() {
384 let (tx, _rx) = mpsc::channel(10);
385 let ctx = ToolExecutionContext {
386 session_id: Some("session_copy"),
387 tool_call_id: "call_copy",
388 event_tx: Some(&tx),
389 available_tool_schemas: None,
390 bypass_permissions: false,
391 };
392
393 let _cloned = ctx;
395
396 let copied = ctx;
398
399 assert_eq!(copied.tool_call_id, "call_copy");
401 }
402
403 #[test]
404 fn context_is_debug() {
405 let ctx = ToolExecutionContext::none("call_debug");
406 let debug_str = format!("{:?}", ctx);
407 assert!(debug_str.contains("call_debug"));
408 }
409
410 #[tokio::test]
411 async fn emit_with_empty_tool_call_id() {
412 let (tx, mut rx) = mpsc::channel(10);
413 let ctx = ToolExecutionContext {
414 session_id: None,
415 tool_call_id: "",
416 event_tx: Some(&tx),
417 available_tool_schemas: None,
418 bypass_permissions: false,
419 };
420
421 ctx.emit(AgentEvent::Token {
422 content: "test".to_string(),
423 })
424 .await;
425
426 let event = rx.recv().await.unwrap();
427 match event {
428 AgentEvent::ToolToken { tool_call_id, .. } => {
429 assert_eq!(tool_call_id, "");
430 }
431 other => panic!("Expected ToolToken, got: {other:?}"),
432 }
433 }
434
435 #[tokio::test]
436 async fn emit_with_unicode_content() {
437 let (tx, mut rx) = mpsc::channel(10);
438 let ctx = ToolExecutionContext {
439 session_id: Some("会话"),
440 tool_call_id: "调用_123",
441 event_tx: Some(&tx),
442 available_tool_schemas: None,
443 bypass_permissions: false,
444 };
445
446 ctx.emit(AgentEvent::Token {
447 content: "测试内容 🎯".to_string(),
448 })
449 .await;
450
451 let event = rx.recv().await.unwrap();
452 match event {
453 AgentEvent::ToolToken {
454 tool_call_id,
455 content,
456 } => {
457 assert_eq!(tool_call_id, "调用_123");
458 assert_eq!(content, "测试内容 🎯");
459 }
460 other => panic!("Expected ToolToken, got: {other:?}"),
461 }
462 }
463
464 #[tokio::test]
465 async fn emit_with_special_characters_in_tool_call_id() {
466 let (tx, mut rx) = mpsc::channel(10);
467 let ctx = ToolExecutionContext {
468 session_id: None,
469 tool_call_id: "call-with_special.chars:123",
470 event_tx: Some(&tx),
471 available_tool_schemas: None,
472 bypass_permissions: false,
473 };
474
475 ctx.emit(AgentEvent::Token {
476 content: "test".to_string(),
477 })
478 .await;
479
480 let event = rx.recv().await.unwrap();
481 match event {
482 AgentEvent::ToolToken { tool_call_id, .. } => {
483 assert_eq!(tool_call_id, "call-with_special.chars:123");
484 }
485 other => panic!("Expected ToolToken, got: {other:?}"),
486 }
487 }
488
489 #[tokio::test]
490 async fn emit_tool_token_with_string_content() {
491 let (tx, mut rx) = mpsc::channel(10);
492 let ctx = ToolExecutionContext {
493 session_id: None,
494 tool_call_id: "call_string",
495 event_tx: Some(&tx),
496 available_tool_schemas: None,
497 bypass_permissions: false,
498 };
499
500 let content = String::from("owned string");
501 ctx.emit_tool_token(content).await;
502
503 let event = rx.recv().await.unwrap();
504 match event {
505 AgentEvent::ToolToken { content, .. } => {
506 assert_eq!(content, "owned string");
507 }
508 other => panic!("Expected ToolToken, got: {other:?}"),
509 }
510 }
511
512 #[tokio::test]
513 async fn emit_tool_token_with_str_content() {
514 let (tx, mut rx) = mpsc::channel(10);
515 let ctx = ToolExecutionContext {
516 session_id: None,
517 tool_call_id: "call_str",
518 event_tx: Some(&tx),
519 available_tool_schemas: None,
520 bypass_permissions: false,
521 };
522
523 ctx.emit_tool_token("string slice").await;
524
525 let event = rx.recv().await.unwrap();
526 match event {
527 AgentEvent::ToolToken { content, .. } => {
528 assert_eq!(content, "string slice");
529 }
530 other => panic!("Expected ToolToken, got: {other:?}"),
531 }
532 }
533}