Skip to main content

agent_sdk/llm/
streaming.rs

1//! Streaming types for LLM responses.
2//!
3//! This module provides types for handling streaming responses from LLM providers.
4//! The [`StreamDelta`] enum represents individual events in a streaming response,
5//! and [`StreamAccumulator`] helps collect these events into a final response.
6
7use crate::llm::{ContentBlock, StopReason, Usage};
8use futures::Stream;
9use std::collections::HashMap;
10use std::pin::Pin;
11
12/// Events yielded during streaming LLM responses.
13///
14/// Each variant represents a different type of event that can occur
15/// during a streaming response from an LLM provider.
16#[derive(Debug, Clone)]
17pub enum StreamDelta {
18    /// A text delta for streaming text content.
19    TextDelta {
20        /// The text fragment to append
21        delta: String,
22        /// Index of the content block being streamed
23        block_index: usize,
24    },
25
26    /// A thinking delta for streaming thinking/reasoning content.
27    ThinkingDelta {
28        /// The thinking fragment to append
29        delta: String,
30        /// Index of the content block being streamed
31        block_index: usize,
32    },
33
34    /// Start of a tool use block (name and id are known).
35    ToolUseStart {
36        /// Unique identifier for this tool call
37        id: String,
38        /// Name of the tool being called
39        name: String,
40        /// Index of the content block
41        block_index: usize,
42        /// Optional thought signature (used by Gemini 3.x models)
43        thought_signature: Option<String>,
44    },
45
46    /// Incremental JSON for tool input (partial/incomplete JSON).
47    ToolInputDelta {
48        /// Tool call ID this delta belongs to
49        id: String,
50        /// JSON fragment to append
51        delta: String,
52        /// Index of the content block
53        block_index: usize,
54    },
55
56    /// Usage information (typically at stream end).
57    Usage(Usage),
58
59    /// Stream completed with stop reason.
60    Done {
61        /// Why the stream ended
62        stop_reason: Option<StopReason>,
63    },
64
65    /// A signature delta for a thinking block.
66    SignatureDelta {
67        /// The signature fragment to append
68        delta: String,
69        /// Index of the content block being streamed
70        block_index: usize,
71    },
72
73    /// A redacted thinking block received at `content_block_start`.
74    RedactedThinking {
75        /// Opaque data payload
76        data: String,
77        /// Index of the content block
78        block_index: usize,
79    },
80
81    /// Error during streaming.
82    Error {
83        /// Error message
84        message: String,
85        /// Whether the error is recoverable (e.g., rate limit)
86        recoverable: bool,
87    },
88}
89
90/// Type alias for a boxed stream of stream deltas.
91pub type StreamBox<'a> = Pin<Box<dyn Stream<Item = anyhow::Result<StreamDelta>> + Send + 'a>>;
92
93/// Helper to accumulate streamed content into a final response.
94///
95/// This struct collects [`StreamDelta`] events and can convert them
96/// into the final content blocks once the stream is complete.
97#[derive(Debug, Default)]
98pub struct StreamAccumulator {
99    /// Accumulated text for each block index
100    text_blocks: Vec<String>,
101    /// Accumulated thinking blocks for each block index
102    thinking_blocks: Vec<String>,
103    /// Accumulated signatures keyed by block index
104    thinking_signatures: HashMap<usize, String>,
105    /// Redacted thinking blocks: (`block_index`, data)
106    redacted_thinking_blocks: Vec<(usize, String)>,
107    /// Accumulated tool use calls
108    tool_uses: Vec<ToolUseAccumulator>,
109    /// Usage information from the stream
110    usage: Option<Usage>,
111    /// Stop reason from the stream
112    stop_reason: Option<StopReason>,
113}
114
115/// Accumulator for a single tool use during streaming.
116#[derive(Debug, Default)]
117pub struct ToolUseAccumulator {
118    /// Tool call ID
119    pub id: String,
120    /// Tool name
121    pub name: String,
122    /// Accumulated JSON input (may be incomplete during streaming)
123    pub input_json: String,
124    /// Block index for ordering
125    pub block_index: usize,
126    /// Optional thought signature (used by Gemini 3.x models)
127    pub thought_signature: Option<String>,
128}
129
130impl StreamAccumulator {
131    /// Create a new empty accumulator.
132    #[must_use]
133    pub fn new() -> Self {
134        Self::default()
135    }
136
137    /// Apply a stream delta to the accumulator.
138    pub fn apply(&mut self, delta: &StreamDelta) {
139        match delta {
140            StreamDelta::TextDelta { delta, block_index } => {
141                while self.text_blocks.len() <= *block_index {
142                    self.text_blocks.push(String::new());
143                }
144                self.text_blocks[*block_index].push_str(delta);
145            }
146            StreamDelta::ThinkingDelta { delta, block_index } => {
147                while self.thinking_blocks.len() <= *block_index {
148                    self.thinking_blocks.push(String::new());
149                }
150                self.thinking_blocks[*block_index].push_str(delta);
151            }
152            StreamDelta::ToolUseStart {
153                id,
154                name,
155                block_index,
156                thought_signature,
157            } => {
158                self.tool_uses.push(ToolUseAccumulator {
159                    id: id.clone(),
160                    name: name.clone(),
161                    input_json: String::new(),
162                    block_index: *block_index,
163                    thought_signature: thought_signature.clone(),
164                });
165            }
166            StreamDelta::ToolInputDelta { id, delta, .. } => {
167                if let Some(tool) = self.tool_uses.iter_mut().find(|t| t.id == *id) {
168                    tool.input_json.push_str(delta);
169                }
170            }
171            StreamDelta::SignatureDelta { delta, block_index } => {
172                self.thinking_signatures
173                    .entry(*block_index)
174                    .or_default()
175                    .push_str(delta);
176            }
177            StreamDelta::RedactedThinking { data, block_index } => {
178                self.redacted_thinking_blocks
179                    .push((*block_index, data.clone()));
180            }
181            StreamDelta::Usage(u) => {
182                self.usage = Some(u.clone());
183            }
184            StreamDelta::Done { stop_reason } => {
185                self.stop_reason = *stop_reason;
186            }
187            StreamDelta::Error { .. } => {}
188        }
189    }
190
191    /// Get the accumulated usage information.
192    #[must_use]
193    pub const fn usage(&self) -> Option<&Usage> {
194        self.usage.as_ref()
195    }
196
197    /// Get the stop reason.
198    #[must_use]
199    pub const fn stop_reason(&self) -> Option<&StopReason> {
200        self.stop_reason.as_ref()
201    }
202
203    /// Convert accumulated content to `ContentBlock`s.
204    ///
205    /// This consumes the accumulator and returns the final content blocks.
206    /// Tool use JSON is parsed at this point; invalid JSON results in a null input.
207    #[must_use]
208    pub fn into_content_blocks(self) -> Vec<ContentBlock> {
209        let mut blocks: Vec<(usize, ContentBlock)> = Vec::new();
210
211        // Add thinking blocks with their indices, attaching signatures
212        let mut signatures = self.thinking_signatures;
213        for (idx, thinking) in self.thinking_blocks.into_iter().enumerate() {
214            if !thinking.is_empty() {
215                let signature = signatures.remove(&idx).filter(|s| !s.is_empty());
216                blocks.push((
217                    idx,
218                    ContentBlock::Thinking {
219                        thinking,
220                        signature,
221                    },
222                ));
223            }
224        }
225
226        // Add redacted thinking blocks
227        for (idx, data) in self.redacted_thinking_blocks {
228            blocks.push((idx, ContentBlock::RedactedThinking { data }));
229        }
230
231        // Add text blocks with their indices
232        for (idx, text) in self.text_blocks.into_iter().enumerate() {
233            if !text.is_empty() {
234                blocks.push((idx, ContentBlock::Text { text }));
235            }
236        }
237
238        // Add tool uses with their indices
239        for tool in self.tool_uses {
240            let input: serde_json::Value =
241                serde_json::from_str(&tool.input_json).unwrap_or_else(|e| {
242                    log::warn!(
243                        "Failed to parse streamed tool input JSON for tool '{}' (id={}): {} — \
244                         input_json ({} bytes): '{}'",
245                        tool.name,
246                        tool.id,
247                        e,
248                        tool.input_json.len(),
249                        tool.input_json.chars().take(500).collect::<String>(),
250                    );
251                    serde_json::json!({})
252                });
253            blocks.push((
254                tool.block_index,
255                ContentBlock::ToolUse {
256                    id: tool.id,
257                    name: tool.name,
258                    input,
259                    thought_signature: tool.thought_signature,
260                },
261            ));
262        }
263
264        // Sort by block index to maintain order
265        blocks.sort_by_key(|(idx, _)| *idx);
266
267        blocks.into_iter().map(|(_, block)| block).collect()
268    }
269
270    /// Take ownership of accumulated usage.
271    pub const fn take_usage(&mut self) -> Option<Usage> {
272        self.usage.take()
273    }
274
275    /// Take ownership of stop reason.
276    pub const fn take_stop_reason(&mut self) -> Option<StopReason> {
277        self.stop_reason.take()
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    #[test]
286    fn test_accumulator_text_deltas() {
287        let mut acc = StreamAccumulator::new();
288
289        acc.apply(&StreamDelta::TextDelta {
290            delta: "Hello".to_string(),
291            block_index: 0,
292        });
293        acc.apply(&StreamDelta::TextDelta {
294            delta: " world".to_string(),
295            block_index: 0,
296        });
297
298        let blocks = acc.into_content_blocks();
299        assert_eq!(blocks.len(), 1);
300        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello world"));
301    }
302
303    #[test]
304    fn test_accumulator_multiple_text_blocks() {
305        let mut acc = StreamAccumulator::new();
306
307        acc.apply(&StreamDelta::TextDelta {
308            delta: "First".to_string(),
309            block_index: 0,
310        });
311        acc.apply(&StreamDelta::TextDelta {
312            delta: "Second".to_string(),
313            block_index: 1,
314        });
315
316        let blocks = acc.into_content_blocks();
317        assert_eq!(blocks.len(), 2);
318        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "First"));
319        assert!(matches!(&blocks[1], ContentBlock::Text { text } if text == "Second"));
320    }
321
322    #[test]
323    fn test_accumulator_thinking_signature() {
324        let mut acc = StreamAccumulator::new();
325
326        acc.apply(&StreamDelta::ThinkingDelta {
327            delta: "Reasoning".to_string(),
328            block_index: 0,
329        });
330        acc.apply(&StreamDelta::SignatureDelta {
331            delta: "sig_123".to_string(),
332            block_index: 0,
333        });
334
335        let blocks = acc.into_content_blocks();
336        assert_eq!(blocks.len(), 1);
337        assert!(matches!(
338            &blocks[0],
339            ContentBlock::Thinking { thinking, signature }
340            if thinking == "Reasoning" && signature.as_deref() == Some("sig_123")
341        ));
342    }
343
344    #[test]
345    fn test_accumulator_tool_use() {
346        let mut acc = StreamAccumulator::new();
347
348        acc.apply(&StreamDelta::ToolUseStart {
349            id: "call_123".to_string(),
350            name: "read_file".to_string(),
351            block_index: 0,
352            thought_signature: None,
353        });
354        acc.apply(&StreamDelta::ToolInputDelta {
355            id: "call_123".to_string(),
356            delta: r#"{"path":"#.to_string(),
357            block_index: 0,
358        });
359        acc.apply(&StreamDelta::ToolInputDelta {
360            id: "call_123".to_string(),
361            delta: r#""test.txt"}"#.to_string(),
362            block_index: 0,
363        });
364
365        let blocks = acc.into_content_blocks();
366        assert_eq!(blocks.len(), 1);
367        match &blocks[0] {
368            ContentBlock::ToolUse {
369                id, name, input, ..
370            } => {
371                assert_eq!(id, "call_123");
372                assert_eq!(name, "read_file");
373                assert_eq!(input["path"], "test.txt");
374            }
375            _ => panic!("Expected ToolUse block"),
376        }
377    }
378
379    #[test]
380    fn test_accumulator_mixed_content() {
381        let mut acc = StreamAccumulator::new();
382
383        acc.apply(&StreamDelta::TextDelta {
384            delta: "Let me read that file.".to_string(),
385            block_index: 0,
386        });
387        acc.apply(&StreamDelta::ToolUseStart {
388            id: "call_456".to_string(),
389            name: "read_file".to_string(),
390            block_index: 1,
391            thought_signature: None,
392        });
393        acc.apply(&StreamDelta::ToolInputDelta {
394            id: "call_456".to_string(),
395            delta: r#"{"path":"file.txt"}"#.to_string(),
396            block_index: 1,
397        });
398        acc.apply(&StreamDelta::Usage(Usage {
399            input_tokens: 100,
400            output_tokens: 50,
401        }));
402        acc.apply(&StreamDelta::Done {
403            stop_reason: Some(StopReason::ToolUse),
404        });
405
406        assert!(acc.usage().is_some());
407        assert_eq!(acc.usage().map(|u| u.input_tokens), Some(100));
408        assert!(matches!(acc.stop_reason(), Some(StopReason::ToolUse)));
409
410        let blocks = acc.into_content_blocks();
411        assert_eq!(blocks.len(), 2);
412        assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
413        assert!(matches!(&blocks[1], ContentBlock::ToolUse { .. }));
414    }
415
416    #[test]
417    fn test_accumulator_invalid_tool_json() {
418        let mut acc = StreamAccumulator::new();
419
420        acc.apply(&StreamDelta::ToolUseStart {
421            id: "call_789".to_string(),
422            name: "test_tool".to_string(),
423            block_index: 0,
424            thought_signature: None,
425        });
426        acc.apply(&StreamDelta::ToolInputDelta {
427            id: "call_789".to_string(),
428            delta: "invalid json {".to_string(),
429            block_index: 0,
430        });
431
432        let blocks = acc.into_content_blocks();
433        assert_eq!(blocks.len(), 1);
434        match &blocks[0] {
435            ContentBlock::ToolUse { input, .. } => {
436                assert!(input.is_object());
437            }
438            _ => panic!("Expected ToolUse block"),
439        }
440    }
441
442    #[test]
443    fn test_accumulator_empty_tool_input_falls_back_to_empty_object() {
444        // If no ToolInputDelta is received (e.g., stream interrupted or
445        // deltas had mismatched IDs), the tool use block should still be
446        // produced with an empty object so that the error is attributable
447        // to the tool rather than silently lost.
448        let mut acc = StreamAccumulator::new();
449
450        acc.apply(&StreamDelta::ToolUseStart {
451            id: "call_empty".to_string(),
452            name: "read".to_string(),
453            block_index: 0,
454            thought_signature: None,
455        });
456        // No ToolInputDelta applied
457
458        let blocks = acc.into_content_blocks();
459        assert_eq!(blocks.len(), 1);
460        match &blocks[0] {
461            ContentBlock::ToolUse { input, name, .. } => {
462                assert_eq!(name, "read");
463                assert_eq!(input, &serde_json::json!({}));
464            }
465            _ => panic!("Expected ToolUse block"),
466        }
467    }
468
469    #[test]
470    fn test_accumulator_mismatched_delta_id_drops_input() {
471        // If ToolInputDelta has a different ID than any ToolUseStart,
472        // the input is silently dropped (the tool gets empty {}).
473        let mut acc = StreamAccumulator::new();
474
475        acc.apply(&StreamDelta::ToolUseStart {
476            id: "call_A".to_string(),
477            name: "bash".to_string(),
478            block_index: 0,
479            thought_signature: None,
480        });
481        // Delta with wrong ID
482        acc.apply(&StreamDelta::ToolInputDelta {
483            id: "call_B".to_string(),
484            delta: r#"{"command":"ls"}"#.to_string(),
485            block_index: 0,
486        });
487
488        let blocks = acc.into_content_blocks();
489        assert_eq!(blocks.len(), 1);
490        match &blocks[0] {
491            ContentBlock::ToolUse { input, .. } => {
492                // Input should be empty because the delta had a mismatched ID
493                assert_eq!(input, &serde_json::json!({}));
494            }
495            _ => panic!("Expected ToolUse block"),
496        }
497    }
498
499    #[test]
500    fn test_accumulator_empty() {
501        let acc = StreamAccumulator::new();
502        let blocks = acc.into_content_blocks();
503        assert!(blocks.is_empty());
504    }
505
506    #[test]
507    fn test_accumulator_skips_empty_text() {
508        let mut acc = StreamAccumulator::new();
509
510        acc.apply(&StreamDelta::TextDelta {
511            delta: String::new(),
512            block_index: 0,
513        });
514        acc.apply(&StreamDelta::TextDelta {
515            delta: "Hello".to_string(),
516            block_index: 1,
517        });
518
519        let blocks = acc.into_content_blocks();
520        assert_eq!(blocks.len(), 1);
521        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello"));
522    }
523}