1use crate::llm::{ContentBlock, StopReason, Usage};
8use futures::Stream;
9use std::collections::HashMap;
10use std::pin::Pin;
11
12#[derive(Debug, Clone)]
17pub enum StreamDelta {
18 TextDelta {
20 delta: String,
22 block_index: usize,
24 },
25
26 ThinkingDelta {
28 delta: String,
30 block_index: usize,
32 },
33
34 ToolUseStart {
36 id: String,
38 name: String,
40 block_index: usize,
42 thought_signature: Option<String>,
44 },
45
46 ToolInputDelta {
48 id: String,
50 delta: String,
52 block_index: usize,
54 },
55
56 Usage(Usage),
58
59 Done {
61 stop_reason: Option<StopReason>,
63 },
64
65 SignatureDelta {
67 delta: String,
69 block_index: usize,
71 },
72
73 RedactedThinking {
75 data: String,
77 block_index: usize,
79 },
80
81 Error {
83 message: String,
85 recoverable: bool,
87 },
88}
89
90pub type StreamBox<'a> = Pin<Box<dyn Stream<Item = anyhow::Result<StreamDelta>> + Send + 'a>>;
92
93#[derive(Debug, Default)]
98pub struct StreamAccumulator {
99 text_blocks: Vec<String>,
101 thinking_blocks: Vec<String>,
103 thinking_signatures: HashMap<usize, String>,
105 redacted_thinking_blocks: Vec<(usize, String)>,
107 tool_uses: Vec<ToolUseAccumulator>,
109 usage: Option<Usage>,
111 stop_reason: Option<StopReason>,
113}
114
115#[derive(Debug, Default)]
117pub struct ToolUseAccumulator {
118 pub id: String,
120 pub name: String,
122 pub input_json: String,
124 pub block_index: usize,
126 pub thought_signature: Option<String>,
128}
129
130impl StreamAccumulator {
131 #[must_use]
133 pub fn new() -> Self {
134 Self::default()
135 }
136
137 pub fn apply(&mut self, delta: &StreamDelta) {
139 match delta {
140 StreamDelta::TextDelta { delta, block_index } => {
141 while self.text_blocks.len() <= *block_index {
142 self.text_blocks.push(String::new());
143 }
144 self.text_blocks[*block_index].push_str(delta);
145 }
146 StreamDelta::ThinkingDelta { delta, block_index } => {
147 while self.thinking_blocks.len() <= *block_index {
148 self.thinking_blocks.push(String::new());
149 }
150 self.thinking_blocks[*block_index].push_str(delta);
151 }
152 StreamDelta::ToolUseStart {
153 id,
154 name,
155 block_index,
156 thought_signature,
157 } => {
158 self.tool_uses.push(ToolUseAccumulator {
159 id: id.clone(),
160 name: name.clone(),
161 input_json: String::new(),
162 block_index: *block_index,
163 thought_signature: thought_signature.clone(),
164 });
165 }
166 StreamDelta::ToolInputDelta { id, delta, .. } => {
167 if let Some(tool) = self.tool_uses.iter_mut().find(|t| t.id == *id) {
168 tool.input_json.push_str(delta);
169 }
170 }
171 StreamDelta::SignatureDelta { delta, block_index } => {
172 self.thinking_signatures
173 .entry(*block_index)
174 .or_default()
175 .push_str(delta);
176 }
177 StreamDelta::RedactedThinking { data, block_index } => {
178 self.redacted_thinking_blocks
179 .push((*block_index, data.clone()));
180 }
181 StreamDelta::Usage(u) => {
182 self.usage = Some(u.clone());
183 }
184 StreamDelta::Done { stop_reason } => {
185 self.stop_reason = *stop_reason;
186 }
187 StreamDelta::Error { .. } => {}
188 }
189 }
190
191 #[must_use]
193 pub const fn usage(&self) -> Option<&Usage> {
194 self.usage.as_ref()
195 }
196
197 #[must_use]
199 pub const fn stop_reason(&self) -> Option<&StopReason> {
200 self.stop_reason.as_ref()
201 }
202
203 #[must_use]
208 pub fn into_content_blocks(self) -> Vec<ContentBlock> {
209 let mut blocks: Vec<(usize, ContentBlock)> = Vec::new();
210
211 let mut signatures = self.thinking_signatures;
213 for (idx, thinking) in self.thinking_blocks.into_iter().enumerate() {
214 if !thinking.is_empty() {
215 let signature = signatures.remove(&idx).filter(|s| !s.is_empty());
216 blocks.push((
217 idx,
218 ContentBlock::Thinking {
219 thinking,
220 signature,
221 },
222 ));
223 }
224 }
225
226 for (idx, data) in self.redacted_thinking_blocks {
228 blocks.push((idx, ContentBlock::RedactedThinking { data }));
229 }
230
231 for (idx, text) in self.text_blocks.into_iter().enumerate() {
233 if !text.is_empty() {
234 blocks.push((idx, ContentBlock::Text { text }));
235 }
236 }
237
238 for tool in self.tool_uses {
240 let input: serde_json::Value =
241 serde_json::from_str(&tool.input_json).unwrap_or_else(|e| {
242 log::warn!(
243 "Failed to parse streamed tool input JSON for tool '{}' (id={}): {} — \
244 input_json ({} bytes): '{}'",
245 tool.name,
246 tool.id,
247 e,
248 tool.input_json.len(),
249 tool.input_json.chars().take(500).collect::<String>(),
250 );
251 serde_json::json!({})
252 });
253 blocks.push((
254 tool.block_index,
255 ContentBlock::ToolUse {
256 id: tool.id,
257 name: tool.name,
258 input,
259 thought_signature: tool.thought_signature,
260 },
261 ));
262 }
263
264 blocks.sort_by_key(|(idx, _)| *idx);
266
267 blocks.into_iter().map(|(_, block)| block).collect()
268 }
269
270 pub const fn take_usage(&mut self) -> Option<Usage> {
272 self.usage.take()
273 }
274
275 pub const fn take_stop_reason(&mut self) -> Option<StopReason> {
277 self.stop_reason.take()
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_accumulator_text_deltas() {
287 let mut acc = StreamAccumulator::new();
288
289 acc.apply(&StreamDelta::TextDelta {
290 delta: "Hello".to_string(),
291 block_index: 0,
292 });
293 acc.apply(&StreamDelta::TextDelta {
294 delta: " world".to_string(),
295 block_index: 0,
296 });
297
298 let blocks = acc.into_content_blocks();
299 assert_eq!(blocks.len(), 1);
300 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello world"));
301 }
302
303 #[test]
304 fn test_accumulator_multiple_text_blocks() {
305 let mut acc = StreamAccumulator::new();
306
307 acc.apply(&StreamDelta::TextDelta {
308 delta: "First".to_string(),
309 block_index: 0,
310 });
311 acc.apply(&StreamDelta::TextDelta {
312 delta: "Second".to_string(),
313 block_index: 1,
314 });
315
316 let blocks = acc.into_content_blocks();
317 assert_eq!(blocks.len(), 2);
318 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "First"));
319 assert!(matches!(&blocks[1], ContentBlock::Text { text } if text == "Second"));
320 }
321
322 #[test]
323 fn test_accumulator_thinking_signature() {
324 let mut acc = StreamAccumulator::new();
325
326 acc.apply(&StreamDelta::ThinkingDelta {
327 delta: "Reasoning".to_string(),
328 block_index: 0,
329 });
330 acc.apply(&StreamDelta::SignatureDelta {
331 delta: "sig_123".to_string(),
332 block_index: 0,
333 });
334
335 let blocks = acc.into_content_blocks();
336 assert_eq!(blocks.len(), 1);
337 assert!(matches!(
338 &blocks[0],
339 ContentBlock::Thinking { thinking, signature }
340 if thinking == "Reasoning" && signature.as_deref() == Some("sig_123")
341 ));
342 }
343
344 #[test]
345 fn test_accumulator_tool_use() {
346 let mut acc = StreamAccumulator::new();
347
348 acc.apply(&StreamDelta::ToolUseStart {
349 id: "call_123".to_string(),
350 name: "read_file".to_string(),
351 block_index: 0,
352 thought_signature: None,
353 });
354 acc.apply(&StreamDelta::ToolInputDelta {
355 id: "call_123".to_string(),
356 delta: r#"{"path":"#.to_string(),
357 block_index: 0,
358 });
359 acc.apply(&StreamDelta::ToolInputDelta {
360 id: "call_123".to_string(),
361 delta: r#""test.txt"}"#.to_string(),
362 block_index: 0,
363 });
364
365 let blocks = acc.into_content_blocks();
366 assert_eq!(blocks.len(), 1);
367 match &blocks[0] {
368 ContentBlock::ToolUse {
369 id, name, input, ..
370 } => {
371 assert_eq!(id, "call_123");
372 assert_eq!(name, "read_file");
373 assert_eq!(input["path"], "test.txt");
374 }
375 _ => panic!("Expected ToolUse block"),
376 }
377 }
378
379 #[test]
380 fn test_accumulator_mixed_content() {
381 let mut acc = StreamAccumulator::new();
382
383 acc.apply(&StreamDelta::TextDelta {
384 delta: "Let me read that file.".to_string(),
385 block_index: 0,
386 });
387 acc.apply(&StreamDelta::ToolUseStart {
388 id: "call_456".to_string(),
389 name: "read_file".to_string(),
390 block_index: 1,
391 thought_signature: None,
392 });
393 acc.apply(&StreamDelta::ToolInputDelta {
394 id: "call_456".to_string(),
395 delta: r#"{"path":"file.txt"}"#.to_string(),
396 block_index: 1,
397 });
398 acc.apply(&StreamDelta::Usage(Usage {
399 input_tokens: 100,
400 output_tokens: 50,
401 cached_input_tokens: 0,
402 }));
403 acc.apply(&StreamDelta::Done {
404 stop_reason: Some(StopReason::ToolUse),
405 });
406
407 assert!(acc.usage().is_some());
408 assert_eq!(acc.usage().map(|u| u.input_tokens), Some(100));
409 assert!(matches!(acc.stop_reason(), Some(StopReason::ToolUse)));
410
411 let blocks = acc.into_content_blocks();
412 assert_eq!(blocks.len(), 2);
413 assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
414 assert!(matches!(&blocks[1], ContentBlock::ToolUse { .. }));
415 }
416
417 #[test]
418 fn test_accumulator_invalid_tool_json() {
419 let mut acc = StreamAccumulator::new();
420
421 acc.apply(&StreamDelta::ToolUseStart {
422 id: "call_789".to_string(),
423 name: "test_tool".to_string(),
424 block_index: 0,
425 thought_signature: None,
426 });
427 acc.apply(&StreamDelta::ToolInputDelta {
428 id: "call_789".to_string(),
429 delta: "invalid json {".to_string(),
430 block_index: 0,
431 });
432
433 let blocks = acc.into_content_blocks();
434 assert_eq!(blocks.len(), 1);
435 match &blocks[0] {
436 ContentBlock::ToolUse { input, .. } => {
437 assert!(input.is_object());
438 }
439 _ => panic!("Expected ToolUse block"),
440 }
441 }
442
443 #[test]
444 fn test_accumulator_empty_tool_input_falls_back_to_empty_object() {
445 let mut acc = StreamAccumulator::new();
450
451 acc.apply(&StreamDelta::ToolUseStart {
452 id: "call_empty".to_string(),
453 name: "read".to_string(),
454 block_index: 0,
455 thought_signature: None,
456 });
457 let blocks = acc.into_content_blocks();
460 assert_eq!(blocks.len(), 1);
461 match &blocks[0] {
462 ContentBlock::ToolUse { input, name, .. } => {
463 assert_eq!(name, "read");
464 assert_eq!(input, &serde_json::json!({}));
465 }
466 _ => panic!("Expected ToolUse block"),
467 }
468 }
469
470 #[test]
471 fn test_accumulator_mismatched_delta_id_drops_input() {
472 let mut acc = StreamAccumulator::new();
475
476 acc.apply(&StreamDelta::ToolUseStart {
477 id: "call_A".to_string(),
478 name: "bash".to_string(),
479 block_index: 0,
480 thought_signature: None,
481 });
482 acc.apply(&StreamDelta::ToolInputDelta {
484 id: "call_B".to_string(),
485 delta: r#"{"command":"ls"}"#.to_string(),
486 block_index: 0,
487 });
488
489 let blocks = acc.into_content_blocks();
490 assert_eq!(blocks.len(), 1);
491 match &blocks[0] {
492 ContentBlock::ToolUse { input, .. } => {
493 assert_eq!(input, &serde_json::json!({}));
495 }
496 _ => panic!("Expected ToolUse block"),
497 }
498 }
499
500 #[test]
501 fn test_accumulator_empty() {
502 let acc = StreamAccumulator::new();
503 let blocks = acc.into_content_blocks();
504 assert!(blocks.is_empty());
505 }
506
507 #[test]
508 fn test_accumulator_skips_empty_text() {
509 let mut acc = StreamAccumulator::new();
510
511 acc.apply(&StreamDelta::TextDelta {
512 delta: String::new(),
513 block_index: 0,
514 });
515 acc.apply(&StreamDelta::TextDelta {
516 delta: "Hello".to_string(),
517 block_index: 1,
518 });
519
520 let blocks = acc.into_content_blocks();
521 assert_eq!(blocks.len(), 1);
522 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello"));
523 }
524}