Skip to main content

agent_sdk/llm/
streaming.rs

1//! Streaming types for LLM responses.
2//!
3//! This module provides types for handling streaming responses from LLM providers.
4//! The [`StreamDelta`] enum represents individual events in a streaming response,
5//! and [`StreamAccumulator`] helps collect these events into a final response.
6
7use crate::llm::{ContentBlock, StopReason, Usage};
8use futures::Stream;
9use std::pin::Pin;
10
11/// Events yielded during streaming LLM responses.
12///
13/// Each variant represents a different type of event that can occur
14/// during a streaming response from an LLM provider.
15#[derive(Debug, Clone)]
16pub enum StreamDelta {
17    /// A text delta for streaming text content.
18    TextDelta {
19        /// The text fragment to append
20        delta: String,
21        /// Index of the content block being streamed
22        block_index: usize,
23    },
24
25    /// A thinking delta for streaming thinking/reasoning content.
26    ThinkingDelta {
27        /// The thinking fragment to append
28        delta: String,
29        /// Index of the content block being streamed
30        block_index: usize,
31    },
32
33    /// Start of a tool use block (name and id are known).
34    ToolUseStart {
35        /// Unique identifier for this tool call
36        id: String,
37        /// Name of the tool being called
38        name: String,
39        /// Index of the content block
40        block_index: usize,
41    },
42
43    /// Incremental JSON for tool input (partial/incomplete JSON).
44    ToolInputDelta {
45        /// Tool call ID this delta belongs to
46        id: String,
47        /// JSON fragment to append
48        delta: String,
49        /// Index of the content block
50        block_index: usize,
51    },
52
53    /// Usage information (typically at stream end).
54    Usage(Usage),
55
56    /// Stream completed with stop reason.
57    Done {
58        /// Why the stream ended
59        stop_reason: Option<StopReason>,
60    },
61
62    /// Error during streaming.
63    Error {
64        /// Error message
65        message: String,
66        /// Whether the error is recoverable (e.g., rate limit)
67        recoverable: bool,
68    },
69}
70
71/// Type alias for a boxed stream of stream deltas.
72pub type StreamBox<'a> = Pin<Box<dyn Stream<Item = anyhow::Result<StreamDelta>> + Send + 'a>>;
73
74/// Helper to accumulate streamed content into a final response.
75///
76/// This struct collects [`StreamDelta`] events and can convert them
77/// into the final content blocks once the stream is complete.
78#[derive(Debug, Default)]
79pub struct StreamAccumulator {
80    /// Accumulated text for each block index
81    text_blocks: Vec<String>,
82    /// Accumulated thinking blocks for each block index
83    thinking_blocks: Vec<String>,
84    /// Accumulated tool use calls
85    tool_uses: Vec<ToolUseAccumulator>,
86    /// Usage information from the stream
87    usage: Option<Usage>,
88    /// Stop reason from the stream
89    stop_reason: Option<StopReason>,
90}
91
92/// Accumulator for a single tool use during streaming.
93#[derive(Debug, Default)]
94pub struct ToolUseAccumulator {
95    /// Tool call ID
96    pub id: String,
97    /// Tool name
98    pub name: String,
99    /// Accumulated JSON input (may be incomplete during streaming)
100    pub input_json: String,
101    /// Block index for ordering
102    pub block_index: usize,
103}
104
105impl StreamAccumulator {
106    /// Create a new empty accumulator.
107    #[must_use]
108    pub fn new() -> Self {
109        Self::default()
110    }
111
112    /// Apply a stream delta to the accumulator.
113    pub fn apply(&mut self, delta: &StreamDelta) {
114        match delta {
115            StreamDelta::TextDelta { delta, block_index } => {
116                while self.text_blocks.len() <= *block_index {
117                    self.text_blocks.push(String::new());
118                }
119                self.text_blocks[*block_index].push_str(delta);
120            }
121            StreamDelta::ThinkingDelta { delta, block_index } => {
122                while self.thinking_blocks.len() <= *block_index {
123                    self.thinking_blocks.push(String::new());
124                }
125                self.thinking_blocks[*block_index].push_str(delta);
126            }
127            StreamDelta::ToolUseStart {
128                id,
129                name,
130                block_index,
131            } => {
132                self.tool_uses.push(ToolUseAccumulator {
133                    id: id.clone(),
134                    name: name.clone(),
135                    input_json: String::new(),
136                    block_index: *block_index,
137                });
138            }
139            StreamDelta::ToolInputDelta { id, delta, .. } => {
140                if let Some(tool) = self.tool_uses.iter_mut().find(|t| t.id == *id) {
141                    tool.input_json.push_str(delta);
142                }
143            }
144            StreamDelta::Usage(u) => {
145                self.usage = Some(u.clone());
146            }
147            StreamDelta::Done { stop_reason } => {
148                self.stop_reason = *stop_reason;
149            }
150            StreamDelta::Error { .. } => {}
151        }
152    }
153
154    /// Get the accumulated usage information.
155    #[must_use]
156    pub const fn usage(&self) -> Option<&Usage> {
157        self.usage.as_ref()
158    }
159
160    /// Get the stop reason.
161    #[must_use]
162    pub const fn stop_reason(&self) -> Option<&StopReason> {
163        self.stop_reason.as_ref()
164    }
165
166    /// Convert accumulated content to `ContentBlock`s.
167    ///
168    /// This consumes the accumulator and returns the final content blocks.
169    /// Tool use JSON is parsed at this point; invalid JSON results in a null input.
170    #[must_use]
171    pub fn into_content_blocks(self) -> Vec<ContentBlock> {
172        let mut blocks: Vec<(usize, ContentBlock)> = Vec::new();
173
174        // Add thinking blocks with their indices
175        for (idx, thinking) in self.thinking_blocks.into_iter().enumerate() {
176            if !thinking.is_empty() {
177                blocks.push((idx, ContentBlock::Thinking { thinking }));
178            }
179        }
180
181        // Add text blocks with their indices
182        for (idx, text) in self.text_blocks.into_iter().enumerate() {
183            if !text.is_empty() {
184                blocks.push((idx, ContentBlock::Text { text }));
185            }
186        }
187
188        // Add tool uses with their indices
189        for tool in self.tool_uses {
190            let input: serde_json::Value =
191                serde_json::from_str(&tool.input_json).unwrap_or_else(|_| serde_json::json!({}));
192            blocks.push((
193                tool.block_index,
194                ContentBlock::ToolUse {
195                    id: tool.id,
196                    name: tool.name,
197                    input,
198                    thought_signature: None, // Streaming doesn't provide thought signatures
199                },
200            ));
201        }
202
203        // Sort by block index to maintain order
204        blocks.sort_by_key(|(idx, _)| *idx);
205
206        blocks.into_iter().map(|(_, block)| block).collect()
207    }
208
209    /// Take ownership of accumulated usage.
210    pub const fn take_usage(&mut self) -> Option<Usage> {
211        self.usage.take()
212    }
213
214    /// Take ownership of stop reason.
215    pub const fn take_stop_reason(&mut self) -> Option<StopReason> {
216        self.stop_reason.take()
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn test_accumulator_text_deltas() {
226        let mut acc = StreamAccumulator::new();
227
228        acc.apply(&StreamDelta::TextDelta {
229            delta: "Hello".to_string(),
230            block_index: 0,
231        });
232        acc.apply(&StreamDelta::TextDelta {
233            delta: " world".to_string(),
234            block_index: 0,
235        });
236
237        let blocks = acc.into_content_blocks();
238        assert_eq!(blocks.len(), 1);
239        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello world"));
240    }
241
242    #[test]
243    fn test_accumulator_multiple_text_blocks() {
244        let mut acc = StreamAccumulator::new();
245
246        acc.apply(&StreamDelta::TextDelta {
247            delta: "First".to_string(),
248            block_index: 0,
249        });
250        acc.apply(&StreamDelta::TextDelta {
251            delta: "Second".to_string(),
252            block_index: 1,
253        });
254
255        let blocks = acc.into_content_blocks();
256        assert_eq!(blocks.len(), 2);
257        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "First"));
258        assert!(matches!(&blocks[1], ContentBlock::Text { text } if text == "Second"));
259    }
260
261    #[test]
262    fn test_accumulator_tool_use() {
263        let mut acc = StreamAccumulator::new();
264
265        acc.apply(&StreamDelta::ToolUseStart {
266            id: "call_123".to_string(),
267            name: "read_file".to_string(),
268            block_index: 0,
269        });
270        acc.apply(&StreamDelta::ToolInputDelta {
271            id: "call_123".to_string(),
272            delta: r#"{"path":"#.to_string(),
273            block_index: 0,
274        });
275        acc.apply(&StreamDelta::ToolInputDelta {
276            id: "call_123".to_string(),
277            delta: r#""test.txt"}"#.to_string(),
278            block_index: 0,
279        });
280
281        let blocks = acc.into_content_blocks();
282        assert_eq!(blocks.len(), 1);
283        match &blocks[0] {
284            ContentBlock::ToolUse {
285                id, name, input, ..
286            } => {
287                assert_eq!(id, "call_123");
288                assert_eq!(name, "read_file");
289                assert_eq!(input["path"], "test.txt");
290            }
291            _ => panic!("Expected ToolUse block"),
292        }
293    }
294
295    #[test]
296    fn test_accumulator_mixed_content() {
297        let mut acc = StreamAccumulator::new();
298
299        acc.apply(&StreamDelta::TextDelta {
300            delta: "Let me read that file.".to_string(),
301            block_index: 0,
302        });
303        acc.apply(&StreamDelta::ToolUseStart {
304            id: "call_456".to_string(),
305            name: "read_file".to_string(),
306            block_index: 1,
307        });
308        acc.apply(&StreamDelta::ToolInputDelta {
309            id: "call_456".to_string(),
310            delta: r#"{"path":"file.txt"}"#.to_string(),
311            block_index: 1,
312        });
313        acc.apply(&StreamDelta::Usage(Usage {
314            input_tokens: 100,
315            output_tokens: 50,
316        }));
317        acc.apply(&StreamDelta::Done {
318            stop_reason: Some(StopReason::ToolUse),
319        });
320
321        assert!(acc.usage().is_some());
322        assert_eq!(acc.usage().map(|u| u.input_tokens), Some(100));
323        assert!(matches!(acc.stop_reason(), Some(StopReason::ToolUse)));
324
325        let blocks = acc.into_content_blocks();
326        assert_eq!(blocks.len(), 2);
327        assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
328        assert!(matches!(&blocks[1], ContentBlock::ToolUse { .. }));
329    }
330
331    #[test]
332    fn test_accumulator_invalid_tool_json() {
333        let mut acc = StreamAccumulator::new();
334
335        acc.apply(&StreamDelta::ToolUseStart {
336            id: "call_789".to_string(),
337            name: "test_tool".to_string(),
338            block_index: 0,
339        });
340        acc.apply(&StreamDelta::ToolInputDelta {
341            id: "call_789".to_string(),
342            delta: "invalid json {".to_string(),
343            block_index: 0,
344        });
345
346        let blocks = acc.into_content_blocks();
347        assert_eq!(blocks.len(), 1);
348        match &blocks[0] {
349            ContentBlock::ToolUse { input, .. } => {
350                assert!(input.is_object());
351            }
352            _ => panic!("Expected ToolUse block"),
353        }
354    }
355
356    #[test]
357    fn test_accumulator_empty() {
358        let acc = StreamAccumulator::new();
359        let blocks = acc.into_content_blocks();
360        assert!(blocks.is_empty());
361    }
362
363    #[test]
364    fn test_accumulator_skips_empty_text() {
365        let mut acc = StreamAccumulator::new();
366
367        acc.apply(&StreamDelta::TextDelta {
368            delta: String::new(),
369            block_index: 0,
370        });
371        acc.apply(&StreamDelta::TextDelta {
372            delta: "Hello".to_string(),
373            block_index: 1,
374        });
375
376        let blocks = acc.into_content_blocks();
377        assert_eq!(blocks.len(), 1);
378        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello"));
379    }
380}