Skip to main content

agent_sdk/llm/
streaming.rs

1//! Streaming types for LLM responses.
2//!
3//! This module provides types for handling streaming responses from LLM providers.
4//! The [`StreamDelta`] enum represents individual events in a streaming response,
5//! and [`StreamAccumulator`] helps collect these events into a final response.
6
7use crate::llm::{ContentBlock, StopReason, Usage};
8use futures::Stream;
9use std::collections::HashMap;
10use std::pin::Pin;
11
12/// Events yielded during streaming LLM responses.
13///
14/// Each variant represents a different type of event that can occur
15/// during a streaming response from an LLM provider.
16#[derive(Debug, Clone)]
17pub enum StreamDelta {
18    /// A text delta for streaming text content.
19    TextDelta {
20        /// The text fragment to append
21        delta: String,
22        /// Index of the content block being streamed
23        block_index: usize,
24    },
25
26    /// A thinking delta for streaming thinking/reasoning content.
27    ThinkingDelta {
28        /// The thinking fragment to append
29        delta: String,
30        /// Index of the content block being streamed
31        block_index: usize,
32    },
33
34    /// Start of a tool use block (name and id are known).
35    ToolUseStart {
36        /// Unique identifier for this tool call
37        id: String,
38        /// Name of the tool being called
39        name: String,
40        /// Index of the content block
41        block_index: usize,
42        /// Optional thought signature (used by Gemini 3.x models)
43        thought_signature: Option<String>,
44    },
45
46    /// Incremental JSON for tool input (partial/incomplete JSON).
47    ToolInputDelta {
48        /// Tool call ID this delta belongs to
49        id: String,
50        /// JSON fragment to append
51        delta: String,
52        /// Index of the content block
53        block_index: usize,
54    },
55
56    /// Usage information (typically at stream end).
57    Usage(Usage),
58
59    /// Stream completed with stop reason.
60    Done {
61        /// Why the stream ended
62        stop_reason: Option<StopReason>,
63    },
64
65    /// A signature delta for a thinking block.
66    SignatureDelta {
67        /// The signature fragment to append
68        delta: String,
69        /// Index of the content block being streamed
70        block_index: usize,
71    },
72
73    /// A redacted thinking block received at `content_block_start`.
74    RedactedThinking {
75        /// Opaque data payload
76        data: String,
77        /// Index of the content block
78        block_index: usize,
79    },
80
81    /// Error during streaming.
82    Error {
83        /// Error message
84        message: String,
85        /// Whether the error is recoverable (e.g., rate limit)
86        recoverable: bool,
87    },
88}
89
90/// Type alias for a boxed stream of stream deltas.
91pub type StreamBox<'a> = Pin<Box<dyn Stream<Item = anyhow::Result<StreamDelta>> + Send + 'a>>;
92
93/// Helper to accumulate streamed content into a final response.
94///
95/// This struct collects [`StreamDelta`] events and can convert them
96/// into the final content blocks once the stream is complete.
97#[derive(Debug, Default)]
98pub struct StreamAccumulator {
99    /// Accumulated text for each block index
100    text_blocks: Vec<String>,
101    /// Accumulated thinking blocks for each block index
102    thinking_blocks: Vec<String>,
103    /// Accumulated signatures keyed by block index
104    thinking_signatures: HashMap<usize, String>,
105    /// Redacted thinking blocks: (`block_index`, data)
106    redacted_thinking_blocks: Vec<(usize, String)>,
107    /// Accumulated tool use calls
108    tool_uses: Vec<ToolUseAccumulator>,
109    /// Usage information from the stream
110    usage: Option<Usage>,
111    /// Stop reason from the stream
112    stop_reason: Option<StopReason>,
113}
114
115/// Accumulator for a single tool use during streaming.
116#[derive(Debug, Default)]
117pub struct ToolUseAccumulator {
118    /// Tool call ID
119    pub id: String,
120    /// Tool name
121    pub name: String,
122    /// Accumulated JSON input (may be incomplete during streaming)
123    pub input_json: String,
124    /// Block index for ordering
125    pub block_index: usize,
126    /// Optional thought signature (used by Gemini 3.x models)
127    pub thought_signature: Option<String>,
128}
129
130impl StreamAccumulator {
131    /// Create a new empty accumulator.
132    #[must_use]
133    pub fn new() -> Self {
134        Self::default()
135    }
136
137    /// Apply a stream delta to the accumulator.
138    pub fn apply(&mut self, delta: &StreamDelta) {
139        match delta {
140            StreamDelta::TextDelta { delta, block_index } => {
141                while self.text_blocks.len() <= *block_index {
142                    self.text_blocks.push(String::new());
143                }
144                self.text_blocks[*block_index].push_str(delta);
145            }
146            StreamDelta::ThinkingDelta { delta, block_index } => {
147                while self.thinking_blocks.len() <= *block_index {
148                    self.thinking_blocks.push(String::new());
149                }
150                self.thinking_blocks[*block_index].push_str(delta);
151            }
152            StreamDelta::ToolUseStart {
153                id,
154                name,
155                block_index,
156                thought_signature,
157            } => {
158                self.tool_uses.push(ToolUseAccumulator {
159                    id: id.clone(),
160                    name: name.clone(),
161                    input_json: String::new(),
162                    block_index: *block_index,
163                    thought_signature: thought_signature.clone(),
164                });
165            }
166            StreamDelta::ToolInputDelta { id, delta, .. } => {
167                if let Some(tool) = self.tool_uses.iter_mut().find(|t| t.id == *id) {
168                    tool.input_json.push_str(delta);
169                }
170            }
171            StreamDelta::SignatureDelta { delta, block_index } => {
172                self.thinking_signatures
173                    .entry(*block_index)
174                    .or_default()
175                    .push_str(delta);
176            }
177            StreamDelta::RedactedThinking { data, block_index } => {
178                self.redacted_thinking_blocks
179                    .push((*block_index, data.clone()));
180            }
181            StreamDelta::Usage(u) => {
182                self.usage = Some(u.clone());
183            }
184            StreamDelta::Done { stop_reason } => {
185                self.stop_reason = *stop_reason;
186            }
187            StreamDelta::Error { .. } => {}
188        }
189    }
190
191    /// Get the accumulated usage information.
192    #[must_use]
193    pub const fn usage(&self) -> Option<&Usage> {
194        self.usage.as_ref()
195    }
196
197    /// Get the stop reason.
198    #[must_use]
199    pub const fn stop_reason(&self) -> Option<&StopReason> {
200        self.stop_reason.as_ref()
201    }
202
203    /// Convert accumulated content to `ContentBlock`s.
204    ///
205    /// This consumes the accumulator and returns the final content blocks.
206    /// Tool use JSON is parsed at this point; invalid JSON results in a null input.
207    #[must_use]
208    pub fn into_content_blocks(self) -> Vec<ContentBlock> {
209        let mut blocks: Vec<(usize, ContentBlock)> = Vec::new();
210
211        // Add thinking blocks with their indices, attaching signatures
212        let mut signatures = self.thinking_signatures;
213        for (idx, thinking) in self.thinking_blocks.into_iter().enumerate() {
214            if !thinking.is_empty() {
215                let signature = signatures.remove(&idx).filter(|s| !s.is_empty());
216                blocks.push((
217                    idx,
218                    ContentBlock::Thinking {
219                        thinking,
220                        signature,
221                    },
222                ));
223            }
224        }
225
226        // Add redacted thinking blocks
227        for (idx, data) in self.redacted_thinking_blocks {
228            blocks.push((idx, ContentBlock::RedactedThinking { data }));
229        }
230
231        // Add text blocks with their indices
232        for (idx, text) in self.text_blocks.into_iter().enumerate() {
233            if !text.is_empty() {
234                blocks.push((idx, ContentBlock::Text { text }));
235            }
236        }
237
238        // Add tool uses with their indices
239        for tool in self.tool_uses {
240            let input: serde_json::Value =
241                serde_json::from_str(&tool.input_json).unwrap_or_else(|_| serde_json::json!({}));
242            blocks.push((
243                tool.block_index,
244                ContentBlock::ToolUse {
245                    id: tool.id,
246                    name: tool.name,
247                    input,
248                    thought_signature: tool.thought_signature,
249                },
250            ));
251        }
252
253        // Sort by block index to maintain order
254        blocks.sort_by_key(|(idx, _)| *idx);
255
256        blocks.into_iter().map(|(_, block)| block).collect()
257    }
258
259    /// Take ownership of accumulated usage.
260    pub const fn take_usage(&mut self) -> Option<Usage> {
261        self.usage.take()
262    }
263
264    /// Take ownership of stop reason.
265    pub const fn take_stop_reason(&mut self) -> Option<StopReason> {
266        self.stop_reason.take()
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_accumulator_text_deltas() {
276        let mut acc = StreamAccumulator::new();
277
278        acc.apply(&StreamDelta::TextDelta {
279            delta: "Hello".to_string(),
280            block_index: 0,
281        });
282        acc.apply(&StreamDelta::TextDelta {
283            delta: " world".to_string(),
284            block_index: 0,
285        });
286
287        let blocks = acc.into_content_blocks();
288        assert_eq!(blocks.len(), 1);
289        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello world"));
290    }
291
292    #[test]
293    fn test_accumulator_multiple_text_blocks() {
294        let mut acc = StreamAccumulator::new();
295
296        acc.apply(&StreamDelta::TextDelta {
297            delta: "First".to_string(),
298            block_index: 0,
299        });
300        acc.apply(&StreamDelta::TextDelta {
301            delta: "Second".to_string(),
302            block_index: 1,
303        });
304
305        let blocks = acc.into_content_blocks();
306        assert_eq!(blocks.len(), 2);
307        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "First"));
308        assert!(matches!(&blocks[1], ContentBlock::Text { text } if text == "Second"));
309    }
310
311    #[test]
312    fn test_accumulator_tool_use() {
313        let mut acc = StreamAccumulator::new();
314
315        acc.apply(&StreamDelta::ToolUseStart {
316            id: "call_123".to_string(),
317            name: "read_file".to_string(),
318            block_index: 0,
319            thought_signature: None,
320        });
321        acc.apply(&StreamDelta::ToolInputDelta {
322            id: "call_123".to_string(),
323            delta: r#"{"path":"#.to_string(),
324            block_index: 0,
325        });
326        acc.apply(&StreamDelta::ToolInputDelta {
327            id: "call_123".to_string(),
328            delta: r#""test.txt"}"#.to_string(),
329            block_index: 0,
330        });
331
332        let blocks = acc.into_content_blocks();
333        assert_eq!(blocks.len(), 1);
334        match &blocks[0] {
335            ContentBlock::ToolUse {
336                id, name, input, ..
337            } => {
338                assert_eq!(id, "call_123");
339                assert_eq!(name, "read_file");
340                assert_eq!(input["path"], "test.txt");
341            }
342            _ => panic!("Expected ToolUse block"),
343        }
344    }
345
346    #[test]
347    fn test_accumulator_mixed_content() {
348        let mut acc = StreamAccumulator::new();
349
350        acc.apply(&StreamDelta::TextDelta {
351            delta: "Let me read that file.".to_string(),
352            block_index: 0,
353        });
354        acc.apply(&StreamDelta::ToolUseStart {
355            id: "call_456".to_string(),
356            name: "read_file".to_string(),
357            block_index: 1,
358            thought_signature: None,
359        });
360        acc.apply(&StreamDelta::ToolInputDelta {
361            id: "call_456".to_string(),
362            delta: r#"{"path":"file.txt"}"#.to_string(),
363            block_index: 1,
364        });
365        acc.apply(&StreamDelta::Usage(Usage {
366            input_tokens: 100,
367            output_tokens: 50,
368        }));
369        acc.apply(&StreamDelta::Done {
370            stop_reason: Some(StopReason::ToolUse),
371        });
372
373        assert!(acc.usage().is_some());
374        assert_eq!(acc.usage().map(|u| u.input_tokens), Some(100));
375        assert!(matches!(acc.stop_reason(), Some(StopReason::ToolUse)));
376
377        let blocks = acc.into_content_blocks();
378        assert_eq!(blocks.len(), 2);
379        assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
380        assert!(matches!(&blocks[1], ContentBlock::ToolUse { .. }));
381    }
382
383    #[test]
384    fn test_accumulator_invalid_tool_json() {
385        let mut acc = StreamAccumulator::new();
386
387        acc.apply(&StreamDelta::ToolUseStart {
388            id: "call_789".to_string(),
389            name: "test_tool".to_string(),
390            block_index: 0,
391            thought_signature: None,
392        });
393        acc.apply(&StreamDelta::ToolInputDelta {
394            id: "call_789".to_string(),
395            delta: "invalid json {".to_string(),
396            block_index: 0,
397        });
398
399        let blocks = acc.into_content_blocks();
400        assert_eq!(blocks.len(), 1);
401        match &blocks[0] {
402            ContentBlock::ToolUse { input, .. } => {
403                assert!(input.is_object());
404            }
405            _ => panic!("Expected ToolUse block"),
406        }
407    }
408
409    #[test]
410    fn test_accumulator_empty() {
411        let acc = StreamAccumulator::new();
412        let blocks = acc.into_content_blocks();
413        assert!(blocks.is_empty());
414    }
415
416    #[test]
417    fn test_accumulator_skips_empty_text() {
418        let mut acc = StreamAccumulator::new();
419
420        acc.apply(&StreamDelta::TextDelta {
421            delta: String::new(),
422            block_index: 0,
423        });
424        acc.apply(&StreamDelta::TextDelta {
425            delta: "Hello".to_string(),
426            block_index: 1,
427        });
428
429        let blocks = acc.into_content_blocks();
430        assert_eq!(blocks.len(), 1);
431        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello"));
432    }
433}