Skip to main content

llm_core/
stream.rs

1use std::pin::Pin;
2
3use futures::Stream;
4
5use crate::error::LlmError;
6use crate::types::{ToolCall, Usage};
7
8#[derive(Debug, Clone, PartialEq)]
9pub enum Chunk {
10    Text(String),
11    ToolCallStart { name: String, id: Option<String> },
12    ToolCallDelta { content: String },
13    Usage(Usage),
14    Done,
15}
16
17#[cfg(not(target_arch = "wasm32"))]
18pub type ResponseStream = Pin<Box<dyn Stream<Item = Result<Chunk, LlmError>> + Send>>;
19
20#[cfg(target_arch = "wasm32")]
21pub type ResponseStream = Pin<Box<dyn Stream<Item = Result<Chunk, LlmError>>>>;
22
23/// Extract concatenated text from a slice of chunks.
24pub fn collect_text(chunks: &[Chunk]) -> String {
25    let mut text = String::new();
26    for chunk in chunks {
27        if let Chunk::Text(t) = chunk {
28            text.push_str(t);
29        }
30    }
31    text
32}
33
34/// Assemble tool calls from a sequence of ToolCallStart/ToolCallDelta chunks.
35pub fn collect_tool_calls(chunks: &[Chunk]) -> Vec<ToolCall> {
36    let mut calls = Vec::new();
37    let mut current_name: Option<String> = None;
38    let mut current_id: Option<String> = None;
39    let mut current_args = String::new();
40
41    for chunk in chunks {
42        match chunk {
43            Chunk::ToolCallStart { name, id } => {
44                // Flush previous tool call if any
45                if let Some(prev_name) = current_name.take() {
46                    let arguments = serde_json::from_str(&current_args).unwrap_or_default();
47                    calls.push(ToolCall {
48                        name: prev_name,
49                        arguments,
50                        tool_call_id: current_id.take(),
51                    });
52                    current_args.clear();
53                }
54                current_name = Some(name.clone());
55                current_id = id.clone();
56            }
57            Chunk::ToolCallDelta { content } => {
58                current_args.push_str(content);
59            }
60            _ => {}
61        }
62    }
63    // Flush last tool call
64    if let Some(name) = current_name {
65        let arguments = serde_json::from_str(&current_args).unwrap_or_default();
66        calls.push(ToolCall {
67            name,
68            arguments,
69            tool_call_id: current_id,
70        });
71    }
72    calls
73}
74
75/// Extract the last Usage chunk, if any.
76pub fn collect_usage(chunks: &[Chunk]) -> Option<Usage> {
77    chunks.iter().rev().find_map(|c| {
78        if let Chunk::Usage(u) = c {
79            Some(u.clone())
80        } else {
81            None
82        }
83    })
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use futures::StreamExt;
90
91    #[test]
92    fn chunk_text_carries_content() {
93        let chunk = Chunk::Text("hello".into());
94        if let Chunk::Text(t) = &chunk {
95            assert_eq!(t, "hello");
96        } else {
97            panic!("expected Text chunk");
98        }
99    }
100
101    #[test]
102    fn chunk_tool_call_start() {
103        let chunk = Chunk::ToolCallStart {
104            name: "search".into(),
105            id: Some("tc_1".into()),
106        };
107        if let Chunk::ToolCallStart { name, id } = &chunk {
108            assert_eq!(name, "search");
109            assert_eq!(id.as_deref(), Some("tc_1"));
110        } else {
111            panic!("expected ToolCallStart");
112        }
113    }
114
115    #[test]
116    fn chunk_tool_call_delta() {
117        let chunk = Chunk::ToolCallDelta {
118            content: r#"{"query":"#.into(),
119        };
120        if let Chunk::ToolCallDelta { content } = &chunk {
121            assert_eq!(content, r#"{"query":"#);
122        } else {
123            panic!("expected ToolCallDelta");
124        }
125    }
126
127    #[test]
128    fn chunk_usage() {
129        let usage = Usage {
130            input: Some(10),
131            output: Some(5),
132            details: None,
133        };
134        let chunk = Chunk::Usage(usage.clone());
135        if let Chunk::Usage(u) = &chunk {
136            assert_eq!(u, &usage);
137        } else {
138            panic!("expected Usage chunk");
139        }
140    }
141
142    #[test]
143    fn chunk_done() {
144        let chunk = Chunk::Done;
145        assert!(matches!(chunk, Chunk::Done));
146    }
147
148    #[tokio::test]
149    async fn response_stream_collects_text() {
150        let chunks = vec![
151            Ok(Chunk::Text("Hello".into())),
152            Ok(Chunk::Text(" world".into())),
153            Ok(Chunk::Done),
154        ];
155        let stream: ResponseStream = Box::pin(futures::stream::iter(chunks));
156        let collected: Vec<_> = stream.collect().await;
157        assert_eq!(collected.len(), 3);
158
159        let mut text = String::new();
160        for item in &collected {
161            if let Ok(Chunk::Text(t)) = item {
162                text.push_str(t);
163            }
164        }
165        assert_eq!(text, "Hello world");
166    }
167
168    #[tokio::test]
169    async fn response_stream_propagates_error() {
170        let chunks: Vec<Result<Chunk, LlmError>> = vec![
171            Ok(Chunk::Text("Hi".into())),
172            Err(LlmError::Provider("connection reset".into())),
173        ];
174        let stream: ResponseStream = Box::pin(futures::stream::iter(chunks));
175        let collected: Vec<_> = stream.collect().await;
176        assert_eq!(collected.len(), 2);
177        assert!(collected[0].is_ok());
178        assert!(collected[1].is_err());
179    }
180
181    #[test]
182    fn collect_text_from_chunks() {
183        let chunks = vec![
184            Chunk::Text("Hello".into()),
185            Chunk::Text(" ".into()),
186            Chunk::ToolCallStart {
187                name: "x".into(),
188                id: None,
189            },
190            Chunk::Text("world".into()),
191            Chunk::Done,
192        ];
193        let text = collect_text(&chunks);
194        assert_eq!(text, "Hello world");
195    }
196
197    #[test]
198    fn collect_tool_calls_from_chunks() {
199        let chunks = vec![
200            Chunk::Text("Let me search.".into()),
201            Chunk::ToolCallStart {
202                name: "search".into(),
203                id: Some("tc_1".into()),
204            },
205            Chunk::ToolCallDelta {
206                content: r#"{"query":"#.into(),
207            },
208            Chunk::ToolCallDelta {
209                content: r#""rust"}"#.into(),
210            },
211            Chunk::Done,
212        ];
213        let calls = collect_tool_calls(&chunks);
214        assert_eq!(calls.len(), 1);
215        assert_eq!(calls[0].name, "search");
216        assert_eq!(calls[0].tool_call_id.as_deref(), Some("tc_1"));
217        assert_eq!(calls[0].arguments, serde_json::json!({"query": "rust"}));
218    }
219
220    #[test]
221    fn collect_tool_calls_multiple() {
222        let chunks = vec![
223            Chunk::ToolCallStart {
224                name: "a".into(),
225                id: Some("1".into()),
226            },
227            Chunk::ToolCallDelta {
228                content: r#"{}"#.into(),
229            },
230            Chunk::ToolCallStart {
231                name: "b".into(),
232                id: Some("2".into()),
233            },
234            Chunk::ToolCallDelta {
235                content: r#"{}"#.into(),
236            },
237            Chunk::Done,
238        ];
239        let calls = collect_tool_calls(&chunks);
240        assert_eq!(calls.len(), 2);
241        assert_eq!(calls[0].name, "a");
242        assert_eq!(calls[1].name, "b");
243    }
244
245    #[test]
246    fn collect_usage_from_chunks() {
247        let chunks = vec![
248            Chunk::Text("Hi".into()),
249            Chunk::Usage(Usage {
250                input: Some(5),
251                output: Some(1),
252                details: None,
253            }),
254            Chunk::Done,
255        ];
256        let usage = collect_usage(&chunks);
257        assert!(usage.is_some());
258        assert_eq!(usage.unwrap().input, Some(5));
259    }
260
261    #[test]
262    fn collect_usage_returns_none_when_absent() {
263        let chunks = vec![Chunk::Text("Hi".into()), Chunk::Done];
264        assert!(collect_usage(&chunks).is_none());
265    }
266}