agent_sdk/llm/
streaming.rs

1//! Streaming types for LLM responses.
2//!
3//! This module provides types for handling streaming responses from LLM providers.
4//! The [`StreamDelta`] enum represents individual events in a streaming response,
5//! and [`StreamAccumulator`] helps collect these events into a final response.
6
7use crate::llm::{ContentBlock, StopReason, Usage};
8use futures::Stream;
9use std::pin::Pin;
10
11/// Events yielded during streaming LLM responses.
12///
13/// Each variant represents a different type of event that can occur
14/// during a streaming response from an LLM provider.
15#[derive(Debug, Clone)]
16pub enum StreamDelta {
17    /// A text delta for streaming text content.
18    TextDelta {
19        /// The text fragment to append
20        delta: String,
21        /// Index of the content block being streamed
22        block_index: usize,
23    },
24
25    /// Start of a tool use block (name and id are known).
26    ToolUseStart {
27        /// Unique identifier for this tool call
28        id: String,
29        /// Name of the tool being called
30        name: String,
31        /// Index of the content block
32        block_index: usize,
33    },
34
35    /// Incremental JSON for tool input (partial/incomplete JSON).
36    ToolInputDelta {
37        /// Tool call ID this delta belongs to
38        id: String,
39        /// JSON fragment to append
40        delta: String,
41        /// Index of the content block
42        block_index: usize,
43    },
44
45    /// Usage information (typically at stream end).
46    Usage(Usage),
47
48    /// Stream completed with stop reason.
49    Done {
50        /// Why the stream ended
51        stop_reason: Option<StopReason>,
52    },
53
54    /// Error during streaming.
55    Error {
56        /// Error message
57        message: String,
58        /// Whether the error is recoverable (e.g., rate limit)
59        recoverable: bool,
60    },
61}
62
63/// Type alias for a boxed stream of stream deltas.
64pub type StreamBox<'a> = Pin<Box<dyn Stream<Item = anyhow::Result<StreamDelta>> + Send + 'a>>;
65
66/// Helper to accumulate streamed content into a final response.
67///
68/// This struct collects [`StreamDelta`] events and can convert them
69/// into the final content blocks once the stream is complete.
70#[derive(Debug, Default)]
71pub struct StreamAccumulator {
72    /// Accumulated text for each block index
73    text_blocks: Vec<String>,
74    /// Accumulated tool use calls
75    tool_uses: Vec<ToolUseAccumulator>,
76    /// Usage information from the stream
77    usage: Option<Usage>,
78    /// Stop reason from the stream
79    stop_reason: Option<StopReason>,
80}
81
82/// Accumulator for a single tool use during streaming.
83#[derive(Debug, Default)]
84pub struct ToolUseAccumulator {
85    /// Tool call ID
86    pub id: String,
87    /// Tool name
88    pub name: String,
89    /// Accumulated JSON input (may be incomplete during streaming)
90    pub input_json: String,
91    /// Block index for ordering
92    pub block_index: usize,
93}
94
95impl StreamAccumulator {
96    /// Create a new empty accumulator.
97    #[must_use]
98    pub fn new() -> Self {
99        Self::default()
100    }
101
102    /// Apply a stream delta to the accumulator.
103    pub fn apply(&mut self, delta: &StreamDelta) {
104        match delta {
105            StreamDelta::TextDelta { delta, block_index } => {
106                while self.text_blocks.len() <= *block_index {
107                    self.text_blocks.push(String::new());
108                }
109                self.text_blocks[*block_index].push_str(delta);
110            }
111            StreamDelta::ToolUseStart {
112                id,
113                name,
114                block_index,
115            } => {
116                self.tool_uses.push(ToolUseAccumulator {
117                    id: id.clone(),
118                    name: name.clone(),
119                    input_json: String::new(),
120                    block_index: *block_index,
121                });
122            }
123            StreamDelta::ToolInputDelta { id, delta, .. } => {
124                if let Some(tool) = self.tool_uses.iter_mut().find(|t| t.id == *id) {
125                    tool.input_json.push_str(delta);
126                }
127            }
128            StreamDelta::Usage(u) => {
129                self.usage = Some(u.clone());
130            }
131            StreamDelta::Done { stop_reason } => {
132                self.stop_reason = *stop_reason;
133            }
134            StreamDelta::Error { .. } => {}
135        }
136    }
137
138    /// Get the accumulated usage information.
139    #[must_use]
140    pub const fn usage(&self) -> Option<&Usage> {
141        self.usage.as_ref()
142    }
143
144    /// Get the stop reason.
145    #[must_use]
146    pub const fn stop_reason(&self) -> Option<&StopReason> {
147        self.stop_reason.as_ref()
148    }
149
150    /// Convert accumulated content to `ContentBlock`s.
151    ///
152    /// This consumes the accumulator and returns the final content blocks.
153    /// Tool use JSON is parsed at this point; invalid JSON results in a null input.
154    #[must_use]
155    pub fn into_content_blocks(self) -> Vec<ContentBlock> {
156        let mut blocks: Vec<(usize, ContentBlock)> = Vec::new();
157
158        // Add text blocks with their indices
159        for (idx, text) in self.text_blocks.into_iter().enumerate() {
160            if !text.is_empty() {
161                blocks.push((idx, ContentBlock::Text { text }));
162            }
163        }
164
165        // Add tool uses with their indices
166        for tool in self.tool_uses {
167            let input: serde_json::Value =
168                serde_json::from_str(&tool.input_json).unwrap_or(serde_json::Value::Null);
169            blocks.push((
170                tool.block_index,
171                ContentBlock::ToolUse {
172                    id: tool.id,
173                    name: tool.name,
174                    input,
175                    thought_signature: None, // Streaming doesn't provide thought signatures
176                },
177            ));
178        }
179
180        // Sort by block index to maintain order
181        blocks.sort_by_key(|(idx, _)| *idx);
182
183        blocks.into_iter().map(|(_, block)| block).collect()
184    }
185
186    /// Take ownership of accumulated usage.
187    pub const fn take_usage(&mut self) -> Option<Usage> {
188        self.usage.take()
189    }
190
191    /// Take ownership of stop reason.
192    pub const fn take_stop_reason(&mut self) -> Option<StopReason> {
193        self.stop_reason.take()
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_accumulator_text_deltas() {
203        let mut acc = StreamAccumulator::new();
204
205        acc.apply(&StreamDelta::TextDelta {
206            delta: "Hello".to_string(),
207            block_index: 0,
208        });
209        acc.apply(&StreamDelta::TextDelta {
210            delta: " world".to_string(),
211            block_index: 0,
212        });
213
214        let blocks = acc.into_content_blocks();
215        assert_eq!(blocks.len(), 1);
216        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello world"));
217    }
218
219    #[test]
220    fn test_accumulator_multiple_text_blocks() {
221        let mut acc = StreamAccumulator::new();
222
223        acc.apply(&StreamDelta::TextDelta {
224            delta: "First".to_string(),
225            block_index: 0,
226        });
227        acc.apply(&StreamDelta::TextDelta {
228            delta: "Second".to_string(),
229            block_index: 1,
230        });
231
232        let blocks = acc.into_content_blocks();
233        assert_eq!(blocks.len(), 2);
234        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "First"));
235        assert!(matches!(&blocks[1], ContentBlock::Text { text } if text == "Second"));
236    }
237
238    #[test]
239    fn test_accumulator_tool_use() {
240        let mut acc = StreamAccumulator::new();
241
242        acc.apply(&StreamDelta::ToolUseStart {
243            id: "call_123".to_string(),
244            name: "read_file".to_string(),
245            block_index: 0,
246        });
247        acc.apply(&StreamDelta::ToolInputDelta {
248            id: "call_123".to_string(),
249            delta: r#"{"path":"#.to_string(),
250            block_index: 0,
251        });
252        acc.apply(&StreamDelta::ToolInputDelta {
253            id: "call_123".to_string(),
254            delta: r#""test.txt"}"#.to_string(),
255            block_index: 0,
256        });
257
258        let blocks = acc.into_content_blocks();
259        assert_eq!(blocks.len(), 1);
260        match &blocks[0] {
261            ContentBlock::ToolUse {
262                id, name, input, ..
263            } => {
264                assert_eq!(id, "call_123");
265                assert_eq!(name, "read_file");
266                assert_eq!(input["path"], "test.txt");
267            }
268            _ => panic!("Expected ToolUse block"),
269        }
270    }
271
272    #[test]
273    fn test_accumulator_mixed_content() {
274        let mut acc = StreamAccumulator::new();
275
276        acc.apply(&StreamDelta::TextDelta {
277            delta: "Let me read that file.".to_string(),
278            block_index: 0,
279        });
280        acc.apply(&StreamDelta::ToolUseStart {
281            id: "call_456".to_string(),
282            name: "read_file".to_string(),
283            block_index: 1,
284        });
285        acc.apply(&StreamDelta::ToolInputDelta {
286            id: "call_456".to_string(),
287            delta: r#"{"path":"file.txt"}"#.to_string(),
288            block_index: 1,
289        });
290        acc.apply(&StreamDelta::Usage(Usage {
291            input_tokens: 100,
292            output_tokens: 50,
293        }));
294        acc.apply(&StreamDelta::Done {
295            stop_reason: Some(StopReason::ToolUse),
296        });
297
298        assert!(acc.usage().is_some());
299        assert_eq!(acc.usage().map(|u| u.input_tokens), Some(100));
300        assert!(matches!(acc.stop_reason(), Some(StopReason::ToolUse)));
301
302        let blocks = acc.into_content_blocks();
303        assert_eq!(blocks.len(), 2);
304        assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
305        assert!(matches!(&blocks[1], ContentBlock::ToolUse { .. }));
306    }
307
308    #[test]
309    fn test_accumulator_invalid_tool_json() {
310        let mut acc = StreamAccumulator::new();
311
312        acc.apply(&StreamDelta::ToolUseStart {
313            id: "call_789".to_string(),
314            name: "test_tool".to_string(),
315            block_index: 0,
316        });
317        acc.apply(&StreamDelta::ToolInputDelta {
318            id: "call_789".to_string(),
319            delta: "invalid json {".to_string(),
320            block_index: 0,
321        });
322
323        let blocks = acc.into_content_blocks();
324        assert_eq!(blocks.len(), 1);
325        match &blocks[0] {
326            ContentBlock::ToolUse { input, .. } => {
327                assert!(input.is_null());
328            }
329            _ => panic!("Expected ToolUse block"),
330        }
331    }
332
333    #[test]
334    fn test_accumulator_empty() {
335        let acc = StreamAccumulator::new();
336        let blocks = acc.into_content_blocks();
337        assert!(blocks.is_empty());
338    }
339
340    #[test]
341    fn test_accumulator_skips_empty_text() {
342        let mut acc = StreamAccumulator::new();
343
344        acc.apply(&StreamDelta::TextDelta {
345            delta: String::new(),
346            block_index: 0,
347        });
348        acc.apply(&StreamDelta::TextDelta {
349            delta: "Hello".to_string(),
350            block_index: 1,
351        });
352
353        let blocks = acc.into_content_blocks();
354        assert_eq!(blocks.len(), 1);
355        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello"));
356    }
357}