1use std::pin::Pin;
2
3use futures::Stream;
4
5use crate::error::LlmError;
6use crate::types::{ToolCall, Usage};
7
8#[derive(Debug, Clone, PartialEq)]
9pub enum Chunk {
10 Text(String),
11 ToolCallStart { name: String, id: Option<String> },
12 ToolCallDelta { content: String },
13 Usage(Usage),
14 Done,
15}
16
17#[cfg(not(target_arch = "wasm32"))]
18pub type ResponseStream = Pin<Box<dyn Stream<Item = Result<Chunk, LlmError>> + Send>>;
19
20#[cfg(target_arch = "wasm32")]
21pub type ResponseStream = Pin<Box<dyn Stream<Item = Result<Chunk, LlmError>>>>;
22
23pub fn collect_text(chunks: &[Chunk]) -> String {
25 let mut text = String::new();
26 for chunk in chunks {
27 if let Chunk::Text(t) = chunk {
28 text.push_str(t);
29 }
30 }
31 text
32}
33
34pub fn collect_tool_calls(chunks: &[Chunk]) -> Vec<ToolCall> {
36 let mut calls = Vec::new();
37 let mut current_name: Option<String> = None;
38 let mut current_id: Option<String> = None;
39 let mut current_args = String::new();
40
41 for chunk in chunks {
42 match chunk {
43 Chunk::ToolCallStart { name, id } => {
44 if let Some(prev_name) = current_name.take() {
46 let arguments = serde_json::from_str(¤t_args).unwrap_or_default();
47 calls.push(ToolCall {
48 name: prev_name,
49 arguments,
50 tool_call_id: current_id.take(),
51 });
52 current_args.clear();
53 }
54 current_name = Some(name.clone());
55 current_id = id.clone();
56 }
57 Chunk::ToolCallDelta { content } => {
58 current_args.push_str(content);
59 }
60 _ => {}
61 }
62 }
63 if let Some(name) = current_name {
65 let arguments = serde_json::from_str(¤t_args).unwrap_or_default();
66 calls.push(ToolCall {
67 name,
68 arguments,
69 tool_call_id: current_id,
70 });
71 }
72 calls
73}
74
75pub fn collect_usage(chunks: &[Chunk]) -> Option<Usage> {
77 chunks.iter().rev().find_map(|c| {
78 if let Chunk::Usage(u) = c {
79 Some(u.clone())
80 } else {
81 None
82 }
83 })
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use futures::StreamExt;
90
91 #[test]
92 fn chunk_text_carries_content() {
93 let chunk = Chunk::Text("hello".into());
94 if let Chunk::Text(t) = &chunk {
95 assert_eq!(t, "hello");
96 } else {
97 panic!("expected Text chunk");
98 }
99 }
100
101 #[test]
102 fn chunk_tool_call_start() {
103 let chunk = Chunk::ToolCallStart {
104 name: "search".into(),
105 id: Some("tc_1".into()),
106 };
107 if let Chunk::ToolCallStart { name, id } = &chunk {
108 assert_eq!(name, "search");
109 assert_eq!(id.as_deref(), Some("tc_1"));
110 } else {
111 panic!("expected ToolCallStart");
112 }
113 }
114
115 #[test]
116 fn chunk_tool_call_delta() {
117 let chunk = Chunk::ToolCallDelta {
118 content: r#"{"query":"#.into(),
119 };
120 if let Chunk::ToolCallDelta { content } = &chunk {
121 assert_eq!(content, r#"{"query":"#);
122 } else {
123 panic!("expected ToolCallDelta");
124 }
125 }
126
127 #[test]
128 fn chunk_usage() {
129 let usage = Usage {
130 input: Some(10),
131 output: Some(5),
132 details: None,
133 };
134 let chunk = Chunk::Usage(usage.clone());
135 if let Chunk::Usage(u) = &chunk {
136 assert_eq!(u, &usage);
137 } else {
138 panic!("expected Usage chunk");
139 }
140 }
141
142 #[test]
143 fn chunk_done() {
144 let chunk = Chunk::Done;
145 assert!(matches!(chunk, Chunk::Done));
146 }
147
148 #[tokio::test]
149 async fn response_stream_collects_text() {
150 let chunks = vec![
151 Ok(Chunk::Text("Hello".into())),
152 Ok(Chunk::Text(" world".into())),
153 Ok(Chunk::Done),
154 ];
155 let stream: ResponseStream = Box::pin(futures::stream::iter(chunks));
156 let collected: Vec<_> = stream.collect().await;
157 assert_eq!(collected.len(), 3);
158
159 let mut text = String::new();
160 for item in &collected {
161 if let Ok(Chunk::Text(t)) = item {
162 text.push_str(t);
163 }
164 }
165 assert_eq!(text, "Hello world");
166 }
167
168 #[tokio::test]
169 async fn response_stream_propagates_error() {
170 let chunks: Vec<Result<Chunk, LlmError>> = vec![
171 Ok(Chunk::Text("Hi".into())),
172 Err(LlmError::Provider("connection reset".into())),
173 ];
174 let stream: ResponseStream = Box::pin(futures::stream::iter(chunks));
175 let collected: Vec<_> = stream.collect().await;
176 assert_eq!(collected.len(), 2);
177 assert!(collected[0].is_ok());
178 assert!(collected[1].is_err());
179 }
180
181 #[test]
182 fn collect_text_from_chunks() {
183 let chunks = vec![
184 Chunk::Text("Hello".into()),
185 Chunk::Text(" ".into()),
186 Chunk::ToolCallStart {
187 name: "x".into(),
188 id: None,
189 },
190 Chunk::Text("world".into()),
191 Chunk::Done,
192 ];
193 let text = collect_text(&chunks);
194 assert_eq!(text, "Hello world");
195 }
196
197 #[test]
198 fn collect_tool_calls_from_chunks() {
199 let chunks = vec![
200 Chunk::Text("Let me search.".into()),
201 Chunk::ToolCallStart {
202 name: "search".into(),
203 id: Some("tc_1".into()),
204 },
205 Chunk::ToolCallDelta {
206 content: r#"{"query":"#.into(),
207 },
208 Chunk::ToolCallDelta {
209 content: r#""rust"}"#.into(),
210 },
211 Chunk::Done,
212 ];
213 let calls = collect_tool_calls(&chunks);
214 assert_eq!(calls.len(), 1);
215 assert_eq!(calls[0].name, "search");
216 assert_eq!(calls[0].tool_call_id.as_deref(), Some("tc_1"));
217 assert_eq!(calls[0].arguments, serde_json::json!({"query": "rust"}));
218 }
219
220 #[test]
221 fn collect_tool_calls_multiple() {
222 let chunks = vec![
223 Chunk::ToolCallStart {
224 name: "a".into(),
225 id: Some("1".into()),
226 },
227 Chunk::ToolCallDelta {
228 content: r#"{}"#.into(),
229 },
230 Chunk::ToolCallStart {
231 name: "b".into(),
232 id: Some("2".into()),
233 },
234 Chunk::ToolCallDelta {
235 content: r#"{}"#.into(),
236 },
237 Chunk::Done,
238 ];
239 let calls = collect_tool_calls(&chunks);
240 assert_eq!(calls.len(), 2);
241 assert_eq!(calls[0].name, "a");
242 assert_eq!(calls[1].name, "b");
243 }
244
245 #[test]
246 fn collect_usage_from_chunks() {
247 let chunks = vec![
248 Chunk::Text("Hi".into()),
249 Chunk::Usage(Usage {
250 input: Some(5),
251 output: Some(1),
252 details: None,
253 }),
254 Chunk::Done,
255 ];
256 let usage = collect_usage(&chunks);
257 assert!(usage.is_some());
258 assert_eq!(usage.unwrap().input, Some(5));
259 }
260
261 #[test]
262 fn collect_usage_returns_none_when_absent() {
263 let chunks = vec![Chunk::Text("Hi".into()), Chunk::Done];
264 assert!(collect_usage(&chunks).is_none());
265 }
266}