bamboo_agent/agent/loop_module/stream/
handler.rs1use futures::StreamExt;
2use tokio::sync::mpsc;
3use tokio_util::sync::CancellationToken;
4
5use crate::agent::core::tools::{ToolCall, ToolCallAccumulator};
6use crate::agent::core::{AgentError, AgentEvent};
7use crate::agent::llm::{LLMChunk, LLMStream};
8
9pub struct StreamHandlingOutput {
10 pub content: String,
11 pub token_count: usize,
12 pub tool_calls: Vec<ToolCall>,
13}
14
15pub async fn consume_llm_stream(
16 mut stream: LLMStream,
17 event_tx: &mpsc::Sender<AgentEvent>,
18 cancel_token: &CancellationToken,
19 session_id: &str,
20) -> Result<StreamHandlingOutput, AgentError> {
21 let mut content = String::new();
22 let mut token_count = 0usize;
23 let mut tool_calls = ToolCallAccumulator::new();
24
25 while let Some(chunk_result) = stream.next().await {
26 if cancel_token.is_cancelled() {
27 return Err(AgentError::Cancelled);
28 }
29
30 match chunk_result {
31 Ok(LLMChunk::Token(token)) => {
32 token_count += token.len();
33 content.push_str(&token);
34
35 let _ = event_tx
36 .send(AgentEvent::Token {
37 content: token.clone(),
38 })
39 .await;
40 }
41 Ok(LLMChunk::ToolCalls(partial_calls)) => {
42 log::debug!(
43 "[{}] Received {} tool call parts",
44 session_id,
45 partial_calls.len()
46 );
47 tool_calls.extend(partial_calls);
48 }
49 Ok(LLMChunk::Done) => {
50 log::debug!("[{}] LLM stream completed", session_id);
51 }
52 Err(error) => {
53 let message = format!("Stream error: {error}");
54 let _ = event_tx
55 .send(AgentEvent::Error {
56 message: message.clone(),
57 })
58 .await;
59 return Err(AgentError::LLM(error.to_string()));
60 }
61 }
62 }
63
64 Ok(StreamHandlingOutput {
65 content,
66 token_count,
67 tool_calls: tool_calls.finalize(),
68 })
69}
70
71#[cfg(test)]
72mod tests {
73 use futures::stream;
74 use tokio::sync::mpsc;
75 use tokio_util::sync::CancellationToken;
76
77 use crate::agent::core::tools::{FunctionCall, ToolCall};
78 use crate::agent::core::AgentEvent;
79 use crate::agent::llm::LLMStream;
80
81 use super::*;
82
83 fn build_stream(items: Vec<crate::agent::llm::provider::Result<LLMChunk>>) -> LLMStream {
84 Box::pin(stream::iter(items))
85 }
86
87 #[tokio::test]
88 async fn consume_llm_stream_accumulates_tokens_and_tool_calls() {
89 let stream = build_stream(vec![
90 Ok(LLMChunk::Token("hi".to_string())),
91 Ok(LLMChunk::ToolCalls(vec![ToolCall {
92 id: "call_1".to_string(),
93 tool_type: "function".to_string(),
94 function: FunctionCall {
95 name: "test_tool".to_string(),
96 arguments: "{".to_string(),
97 },
98 }])),
99 Ok(LLMChunk::ToolCalls(vec![ToolCall {
100 id: "call_1".to_string(),
101 tool_type: "function".to_string(),
102 function: FunctionCall {
103 name: String::new(),
104 arguments: "}".to_string(),
105 },
106 }])),
107 Ok(LLMChunk::Done),
108 ]);
109
110 let (event_tx, mut event_rx) = mpsc::channel::<AgentEvent>(8);
111 let output = consume_llm_stream(stream, &event_tx, &CancellationToken::new(), "session-1")
112 .await
113 .expect("stream should succeed");
114
115 assert_eq!(output.content, "hi");
116 assert_eq!(output.token_count, 2);
117 assert_eq!(output.tool_calls.len(), 1);
118 assert_eq!(output.tool_calls[0].function.name, "test_tool");
119 assert_eq!(output.tool_calls[0].function.arguments, "{}");
120
121 let token_event = event_rx.recv().await.expect("missing token event");
122 assert!(matches!(token_event, AgentEvent::Token { .. }));
123 }
124}