Skip to main content

bamboo_agent/agent/loop_module/stream/
handler.rs

1use futures::StreamExt;
2use tokio::sync::mpsc;
3use tokio_util::sync::CancellationToken;
4
5use crate::agent::core::tools::{ToolCall, ToolCallAccumulator};
6use crate::agent::core::{AgentError, AgentEvent};
7use crate::agent::llm::{LLMChunk, LLMStream};
8
9pub struct StreamHandlingOutput {
10    pub content: String,
11    pub token_count: usize,
12    pub tool_calls: Vec<ToolCall>,
13}
14
15pub async fn consume_llm_stream(
16    mut stream: LLMStream,
17    event_tx: &mpsc::Sender<AgentEvent>,
18    cancel_token: &CancellationToken,
19    session_id: &str,
20) -> Result<StreamHandlingOutput, AgentError> {
21    let mut content = String::new();
22    let mut token_count = 0usize;
23    let mut tool_calls = ToolCallAccumulator::new();
24
25    while let Some(chunk_result) = stream.next().await {
26        if cancel_token.is_cancelled() {
27            return Err(AgentError::Cancelled);
28        }
29
30        match chunk_result {
31            Ok(LLMChunk::Token(token)) => {
32                token_count += token.len();
33                content.push_str(&token);
34
35                let _ = event_tx
36                    .send(AgentEvent::Token {
37                        content: token.clone(),
38                    })
39                    .await;
40            }
41            Ok(LLMChunk::ToolCalls(partial_calls)) => {
42                log::debug!(
43                    "[{}] Received {} tool call parts",
44                    session_id,
45                    partial_calls.len()
46                );
47                tool_calls.extend(partial_calls);
48            }
49            Ok(LLMChunk::Done) => {
50                log::debug!("[{}] LLM stream completed", session_id);
51            }
52            Err(error) => {
53                let message = format!("Stream error: {error}");
54                let _ = event_tx
55                    .send(AgentEvent::Error {
56                        message: message.clone(),
57                    })
58                    .await;
59                return Err(AgentError::LLM(error.to_string()));
60            }
61        }
62    }
63
64    Ok(StreamHandlingOutput {
65        content,
66        token_count,
67        tool_calls: tool_calls.finalize(),
68    })
69}
70
71#[cfg(test)]
72mod tests {
73    use futures::stream;
74    use tokio::sync::mpsc;
75    use tokio_util::sync::CancellationToken;
76
77    use crate::agent::core::tools::{FunctionCall, ToolCall};
78    use crate::agent::core::AgentEvent;
79    use crate::agent::llm::LLMStream;
80
81    use super::*;
82
83    fn build_stream(items: Vec<crate::agent::llm::provider::Result<LLMChunk>>) -> LLMStream {
84        Box::pin(stream::iter(items))
85    }
86
87    #[tokio::test]
88    async fn consume_llm_stream_accumulates_tokens_and_tool_calls() {
89        let stream = build_stream(vec![
90            Ok(LLMChunk::Token("hi".to_string())),
91            Ok(LLMChunk::ToolCalls(vec![ToolCall {
92                id: "call_1".to_string(),
93                tool_type: "function".to_string(),
94                function: FunctionCall {
95                    name: "test_tool".to_string(),
96                    arguments: "{".to_string(),
97                },
98            }])),
99            Ok(LLMChunk::ToolCalls(vec![ToolCall {
100                id: "call_1".to_string(),
101                tool_type: "function".to_string(),
102                function: FunctionCall {
103                    name: String::new(),
104                    arguments: "}".to_string(),
105                },
106            }])),
107            Ok(LLMChunk::Done),
108        ]);
109
110        let (event_tx, mut event_rx) = mpsc::channel::<AgentEvent>(8);
111        let output = consume_llm_stream(stream, &event_tx, &CancellationToken::new(), "session-1")
112            .await
113            .expect("stream should succeed");
114
115        assert_eq!(output.content, "hi");
116        assert_eq!(output.token_count, 2);
117        assert_eq!(output.tool_calls.len(), 1);
118        assert_eq!(output.tool_calls[0].function.name, "test_tool");
119        assert_eq!(output.tool_calls[0].function.arguments, "{}");
120
121        let token_event = event_rx.recv().await.expect("missing token event");
122        assert!(matches!(token_event, AgentEvent::Token { .. }));
123    }
124}