1use crate::llm::{ContentBlock, StopReason, Usage};
8use futures::Stream;
9use std::pin::Pin;
10
11#[derive(Debug, Clone)]
16pub enum StreamDelta {
17 TextDelta {
19 delta: String,
21 block_index: usize,
23 },
24
25 ToolUseStart {
27 id: String,
29 name: String,
31 block_index: usize,
33 },
34
35 ToolInputDelta {
37 id: String,
39 delta: String,
41 block_index: usize,
43 },
44
45 Usage(Usage),
47
48 Done {
50 stop_reason: Option<StopReason>,
52 },
53
54 Error {
56 message: String,
58 recoverable: bool,
60 },
61}
62
63pub type StreamBox<'a> = Pin<Box<dyn Stream<Item = anyhow::Result<StreamDelta>> + Send + 'a>>;
65
66#[derive(Debug, Default)]
71pub struct StreamAccumulator {
72 text_blocks: Vec<String>,
74 tool_uses: Vec<ToolUseAccumulator>,
76 usage: Option<Usage>,
78 stop_reason: Option<StopReason>,
80}
81
82#[derive(Debug, Default)]
84pub struct ToolUseAccumulator {
85 pub id: String,
87 pub name: String,
89 pub input_json: String,
91 pub block_index: usize,
93}
94
95impl StreamAccumulator {
96 #[must_use]
98 pub fn new() -> Self {
99 Self::default()
100 }
101
102 pub fn apply(&mut self, delta: &StreamDelta) {
104 match delta {
105 StreamDelta::TextDelta { delta, block_index } => {
106 while self.text_blocks.len() <= *block_index {
107 self.text_blocks.push(String::new());
108 }
109 self.text_blocks[*block_index].push_str(delta);
110 }
111 StreamDelta::ToolUseStart {
112 id,
113 name,
114 block_index,
115 } => {
116 self.tool_uses.push(ToolUseAccumulator {
117 id: id.clone(),
118 name: name.clone(),
119 input_json: String::new(),
120 block_index: *block_index,
121 });
122 }
123 StreamDelta::ToolInputDelta { id, delta, .. } => {
124 if let Some(tool) = self.tool_uses.iter_mut().find(|t| t.id == *id) {
125 tool.input_json.push_str(delta);
126 }
127 }
128 StreamDelta::Usage(u) => {
129 self.usage = Some(u.clone());
130 }
131 StreamDelta::Done { stop_reason } => {
132 self.stop_reason = *stop_reason;
133 }
134 StreamDelta::Error { .. } => {}
135 }
136 }
137
138 #[must_use]
140 pub const fn usage(&self) -> Option<&Usage> {
141 self.usage.as_ref()
142 }
143
144 #[must_use]
146 pub const fn stop_reason(&self) -> Option<&StopReason> {
147 self.stop_reason.as_ref()
148 }
149
150 #[must_use]
155 pub fn into_content_blocks(self) -> Vec<ContentBlock> {
156 let mut blocks: Vec<(usize, ContentBlock)> = Vec::new();
157
158 for (idx, text) in self.text_blocks.into_iter().enumerate() {
160 if !text.is_empty() {
161 blocks.push((idx, ContentBlock::Text { text }));
162 }
163 }
164
165 for tool in self.tool_uses {
167 let input: serde_json::Value =
168 serde_json::from_str(&tool.input_json).unwrap_or(serde_json::Value::Null);
169 blocks.push((
170 tool.block_index,
171 ContentBlock::ToolUse {
172 id: tool.id,
173 name: tool.name,
174 input,
175 thought_signature: None, },
177 ));
178 }
179
180 blocks.sort_by_key(|(idx, _)| *idx);
182
183 blocks.into_iter().map(|(_, block)| block).collect()
184 }
185
186 pub const fn take_usage(&mut self) -> Option<Usage> {
188 self.usage.take()
189 }
190
191 pub const fn take_stop_reason(&mut self) -> Option<StopReason> {
193 self.stop_reason.take()
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_accumulator_text_deltas() {
203 let mut acc = StreamAccumulator::new();
204
205 acc.apply(&StreamDelta::TextDelta {
206 delta: "Hello".to_string(),
207 block_index: 0,
208 });
209 acc.apply(&StreamDelta::TextDelta {
210 delta: " world".to_string(),
211 block_index: 0,
212 });
213
214 let blocks = acc.into_content_blocks();
215 assert_eq!(blocks.len(), 1);
216 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello world"));
217 }
218
219 #[test]
220 fn test_accumulator_multiple_text_blocks() {
221 let mut acc = StreamAccumulator::new();
222
223 acc.apply(&StreamDelta::TextDelta {
224 delta: "First".to_string(),
225 block_index: 0,
226 });
227 acc.apply(&StreamDelta::TextDelta {
228 delta: "Second".to_string(),
229 block_index: 1,
230 });
231
232 let blocks = acc.into_content_blocks();
233 assert_eq!(blocks.len(), 2);
234 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "First"));
235 assert!(matches!(&blocks[1], ContentBlock::Text { text } if text == "Second"));
236 }
237
238 #[test]
239 fn test_accumulator_tool_use() {
240 let mut acc = StreamAccumulator::new();
241
242 acc.apply(&StreamDelta::ToolUseStart {
243 id: "call_123".to_string(),
244 name: "read_file".to_string(),
245 block_index: 0,
246 });
247 acc.apply(&StreamDelta::ToolInputDelta {
248 id: "call_123".to_string(),
249 delta: r#"{"path":"#.to_string(),
250 block_index: 0,
251 });
252 acc.apply(&StreamDelta::ToolInputDelta {
253 id: "call_123".to_string(),
254 delta: r#""test.txt"}"#.to_string(),
255 block_index: 0,
256 });
257
258 let blocks = acc.into_content_blocks();
259 assert_eq!(blocks.len(), 1);
260 match &blocks[0] {
261 ContentBlock::ToolUse {
262 id, name, input, ..
263 } => {
264 assert_eq!(id, "call_123");
265 assert_eq!(name, "read_file");
266 assert_eq!(input["path"], "test.txt");
267 }
268 _ => panic!("Expected ToolUse block"),
269 }
270 }
271
272 #[test]
273 fn test_accumulator_mixed_content() {
274 let mut acc = StreamAccumulator::new();
275
276 acc.apply(&StreamDelta::TextDelta {
277 delta: "Let me read that file.".to_string(),
278 block_index: 0,
279 });
280 acc.apply(&StreamDelta::ToolUseStart {
281 id: "call_456".to_string(),
282 name: "read_file".to_string(),
283 block_index: 1,
284 });
285 acc.apply(&StreamDelta::ToolInputDelta {
286 id: "call_456".to_string(),
287 delta: r#"{"path":"file.txt"}"#.to_string(),
288 block_index: 1,
289 });
290 acc.apply(&StreamDelta::Usage(Usage {
291 input_tokens: 100,
292 output_tokens: 50,
293 }));
294 acc.apply(&StreamDelta::Done {
295 stop_reason: Some(StopReason::ToolUse),
296 });
297
298 assert!(acc.usage().is_some());
299 assert_eq!(acc.usage().map(|u| u.input_tokens), Some(100));
300 assert!(matches!(acc.stop_reason(), Some(StopReason::ToolUse)));
301
302 let blocks = acc.into_content_blocks();
303 assert_eq!(blocks.len(), 2);
304 assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
305 assert!(matches!(&blocks[1], ContentBlock::ToolUse { .. }));
306 }
307
308 #[test]
309 fn test_accumulator_invalid_tool_json() {
310 let mut acc = StreamAccumulator::new();
311
312 acc.apply(&StreamDelta::ToolUseStart {
313 id: "call_789".to_string(),
314 name: "test_tool".to_string(),
315 block_index: 0,
316 });
317 acc.apply(&StreamDelta::ToolInputDelta {
318 id: "call_789".to_string(),
319 delta: "invalid json {".to_string(),
320 block_index: 0,
321 });
322
323 let blocks = acc.into_content_blocks();
324 assert_eq!(blocks.len(), 1);
325 match &blocks[0] {
326 ContentBlock::ToolUse { input, .. } => {
327 assert!(input.is_null());
328 }
329 _ => panic!("Expected ToolUse block"),
330 }
331 }
332
333 #[test]
334 fn test_accumulator_empty() {
335 let acc = StreamAccumulator::new();
336 let blocks = acc.into_content_blocks();
337 assert!(blocks.is_empty());
338 }
339
340 #[test]
341 fn test_accumulator_skips_empty_text() {
342 let mut acc = StreamAccumulator::new();
343
344 acc.apply(&StreamDelta::TextDelta {
345 delta: String::new(),
346 block_index: 0,
347 });
348 acc.apply(&StreamDelta::TextDelta {
349 delta: "Hello".to_string(),
350 block_index: 1,
351 });
352
353 let blocks = acc.into_content_blocks();
354 assert_eq!(blocks.len(), 1);
355 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello"));
356 }
357}