1use agent_sdk_foundation::llm::{ContentBlock, StopReason, Usage};
8use futures::Stream;
9use std::collections::HashMap;
10use std::pin::Pin;
11
12#[derive(Debug, Clone)]
17#[non_exhaustive]
18pub enum StreamDelta {
19 TextDelta {
21 delta: String,
23 block_index: usize,
25 },
26
27 ThinkingDelta {
29 delta: String,
31 block_index: usize,
33 },
34
35 ToolUseStart {
37 id: String,
39 name: String,
41 block_index: usize,
43 thought_signature: Option<String>,
45 },
46
47 ToolInputDelta {
49 id: String,
51 delta: String,
53 block_index: usize,
55 },
56
57 Usage(Usage),
59
60 Done {
62 stop_reason: Option<StopReason>,
64 },
65
66 SignatureDelta {
68 delta: String,
70 block_index: usize,
72 },
73
74 RedactedThinking {
76 data: String,
78 block_index: usize,
80 },
81
82 Error {
84 message: String,
86 kind: StreamErrorKind,
91 },
92}
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq)]
103#[non_exhaustive]
104pub enum StreamErrorKind {
105 RateLimited,
107 ServerError,
110 InvalidRequest,
113 Unknown,
123}
124
125impl StreamErrorKind {
126 #[must_use]
130 pub const fn is_recoverable(self) -> bool {
131 matches!(self, Self::RateLimited | Self::ServerError)
132 }
133}
134
135pub type StreamBox<'a> = Pin<Box<dyn Stream<Item = anyhow::Result<StreamDelta>> + Send + 'a>>;
137
138#[derive(Debug, Default)]
143pub struct StreamAccumulator {
144 text_blocks: Vec<String>,
146 thinking_blocks: Vec<String>,
148 thinking_signatures: HashMap<usize, String>,
150 redacted_thinking_blocks: Vec<(usize, String)>,
152 tool_uses: Vec<ToolUseAccumulator>,
154 usage: Option<Usage>,
156 stop_reason: Option<StopReason>,
158}
159
160#[derive(Debug, Default)]
162pub struct ToolUseAccumulator {
163 pub id: String,
165 pub name: String,
167 pub input_json: String,
169 pub block_index: usize,
171 pub thought_signature: Option<String>,
173}
174
175impl StreamAccumulator {
176 #[must_use]
178 pub fn new() -> Self {
179 Self::default()
180 }
181
182 pub fn apply(&mut self, delta: &StreamDelta) {
184 match delta {
185 StreamDelta::TextDelta { delta, block_index } => {
186 while self.text_blocks.len() <= *block_index {
187 self.text_blocks.push(String::new());
188 }
189 self.text_blocks[*block_index].push_str(delta);
190 }
191 StreamDelta::ThinkingDelta { delta, block_index } => {
192 while self.thinking_blocks.len() <= *block_index {
193 self.thinking_blocks.push(String::new());
194 }
195 self.thinking_blocks[*block_index].push_str(delta);
196 }
197 StreamDelta::ToolUseStart {
198 id,
199 name,
200 block_index,
201 thought_signature,
202 } => {
203 self.tool_uses.push(ToolUseAccumulator {
204 id: id.clone(),
205 name: name.clone(),
206 input_json: String::new(),
207 block_index: *block_index,
208 thought_signature: thought_signature.clone(),
209 });
210 }
211 StreamDelta::ToolInputDelta { id, delta, .. } => {
212 if let Some(tool) = self.tool_uses.iter_mut().find(|t| t.id == *id) {
213 tool.input_json.push_str(delta);
214 }
215 }
216 StreamDelta::SignatureDelta { delta, block_index } => {
217 self.thinking_signatures
218 .entry(*block_index)
219 .or_default()
220 .push_str(delta);
221 }
222 StreamDelta::RedactedThinking { data, block_index } => {
223 self.redacted_thinking_blocks
224 .push((*block_index, data.clone()));
225 }
226 StreamDelta::Usage(u) => {
227 self.usage = Some(u.clone());
228 }
229 StreamDelta::Done { stop_reason } => {
230 self.stop_reason = *stop_reason;
231 }
232 StreamDelta::Error { .. } => {}
233 }
234 }
235
236 #[must_use]
238 pub const fn usage(&self) -> Option<&Usage> {
239 self.usage.as_ref()
240 }
241
242 #[must_use]
244 pub const fn stop_reason(&self) -> Option<&StopReason> {
245 self.stop_reason.as_ref()
246 }
247
248 #[must_use]
253 pub fn into_content_blocks(self) -> Vec<ContentBlock> {
254 let mut blocks: Vec<(usize, ContentBlock)> = Vec::new();
255
256 let mut signatures = self.thinking_signatures;
258 for (idx, thinking) in self.thinking_blocks.into_iter().enumerate() {
259 if !thinking.is_empty() {
260 let signature = signatures.remove(&idx).filter(|s| !s.is_empty());
261 blocks.push((
262 idx,
263 ContentBlock::Thinking {
264 thinking,
265 signature,
266 },
267 ));
268 }
269 }
270
271 for (idx, data) in self.redacted_thinking_blocks {
273 blocks.push((idx, ContentBlock::RedactedThinking { data }));
274 }
275
276 for (idx, text) in self.text_blocks.into_iter().enumerate() {
278 if !text.is_empty() {
279 blocks.push((idx, ContentBlock::Text { text }));
280 }
281 }
282
283 for tool in self.tool_uses {
285 let input: serde_json::Value =
286 serde_json::from_str(&tool.input_json).unwrap_or_else(|e| {
287 log::warn!(
288 "Failed to parse streamed tool input JSON for tool '{}' (id={}): {} — \
289 input_json ({} bytes): '{}'",
290 tool.name,
291 tool.id,
292 e,
293 tool.input_json.len(),
294 tool.input_json.chars().take(500).collect::<String>(),
295 );
296 serde_json::json!({})
297 });
298 blocks.push((
299 tool.block_index,
300 ContentBlock::ToolUse {
301 id: tool.id,
302 name: tool.name,
303 input,
304 thought_signature: tool.thought_signature,
305 },
306 ));
307 }
308
309 blocks.sort_by_key(|(idx, _)| *idx);
311
312 blocks.into_iter().map(|(_, block)| block).collect()
313 }
314
315 pub const fn take_usage(&mut self) -> Option<Usage> {
317 self.usage.take()
318 }
319
320 pub const fn take_stop_reason(&mut self) -> Option<StopReason> {
322 self.stop_reason.take()
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn test_accumulator_text_deltas() {
332 let mut acc = StreamAccumulator::new();
333
334 acc.apply(&StreamDelta::TextDelta {
335 delta: "Hello".to_string(),
336 block_index: 0,
337 });
338 acc.apply(&StreamDelta::TextDelta {
339 delta: " world".to_string(),
340 block_index: 0,
341 });
342
343 let blocks = acc.into_content_blocks();
344 assert_eq!(blocks.len(), 1);
345 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello world"));
346 }
347
348 #[test]
349 fn test_accumulator_multiple_text_blocks() {
350 let mut acc = StreamAccumulator::new();
351
352 acc.apply(&StreamDelta::TextDelta {
353 delta: "First".to_string(),
354 block_index: 0,
355 });
356 acc.apply(&StreamDelta::TextDelta {
357 delta: "Second".to_string(),
358 block_index: 1,
359 });
360
361 let blocks = acc.into_content_blocks();
362 assert_eq!(blocks.len(), 2);
363 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "First"));
364 assert!(matches!(&blocks[1], ContentBlock::Text { text } if text == "Second"));
365 }
366
367 #[test]
368 fn test_accumulator_thinking_signature() {
369 let mut acc = StreamAccumulator::new();
370
371 acc.apply(&StreamDelta::ThinkingDelta {
372 delta: "Reasoning".to_string(),
373 block_index: 0,
374 });
375 acc.apply(&StreamDelta::SignatureDelta {
376 delta: "sig_123".to_string(),
377 block_index: 0,
378 });
379
380 let blocks = acc.into_content_blocks();
381 assert_eq!(blocks.len(), 1);
382 assert!(matches!(
383 &blocks[0],
384 ContentBlock::Thinking { thinking, signature }
385 if thinking == "Reasoning" && signature.as_deref() == Some("sig_123")
386 ));
387 }
388
389 #[test]
390 fn test_accumulator_tool_use() {
391 let mut acc = StreamAccumulator::new();
392
393 acc.apply(&StreamDelta::ToolUseStart {
394 id: "call_123".to_string(),
395 name: "read_file".to_string(),
396 block_index: 0,
397 thought_signature: None,
398 });
399 acc.apply(&StreamDelta::ToolInputDelta {
400 id: "call_123".to_string(),
401 delta: r#"{"path":"#.to_string(),
402 block_index: 0,
403 });
404 acc.apply(&StreamDelta::ToolInputDelta {
405 id: "call_123".to_string(),
406 delta: r#""test.txt"}"#.to_string(),
407 block_index: 0,
408 });
409
410 let blocks = acc.into_content_blocks();
411 assert_eq!(blocks.len(), 1);
412 match &blocks[0] {
413 ContentBlock::ToolUse {
414 id, name, input, ..
415 } => {
416 assert_eq!(id, "call_123");
417 assert_eq!(name, "read_file");
418 assert_eq!(input["path"], "test.txt");
419 }
420 _ => panic!("Expected ToolUse block"),
421 }
422 }
423
424 #[test]
425 fn test_accumulator_mixed_content() {
426 let mut acc = StreamAccumulator::new();
427
428 acc.apply(&StreamDelta::TextDelta {
429 delta: "Let me read that file.".to_string(),
430 block_index: 0,
431 });
432 acc.apply(&StreamDelta::ToolUseStart {
433 id: "call_456".to_string(),
434 name: "read_file".to_string(),
435 block_index: 1,
436 thought_signature: None,
437 });
438 acc.apply(&StreamDelta::ToolInputDelta {
439 id: "call_456".to_string(),
440 delta: r#"{"path":"file.txt"}"#.to_string(),
441 block_index: 1,
442 });
443 acc.apply(&StreamDelta::Usage(Usage {
444 input_tokens: 100,
445 output_tokens: 50,
446 cached_input_tokens: 0,
447 cache_creation_input_tokens: 0,
448 }));
449 acc.apply(&StreamDelta::Done {
450 stop_reason: Some(StopReason::ToolUse),
451 });
452
453 assert!(acc.usage().is_some());
454 assert_eq!(acc.usage().map(|u| u.input_tokens), Some(100));
455 assert!(matches!(acc.stop_reason(), Some(StopReason::ToolUse)));
456
457 let blocks = acc.into_content_blocks();
458 assert_eq!(blocks.len(), 2);
459 assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
460 assert!(matches!(&blocks[1], ContentBlock::ToolUse { .. }));
461 }
462
463 #[test]
464 fn test_accumulator_invalid_tool_json() {
465 let mut acc = StreamAccumulator::new();
466
467 acc.apply(&StreamDelta::ToolUseStart {
468 id: "call_789".to_string(),
469 name: "test_tool".to_string(),
470 block_index: 0,
471 thought_signature: None,
472 });
473 acc.apply(&StreamDelta::ToolInputDelta {
474 id: "call_789".to_string(),
475 delta: "invalid json {".to_string(),
476 block_index: 0,
477 });
478
479 let blocks = acc.into_content_blocks();
480 assert_eq!(blocks.len(), 1);
481 match &blocks[0] {
482 ContentBlock::ToolUse { input, .. } => {
483 assert!(input.is_object());
484 }
485 _ => panic!("Expected ToolUse block"),
486 }
487 }
488
489 #[test]
490 fn test_accumulator_empty_tool_input_falls_back_to_empty_object() {
491 let mut acc = StreamAccumulator::new();
496
497 acc.apply(&StreamDelta::ToolUseStart {
498 id: "call_empty".to_string(),
499 name: "read".to_string(),
500 block_index: 0,
501 thought_signature: None,
502 });
503 let blocks = acc.into_content_blocks();
506 assert_eq!(blocks.len(), 1);
507 match &blocks[0] {
508 ContentBlock::ToolUse { input, name, .. } => {
509 assert_eq!(name, "read");
510 assert_eq!(input, &serde_json::json!({}));
511 }
512 _ => panic!("Expected ToolUse block"),
513 }
514 }
515
516 #[test]
517 fn test_accumulator_mismatched_delta_id_drops_input() {
518 let mut acc = StreamAccumulator::new();
521
522 acc.apply(&StreamDelta::ToolUseStart {
523 id: "call_A".to_string(),
524 name: "bash".to_string(),
525 block_index: 0,
526 thought_signature: None,
527 });
528 acc.apply(&StreamDelta::ToolInputDelta {
530 id: "call_B".to_string(),
531 delta: r#"{"command":"ls"}"#.to_string(),
532 block_index: 0,
533 });
534
535 let blocks = acc.into_content_blocks();
536 assert_eq!(blocks.len(), 1);
537 match &blocks[0] {
538 ContentBlock::ToolUse { input, .. } => {
539 assert_eq!(input, &serde_json::json!({}));
541 }
542 _ => panic!("Expected ToolUse block"),
543 }
544 }
545
546 #[test]
547 fn test_accumulator_empty() {
548 let acc = StreamAccumulator::new();
549 let blocks = acc.into_content_blocks();
550 assert!(blocks.is_empty());
551 }
552
553 #[test]
554 fn test_accumulator_skips_empty_text() {
555 let mut acc = StreamAccumulator::new();
556
557 acc.apply(&StreamDelta::TextDelta {
558 delta: String::new(),
559 block_index: 0,
560 });
561 acc.apply(&StreamDelta::TextDelta {
562 delta: "Hello".to_string(),
563 block_index: 1,
564 });
565
566 let blocks = acc.into_content_blocks();
567 assert_eq!(blocks.len(), 1);
568 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello"));
569 }
570}