1use crate::llm::{ContentBlock, StopReason, Usage};
8use futures::Stream;
9use std::pin::Pin;
10
11#[derive(Debug, Clone)]
16pub enum StreamDelta {
17 TextDelta {
19 delta: String,
21 block_index: usize,
23 },
24
25 ThinkingDelta {
27 delta: String,
29 block_index: usize,
31 },
32
33 ToolUseStart {
35 id: String,
37 name: String,
39 block_index: usize,
41 },
42
43 ToolInputDelta {
45 id: String,
47 delta: String,
49 block_index: usize,
51 },
52
53 Usage(Usage),
55
56 Done {
58 stop_reason: Option<StopReason>,
60 },
61
62 Error {
64 message: String,
66 recoverable: bool,
68 },
69}
70
71pub type StreamBox<'a> = Pin<Box<dyn Stream<Item = anyhow::Result<StreamDelta>> + Send + 'a>>;
73
74#[derive(Debug, Default)]
79pub struct StreamAccumulator {
80 text_blocks: Vec<String>,
82 thinking_blocks: Vec<String>,
84 tool_uses: Vec<ToolUseAccumulator>,
86 usage: Option<Usage>,
88 stop_reason: Option<StopReason>,
90}
91
92#[derive(Debug, Default)]
94pub struct ToolUseAccumulator {
95 pub id: String,
97 pub name: String,
99 pub input_json: String,
101 pub block_index: usize,
103}
104
105impl StreamAccumulator {
106 #[must_use]
108 pub fn new() -> Self {
109 Self::default()
110 }
111
112 pub fn apply(&mut self, delta: &StreamDelta) {
114 match delta {
115 StreamDelta::TextDelta { delta, block_index } => {
116 while self.text_blocks.len() <= *block_index {
117 self.text_blocks.push(String::new());
118 }
119 self.text_blocks[*block_index].push_str(delta);
120 }
121 StreamDelta::ThinkingDelta { delta, block_index } => {
122 while self.thinking_blocks.len() <= *block_index {
123 self.thinking_blocks.push(String::new());
124 }
125 self.thinking_blocks[*block_index].push_str(delta);
126 }
127 StreamDelta::ToolUseStart {
128 id,
129 name,
130 block_index,
131 } => {
132 self.tool_uses.push(ToolUseAccumulator {
133 id: id.clone(),
134 name: name.clone(),
135 input_json: String::new(),
136 block_index: *block_index,
137 });
138 }
139 StreamDelta::ToolInputDelta { id, delta, .. } => {
140 if let Some(tool) = self.tool_uses.iter_mut().find(|t| t.id == *id) {
141 tool.input_json.push_str(delta);
142 }
143 }
144 StreamDelta::Usage(u) => {
145 self.usage = Some(u.clone());
146 }
147 StreamDelta::Done { stop_reason } => {
148 self.stop_reason = *stop_reason;
149 }
150 StreamDelta::Error { .. } => {}
151 }
152 }
153
154 #[must_use]
156 pub const fn usage(&self) -> Option<&Usage> {
157 self.usage.as_ref()
158 }
159
160 #[must_use]
162 pub const fn stop_reason(&self) -> Option<&StopReason> {
163 self.stop_reason.as_ref()
164 }
165
166 #[must_use]
171 pub fn into_content_blocks(self) -> Vec<ContentBlock> {
172 let mut blocks: Vec<(usize, ContentBlock)> = Vec::new();
173
174 for (idx, thinking) in self.thinking_blocks.into_iter().enumerate() {
176 if !thinking.is_empty() {
177 blocks.push((idx, ContentBlock::Thinking { thinking }));
178 }
179 }
180
181 for (idx, text) in self.text_blocks.into_iter().enumerate() {
183 if !text.is_empty() {
184 blocks.push((idx, ContentBlock::Text { text }));
185 }
186 }
187
188 for tool in self.tool_uses {
190 let input: serde_json::Value =
191 serde_json::from_str(&tool.input_json).unwrap_or_else(|_| serde_json::json!({}));
192 blocks.push((
193 tool.block_index,
194 ContentBlock::ToolUse {
195 id: tool.id,
196 name: tool.name,
197 input,
198 thought_signature: None, },
200 ));
201 }
202
203 blocks.sort_by_key(|(idx, _)| *idx);
205
206 blocks.into_iter().map(|(_, block)| block).collect()
207 }
208
209 pub const fn take_usage(&mut self) -> Option<Usage> {
211 self.usage.take()
212 }
213
214 pub const fn take_stop_reason(&mut self) -> Option<StopReason> {
216 self.stop_reason.take()
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_accumulator_text_deltas() {
226 let mut acc = StreamAccumulator::new();
227
228 acc.apply(&StreamDelta::TextDelta {
229 delta: "Hello".to_string(),
230 block_index: 0,
231 });
232 acc.apply(&StreamDelta::TextDelta {
233 delta: " world".to_string(),
234 block_index: 0,
235 });
236
237 let blocks = acc.into_content_blocks();
238 assert_eq!(blocks.len(), 1);
239 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello world"));
240 }
241
242 #[test]
243 fn test_accumulator_multiple_text_blocks() {
244 let mut acc = StreamAccumulator::new();
245
246 acc.apply(&StreamDelta::TextDelta {
247 delta: "First".to_string(),
248 block_index: 0,
249 });
250 acc.apply(&StreamDelta::TextDelta {
251 delta: "Second".to_string(),
252 block_index: 1,
253 });
254
255 let blocks = acc.into_content_blocks();
256 assert_eq!(blocks.len(), 2);
257 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "First"));
258 assert!(matches!(&blocks[1], ContentBlock::Text { text } if text == "Second"));
259 }
260
261 #[test]
262 fn test_accumulator_tool_use() {
263 let mut acc = StreamAccumulator::new();
264
265 acc.apply(&StreamDelta::ToolUseStart {
266 id: "call_123".to_string(),
267 name: "read_file".to_string(),
268 block_index: 0,
269 });
270 acc.apply(&StreamDelta::ToolInputDelta {
271 id: "call_123".to_string(),
272 delta: r#"{"path":"#.to_string(),
273 block_index: 0,
274 });
275 acc.apply(&StreamDelta::ToolInputDelta {
276 id: "call_123".to_string(),
277 delta: r#""test.txt"}"#.to_string(),
278 block_index: 0,
279 });
280
281 let blocks = acc.into_content_blocks();
282 assert_eq!(blocks.len(), 1);
283 match &blocks[0] {
284 ContentBlock::ToolUse {
285 id, name, input, ..
286 } => {
287 assert_eq!(id, "call_123");
288 assert_eq!(name, "read_file");
289 assert_eq!(input["path"], "test.txt");
290 }
291 _ => panic!("Expected ToolUse block"),
292 }
293 }
294
295 #[test]
296 fn test_accumulator_mixed_content() {
297 let mut acc = StreamAccumulator::new();
298
299 acc.apply(&StreamDelta::TextDelta {
300 delta: "Let me read that file.".to_string(),
301 block_index: 0,
302 });
303 acc.apply(&StreamDelta::ToolUseStart {
304 id: "call_456".to_string(),
305 name: "read_file".to_string(),
306 block_index: 1,
307 });
308 acc.apply(&StreamDelta::ToolInputDelta {
309 id: "call_456".to_string(),
310 delta: r#"{"path":"file.txt"}"#.to_string(),
311 block_index: 1,
312 });
313 acc.apply(&StreamDelta::Usage(Usage {
314 input_tokens: 100,
315 output_tokens: 50,
316 }));
317 acc.apply(&StreamDelta::Done {
318 stop_reason: Some(StopReason::ToolUse),
319 });
320
321 assert!(acc.usage().is_some());
322 assert_eq!(acc.usage().map(|u| u.input_tokens), Some(100));
323 assert!(matches!(acc.stop_reason(), Some(StopReason::ToolUse)));
324
325 let blocks = acc.into_content_blocks();
326 assert_eq!(blocks.len(), 2);
327 assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
328 assert!(matches!(&blocks[1], ContentBlock::ToolUse { .. }));
329 }
330
331 #[test]
332 fn test_accumulator_invalid_tool_json() {
333 let mut acc = StreamAccumulator::new();
334
335 acc.apply(&StreamDelta::ToolUseStart {
336 id: "call_789".to_string(),
337 name: "test_tool".to_string(),
338 block_index: 0,
339 });
340 acc.apply(&StreamDelta::ToolInputDelta {
341 id: "call_789".to_string(),
342 delta: "invalid json {".to_string(),
343 block_index: 0,
344 });
345
346 let blocks = acc.into_content_blocks();
347 assert_eq!(blocks.len(), 1);
348 match &blocks[0] {
349 ContentBlock::ToolUse { input, .. } => {
350 assert!(input.is_object());
351 }
352 _ => panic!("Expected ToolUse block"),
353 }
354 }
355
356 #[test]
357 fn test_accumulator_empty() {
358 let acc = StreamAccumulator::new();
359 let blocks = acc.into_content_blocks();
360 assert!(blocks.is_empty());
361 }
362
363 #[test]
364 fn test_accumulator_skips_empty_text() {
365 let mut acc = StreamAccumulator::new();
366
367 acc.apply(&StreamDelta::TextDelta {
368 delta: String::new(),
369 block_index: 0,
370 });
371 acc.apply(&StreamDelta::TextDelta {
372 delta: "Hello".to_string(),
373 block_index: 1,
374 });
375
376 let blocks = acc.into_content_blocks();
377 assert_eq!(blocks.len(), 1);
378 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello"));
379 }
380}