Skip to main content

bamboo_agent_core/tools/
context.rs

1//! Execution context for tool calls.
2//!
3//! Tools normally return a single `ToolResult` after completion. Some tools
4//! (for example, long-running CLIs) may want to stream intermediate progress
5//! to clients. The agent loop passes a `ToolExecutionContext` that allows tools
6//! to emit `AgentEvent`s while they run.
7
8use tokio::sync::mpsc;
9
10use crate::tools::ToolSchema;
11use crate::AgentEvent;
12
13/// Context passed to tools during execution.
14///
15/// All fields are optional and should be treated as best-effort hints.
16#[derive(Clone, Copy, Debug)]
17pub struct ToolExecutionContext<'a> {
18    /// Bamboo session id that is executing the tool.
19    pub session_id: Option<&'a str>,
20    /// Tool call id from the model (`ToolCall.id`).
21    pub tool_call_id: &'a str,
22    /// Event sender for streaming progress to clients (agent SSE stream).
23    pub event_tx: Option<&'a mpsc::Sender<AgentEvent>>,
24    /// Snapshot of tools currently available to the executing session.
25    pub available_tool_schemas: Option<&'a [ToolSchema]>,
26}
27
28impl<'a> ToolExecutionContext<'a> {
29    pub fn none(tool_call_id: &'a str) -> Self {
30        Self {
31            session_id: None,
32            tool_call_id,
33            event_tx: None,
34            available_tool_schemas: None,
35        }
36    }
37
38    /// Clone the sender (when present) for use in spawned tasks.
39    pub fn cloned_sender(&self) -> Option<mpsc::Sender<AgentEvent>> {
40        self.event_tx.cloned()
41    }
42
43    /// Best-effort emit of an event (ignored if no sender).
44    pub async fn emit(&self, event: AgentEvent) {
45        if let Some(tx) = self.event_tx {
46            // Tools sometimes want to stream incremental output. Historically they emitted
47            // `AgentEvent::Token`, but that mixes tool output into the assistant stream.
48            // When emitting from a tool context, treat `Token` as tool-scoped output.
49            let event = match event {
50                AgentEvent::Token { content } => AgentEvent::ToolToken {
51                    tool_call_id: self.tool_call_id.to_string(),
52                    content,
53                },
54                other => other,
55            };
56            let _ = tx.try_send(event);
57        }
58    }
59
60    /// Convenience helper for streaming tool-scoped output.
61    pub async fn emit_tool_token(&self, content: impl Into<String>) {
62        self.emit(AgentEvent::ToolToken {
63            tool_call_id: self.tool_call_id.to_string(),
64            content: content.into(),
65        })
66        .await;
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    #[tokio::test]
75    async fn emit_does_not_block_when_channel_is_full() {
76        let (tx, mut rx) = mpsc::channel(1);
77        tx.send(AgentEvent::Token {
78            content: "full".to_string(),
79        })
80        .await
81        .unwrap();
82        let ctx = ToolExecutionContext {
83            session_id: Some("session_1"),
84            tool_call_id: "call_1",
85            event_tx: Some(&tx),
86            available_tool_schemas: None,
87        };
88
89        tokio::time::timeout(
90            std::time::Duration::from_millis(100),
91            ctx.emit(AgentEvent::Token {
92                content: "next".to_string(),
93            }),
94        )
95        .await
96        .expect("emit should not block on full channel");
97
98        let first = rx.recv().await.unwrap();
99        match first {
100            AgentEvent::Token { content } => assert_eq!(content, "full"),
101            other => panic!("unexpected event: {other:?}"),
102        }
103    }
104
105    #[tokio::test]
106    async fn emit_converts_token_to_tool_token() {
107        let (tx, mut rx) = mpsc::channel(10);
108        let ctx = ToolExecutionContext {
109            session_id: Some("session_1"),
110            tool_call_id: "call_123",
111            event_tx: Some(&tx),
112            available_tool_schemas: None,
113        };
114
115        ctx.emit(AgentEvent::Token {
116            content: "test content".to_string(),
117        })
118        .await;
119
120        let event = rx.recv().await.unwrap();
121        match event {
122            AgentEvent::ToolToken {
123                tool_call_id,
124                content,
125            } => {
126                assert_eq!(tool_call_id, "call_123");
127                assert_eq!(content, "test content");
128            }
129            other => panic!("Expected ToolToken, got: {other:?}"),
130        }
131    }
132
133    #[tokio::test]
134    async fn emit_passes_through_non_token_events() {
135        let (tx, mut rx) = mpsc::channel(10);
136        let ctx = ToolExecutionContext {
137            session_id: Some("session_1"),
138            tool_call_id: "call_456",
139            event_tx: Some(&tx),
140            available_tool_schemas: None,
141        };
142
143        // Test with various non-Token events
144        ctx.emit(AgentEvent::ToolToken {
145            tool_call_id: "other".to_string(),
146            content: "direct tool token".to_string(),
147        })
148        .await;
149
150        let event = rx.recv().await.unwrap();
151        match event {
152            AgentEvent::ToolToken { content, .. } => {
153                assert_eq!(content, "direct tool token");
154            }
155            other => panic!("Expected ToolToken, got: {other:?}"),
156        }
157    }
158
159    #[tokio::test]
160    async fn emit_does_nothing_when_no_sender() {
161        let ctx = ToolExecutionContext::none("call_789");
162
163        // Should not panic or block
164        ctx.emit(AgentEvent::Token {
165            content: "test".to_string(),
166        })
167        .await;
168
169        // Success if we get here
170    }
171
172    #[tokio::test]
173    async fn emit_tool_token_convenience_method() {
174        let (tx, mut rx) = mpsc::channel(10);
175        let ctx = ToolExecutionContext {
176            session_id: None,
177            tool_call_id: "call_abc",
178            event_tx: Some(&tx),
179            available_tool_schemas: None,
180        };
181
182        ctx.emit_tool_token("convenient output").await;
183
184        let event = rx.recv().await.unwrap();
185        match event {
186            AgentEvent::ToolToken {
187                tool_call_id,
188                content,
189            } => {
190                assert_eq!(tool_call_id, "call_abc");
191                assert_eq!(content, "convenient output");
192            }
193            other => panic!("Expected ToolToken, got: {other:?}"),
194        }
195    }
196
197    #[tokio::test]
198    async fn emit_tool_token_with_no_sender_does_nothing() {
199        let ctx = ToolExecutionContext::none("call_def");
200
201        // Should not panic or block
202        ctx.emit_tool_token("test").await;
203
204        // Success if we get here
205    }
206
207    #[test]
208    fn none_creates_context_with_no_optional_fields() {
209        let ctx = ToolExecutionContext::none("call_xyz");
210
211        assert_eq!(ctx.session_id, None);
212        assert_eq!(ctx.tool_call_id, "call_xyz");
213        assert!(ctx.event_tx.is_none());
214    }
215
216    #[test]
217    fn cloned_sender_returns_none_when_no_sender() {
218        let ctx = ToolExecutionContext::none("call_test");
219        assert!(ctx.cloned_sender().is_none());
220    }
221
222    #[tokio::test]
223    async fn cloned_sender_returns_clone_when_sender_present() {
224        let (tx, _rx) = mpsc::channel(10);
225        let ctx = ToolExecutionContext {
226            session_id: None,
227            tool_call_id: "call_clone",
228            event_tx: Some(&tx),
229            available_tool_schemas: None,
230        };
231
232        let cloned = ctx.cloned_sender();
233        assert!(cloned.is_some());
234
235        // Can use cloned sender
236        cloned
237            .unwrap()
238            .send(AgentEvent::Token {
239                content: "test".to_string(),
240            })
241            .await
242            .unwrap();
243    }
244
245    #[tokio::test]
246    async fn emit_handles_multiple_sequential_calls() {
247        let (tx, mut rx) = mpsc::channel(10);
248        let ctx = ToolExecutionContext {
249            session_id: Some("session_multi"),
250            tool_call_id: "call_multi",
251            event_tx: Some(&tx),
252            available_tool_schemas: None,
253        };
254
255        for i in 0..5 {
256            ctx.emit(AgentEvent::Token {
257                content: format!("message {}", i),
258            })
259            .await;
260        }
261
262        for i in 0..5 {
263            let event = rx.recv().await.unwrap();
264            match event {
265                AgentEvent::ToolToken { content, .. } => {
266                    assert_eq!(content, format!("message {}", i));
267                }
268                other => panic!("Expected ToolToken, got: {other:?}"),
269            }
270        }
271    }
272
273    #[test]
274    fn context_is_clone_and_copy() {
275        let (tx, _rx) = mpsc::channel(10);
276        let ctx = ToolExecutionContext {
277            session_id: Some("session_copy"),
278            tool_call_id: "call_copy",
279            event_tx: Some(&tx),
280            available_tool_schemas: None,
281        };
282
283        // Can clone (Copy implies Clone)
284        let _cloned = ctx.clone();
285
286        // Can copy
287        let copied = ctx;
288
289        // Both are valid
290        assert_eq!(copied.tool_call_id, "call_copy");
291    }
292
293    #[test]
294    fn context_is_debug() {
295        let ctx = ToolExecutionContext::none("call_debug");
296        let debug_str = format!("{:?}", ctx);
297        assert!(debug_str.contains("call_debug"));
298    }
299
300    #[tokio::test]
301    async fn emit_with_empty_tool_call_id() {
302        let (tx, mut rx) = mpsc::channel(10);
303        let ctx = ToolExecutionContext {
304            session_id: None,
305            tool_call_id: "",
306            event_tx: Some(&tx),
307            available_tool_schemas: None,
308        };
309
310        ctx.emit(AgentEvent::Token {
311            content: "test".to_string(),
312        })
313        .await;
314
315        let event = rx.recv().await.unwrap();
316        match event {
317            AgentEvent::ToolToken { tool_call_id, .. } => {
318                assert_eq!(tool_call_id, "");
319            }
320            other => panic!("Expected ToolToken, got: {other:?}"),
321        }
322    }
323
324    #[tokio::test]
325    async fn emit_with_unicode_content() {
326        let (tx, mut rx) = mpsc::channel(10);
327        let ctx = ToolExecutionContext {
328            session_id: Some("会话"),
329            tool_call_id: "调用_123",
330            event_tx: Some(&tx),
331            available_tool_schemas: None,
332        };
333
334        ctx.emit(AgentEvent::Token {
335            content: "测试内容 🎯".to_string(),
336        })
337        .await;
338
339        let event = rx.recv().await.unwrap();
340        match event {
341            AgentEvent::ToolToken {
342                tool_call_id,
343                content,
344            } => {
345                assert_eq!(tool_call_id, "调用_123");
346                assert_eq!(content, "测试内容 🎯");
347            }
348            other => panic!("Expected ToolToken, got: {other:?}"),
349        }
350    }
351
352    #[tokio::test]
353    async fn emit_with_special_characters_in_tool_call_id() {
354        let (tx, mut rx) = mpsc::channel(10);
355        let ctx = ToolExecutionContext {
356            session_id: None,
357            tool_call_id: "call-with_special.chars:123",
358            event_tx: Some(&tx),
359            available_tool_schemas: None,
360        };
361
362        ctx.emit(AgentEvent::Token {
363            content: "test".to_string(),
364        })
365        .await;
366
367        let event = rx.recv().await.unwrap();
368        match event {
369            AgentEvent::ToolToken { tool_call_id, .. } => {
370                assert_eq!(tool_call_id, "call-with_special.chars:123");
371            }
372            other => panic!("Expected ToolToken, got: {other:?}"),
373        }
374    }
375
376    #[tokio::test]
377    async fn emit_tool_token_with_string_content() {
378        let (tx, mut rx) = mpsc::channel(10);
379        let ctx = ToolExecutionContext {
380            session_id: None,
381            tool_call_id: "call_string",
382            event_tx: Some(&tx),
383            available_tool_schemas: None,
384        };
385
386        let content = String::from("owned string");
387        ctx.emit_tool_token(content).await;
388
389        let event = rx.recv().await.unwrap();
390        match event {
391            AgentEvent::ToolToken { content, .. } => {
392                assert_eq!(content, "owned string");
393            }
394            other => panic!("Expected ToolToken, got: {other:?}"),
395        }
396    }
397
398    #[tokio::test]
399    async fn emit_tool_token_with_str_content() {
400        let (tx, mut rx) = mpsc::channel(10);
401        let ctx = ToolExecutionContext {
402            session_id: None,
403            tool_call_id: "call_str",
404            event_tx: Some(&tx),
405            available_tool_schemas: None,
406        };
407
408        ctx.emit_tool_token("string slice").await;
409
410        let event = rx.recv().await.unwrap();
411        match event {
412            AgentEvent::ToolToken { content, .. } => {
413                assert_eq!(content, "string slice");
414            }
415            other => panic!("Expected ToolToken, got: {other:?}"),
416        }
417    }
418}