1use crate::llm::{ContentBlock, StopReason, Usage};
8use futures::Stream;
9use std::collections::HashMap;
10use std::pin::Pin;
11
12#[derive(Debug, Clone)]
17pub enum StreamDelta {
18 TextDelta {
20 delta: String,
22 block_index: usize,
24 },
25
26 ThinkingDelta {
28 delta: String,
30 block_index: usize,
32 },
33
34 ToolUseStart {
36 id: String,
38 name: String,
40 block_index: usize,
42 thought_signature: Option<String>,
44 },
45
46 ToolInputDelta {
48 id: String,
50 delta: String,
52 block_index: usize,
54 },
55
56 Usage(Usage),
58
59 Done {
61 stop_reason: Option<StopReason>,
63 },
64
65 SignatureDelta {
67 delta: String,
69 block_index: usize,
71 },
72
73 RedactedThinking {
75 data: String,
77 block_index: usize,
79 },
80
81 Error {
83 message: String,
85 recoverable: bool,
87 },
88}
89
90pub type StreamBox<'a> = Pin<Box<dyn Stream<Item = anyhow::Result<StreamDelta>> + Send + 'a>>;
92
93#[derive(Debug, Default)]
98pub struct StreamAccumulator {
99 text_blocks: Vec<String>,
101 thinking_blocks: Vec<String>,
103 thinking_signatures: HashMap<usize, String>,
105 redacted_thinking_blocks: Vec<(usize, String)>,
107 tool_uses: Vec<ToolUseAccumulator>,
109 usage: Option<Usage>,
111 stop_reason: Option<StopReason>,
113}
114
115#[derive(Debug, Default)]
117pub struct ToolUseAccumulator {
118 pub id: String,
120 pub name: String,
122 pub input_json: String,
124 pub block_index: usize,
126 pub thought_signature: Option<String>,
128}
129
130impl StreamAccumulator {
131 #[must_use]
133 pub fn new() -> Self {
134 Self::default()
135 }
136
137 pub fn apply(&mut self, delta: &StreamDelta) {
139 match delta {
140 StreamDelta::TextDelta { delta, block_index } => {
141 while self.text_blocks.len() <= *block_index {
142 self.text_blocks.push(String::new());
143 }
144 self.text_blocks[*block_index].push_str(delta);
145 }
146 StreamDelta::ThinkingDelta { delta, block_index } => {
147 while self.thinking_blocks.len() <= *block_index {
148 self.thinking_blocks.push(String::new());
149 }
150 self.thinking_blocks[*block_index].push_str(delta);
151 }
152 StreamDelta::ToolUseStart {
153 id,
154 name,
155 block_index,
156 thought_signature,
157 } => {
158 self.tool_uses.push(ToolUseAccumulator {
159 id: id.clone(),
160 name: name.clone(),
161 input_json: String::new(),
162 block_index: *block_index,
163 thought_signature: thought_signature.clone(),
164 });
165 }
166 StreamDelta::ToolInputDelta { id, delta, .. } => {
167 if let Some(tool) = self.tool_uses.iter_mut().find(|t| t.id == *id) {
168 tool.input_json.push_str(delta);
169 }
170 }
171 StreamDelta::SignatureDelta { delta, block_index } => {
172 self.thinking_signatures
173 .entry(*block_index)
174 .or_default()
175 .push_str(delta);
176 }
177 StreamDelta::RedactedThinking { data, block_index } => {
178 self.redacted_thinking_blocks
179 .push((*block_index, data.clone()));
180 }
181 StreamDelta::Usage(u) => {
182 self.usage = Some(u.clone());
183 }
184 StreamDelta::Done { stop_reason } => {
185 self.stop_reason = *stop_reason;
186 }
187 StreamDelta::Error { .. } => {}
188 }
189 }
190
191 #[must_use]
193 pub const fn usage(&self) -> Option<&Usage> {
194 self.usage.as_ref()
195 }
196
197 #[must_use]
199 pub const fn stop_reason(&self) -> Option<&StopReason> {
200 self.stop_reason.as_ref()
201 }
202
203 #[must_use]
208 pub fn into_content_blocks(self) -> Vec<ContentBlock> {
209 let mut blocks: Vec<(usize, ContentBlock)> = Vec::new();
210
211 let mut signatures = self.thinking_signatures;
213 for (idx, thinking) in self.thinking_blocks.into_iter().enumerate() {
214 if !thinking.is_empty() {
215 let signature = signatures.remove(&idx).filter(|s| !s.is_empty());
216 blocks.push((
217 idx,
218 ContentBlock::Thinking {
219 thinking,
220 signature,
221 },
222 ));
223 }
224 }
225
226 for (idx, data) in self.redacted_thinking_blocks {
228 blocks.push((idx, ContentBlock::RedactedThinking { data }));
229 }
230
231 for (idx, text) in self.text_blocks.into_iter().enumerate() {
233 if !text.is_empty() {
234 blocks.push((idx, ContentBlock::Text { text }));
235 }
236 }
237
238 for tool in self.tool_uses {
240 let input: serde_json::Value =
241 serde_json::from_str(&tool.input_json).unwrap_or_else(|e| {
242 log::warn!(
243 "Failed to parse streamed tool input JSON for tool '{}' (id={}): {} — \
244 input_json ({} bytes): '{}'",
245 tool.name,
246 tool.id,
247 e,
248 tool.input_json.len(),
249 tool.input_json.chars().take(500).collect::<String>(),
250 );
251 serde_json::json!({})
252 });
253 blocks.push((
254 tool.block_index,
255 ContentBlock::ToolUse {
256 id: tool.id,
257 name: tool.name,
258 input,
259 thought_signature: tool.thought_signature,
260 },
261 ));
262 }
263
264 blocks.sort_by_key(|(idx, _)| *idx);
266
267 blocks.into_iter().map(|(_, block)| block).collect()
268 }
269
270 pub const fn take_usage(&mut self) -> Option<Usage> {
272 self.usage.take()
273 }
274
275 pub const fn take_stop_reason(&mut self) -> Option<StopReason> {
277 self.stop_reason.take()
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
286 fn test_accumulator_text_deltas() {
287 let mut acc = StreamAccumulator::new();
288
289 acc.apply(&StreamDelta::TextDelta {
290 delta: "Hello".to_string(),
291 block_index: 0,
292 });
293 acc.apply(&StreamDelta::TextDelta {
294 delta: " world".to_string(),
295 block_index: 0,
296 });
297
298 let blocks = acc.into_content_blocks();
299 assert_eq!(blocks.len(), 1);
300 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello world"));
301 }
302
303 #[test]
304 fn test_accumulator_multiple_text_blocks() {
305 let mut acc = StreamAccumulator::new();
306
307 acc.apply(&StreamDelta::TextDelta {
308 delta: "First".to_string(),
309 block_index: 0,
310 });
311 acc.apply(&StreamDelta::TextDelta {
312 delta: "Second".to_string(),
313 block_index: 1,
314 });
315
316 let blocks = acc.into_content_blocks();
317 assert_eq!(blocks.len(), 2);
318 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "First"));
319 assert!(matches!(&blocks[1], ContentBlock::Text { text } if text == "Second"));
320 }
321
322 #[test]
323 fn test_accumulator_thinking_signature() {
324 let mut acc = StreamAccumulator::new();
325
326 acc.apply(&StreamDelta::ThinkingDelta {
327 delta: "Reasoning".to_string(),
328 block_index: 0,
329 });
330 acc.apply(&StreamDelta::SignatureDelta {
331 delta: "sig_123".to_string(),
332 block_index: 0,
333 });
334
335 let blocks = acc.into_content_blocks();
336 assert_eq!(blocks.len(), 1);
337 assert!(matches!(
338 &blocks[0],
339 ContentBlock::Thinking { thinking, signature }
340 if thinking == "Reasoning" && signature.as_deref() == Some("sig_123")
341 ));
342 }
343
344 #[test]
345 fn test_accumulator_tool_use() {
346 let mut acc = StreamAccumulator::new();
347
348 acc.apply(&StreamDelta::ToolUseStart {
349 id: "call_123".to_string(),
350 name: "read_file".to_string(),
351 block_index: 0,
352 thought_signature: None,
353 });
354 acc.apply(&StreamDelta::ToolInputDelta {
355 id: "call_123".to_string(),
356 delta: r#"{"path":"#.to_string(),
357 block_index: 0,
358 });
359 acc.apply(&StreamDelta::ToolInputDelta {
360 id: "call_123".to_string(),
361 delta: r#""test.txt"}"#.to_string(),
362 block_index: 0,
363 });
364
365 let blocks = acc.into_content_blocks();
366 assert_eq!(blocks.len(), 1);
367 match &blocks[0] {
368 ContentBlock::ToolUse {
369 id, name, input, ..
370 } => {
371 assert_eq!(id, "call_123");
372 assert_eq!(name, "read_file");
373 assert_eq!(input["path"], "test.txt");
374 }
375 _ => panic!("Expected ToolUse block"),
376 }
377 }
378
379 #[test]
380 fn test_accumulator_mixed_content() {
381 let mut acc = StreamAccumulator::new();
382
383 acc.apply(&StreamDelta::TextDelta {
384 delta: "Let me read that file.".to_string(),
385 block_index: 0,
386 });
387 acc.apply(&StreamDelta::ToolUseStart {
388 id: "call_456".to_string(),
389 name: "read_file".to_string(),
390 block_index: 1,
391 thought_signature: None,
392 });
393 acc.apply(&StreamDelta::ToolInputDelta {
394 id: "call_456".to_string(),
395 delta: r#"{"path":"file.txt"}"#.to_string(),
396 block_index: 1,
397 });
398 acc.apply(&StreamDelta::Usage(Usage {
399 input_tokens: 100,
400 output_tokens: 50,
401 }));
402 acc.apply(&StreamDelta::Done {
403 stop_reason: Some(StopReason::ToolUse),
404 });
405
406 assert!(acc.usage().is_some());
407 assert_eq!(acc.usage().map(|u| u.input_tokens), Some(100));
408 assert!(matches!(acc.stop_reason(), Some(StopReason::ToolUse)));
409
410 let blocks = acc.into_content_blocks();
411 assert_eq!(blocks.len(), 2);
412 assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
413 assert!(matches!(&blocks[1], ContentBlock::ToolUse { .. }));
414 }
415
416 #[test]
417 fn test_accumulator_invalid_tool_json() {
418 let mut acc = StreamAccumulator::new();
419
420 acc.apply(&StreamDelta::ToolUseStart {
421 id: "call_789".to_string(),
422 name: "test_tool".to_string(),
423 block_index: 0,
424 thought_signature: None,
425 });
426 acc.apply(&StreamDelta::ToolInputDelta {
427 id: "call_789".to_string(),
428 delta: "invalid json {".to_string(),
429 block_index: 0,
430 });
431
432 let blocks = acc.into_content_blocks();
433 assert_eq!(blocks.len(), 1);
434 match &blocks[0] {
435 ContentBlock::ToolUse { input, .. } => {
436 assert!(input.is_object());
437 }
438 _ => panic!("Expected ToolUse block"),
439 }
440 }
441
442 #[test]
443 fn test_accumulator_empty_tool_input_falls_back_to_empty_object() {
444 let mut acc = StreamAccumulator::new();
449
450 acc.apply(&StreamDelta::ToolUseStart {
451 id: "call_empty".to_string(),
452 name: "read".to_string(),
453 block_index: 0,
454 thought_signature: None,
455 });
456 let blocks = acc.into_content_blocks();
459 assert_eq!(blocks.len(), 1);
460 match &blocks[0] {
461 ContentBlock::ToolUse { input, name, .. } => {
462 assert_eq!(name, "read");
463 assert_eq!(input, &serde_json::json!({}));
464 }
465 _ => panic!("Expected ToolUse block"),
466 }
467 }
468
469 #[test]
470 fn test_accumulator_mismatched_delta_id_drops_input() {
471 let mut acc = StreamAccumulator::new();
474
475 acc.apply(&StreamDelta::ToolUseStart {
476 id: "call_A".to_string(),
477 name: "bash".to_string(),
478 block_index: 0,
479 thought_signature: None,
480 });
481 acc.apply(&StreamDelta::ToolInputDelta {
483 id: "call_B".to_string(),
484 delta: r#"{"command":"ls"}"#.to_string(),
485 block_index: 0,
486 });
487
488 let blocks = acc.into_content_blocks();
489 assert_eq!(blocks.len(), 1);
490 match &blocks[0] {
491 ContentBlock::ToolUse { input, .. } => {
492 assert_eq!(input, &serde_json::json!({}));
494 }
495 _ => panic!("Expected ToolUse block"),
496 }
497 }
498
499 #[test]
500 fn test_accumulator_empty() {
501 let acc = StreamAccumulator::new();
502 let blocks = acc.into_content_blocks();
503 assert!(blocks.is_empty());
504 }
505
506 #[test]
507 fn test_accumulator_skips_empty_text() {
508 let mut acc = StreamAccumulator::new();
509
510 acc.apply(&StreamDelta::TextDelta {
511 delta: String::new(),
512 block_index: 0,
513 });
514 acc.apply(&StreamDelta::TextDelta {
515 delta: "Hello".to_string(),
516 block_index: 1,
517 });
518
519 let blocks = acc.into_content_blocks();
520 assert_eq!(blocks.len(), 1);
521 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello"));
522 }
523}