1use crate::llm::{ContentBlock, StopReason, Usage};
8use futures::Stream;
9use std::collections::HashMap;
10use std::pin::Pin;
11
12#[derive(Debug, Clone)]
17pub enum StreamDelta {
18 TextDelta {
20 delta: String,
22 block_index: usize,
24 },
25
26 ThinkingDelta {
28 delta: String,
30 block_index: usize,
32 },
33
34 ToolUseStart {
36 id: String,
38 name: String,
40 block_index: usize,
42 thought_signature: Option<String>,
44 },
45
46 ToolInputDelta {
48 id: String,
50 delta: String,
52 block_index: usize,
54 },
55
56 Usage(Usage),
58
59 Done {
61 stop_reason: Option<StopReason>,
63 },
64
65 SignatureDelta {
67 delta: String,
69 block_index: usize,
71 },
72
73 RedactedThinking {
75 data: String,
77 block_index: usize,
79 },
80
81 Error {
83 message: String,
85 recoverable: bool,
87 },
88}
89
90pub type StreamBox<'a> = Pin<Box<dyn Stream<Item = anyhow::Result<StreamDelta>> + Send + 'a>>;
92
93#[derive(Debug, Default)]
98pub struct StreamAccumulator {
99 text_blocks: Vec<String>,
101 thinking_blocks: Vec<String>,
103 thinking_signatures: HashMap<usize, String>,
105 redacted_thinking_blocks: Vec<(usize, String)>,
107 tool_uses: Vec<ToolUseAccumulator>,
109 usage: Option<Usage>,
111 stop_reason: Option<StopReason>,
113}
114
115#[derive(Debug, Default)]
117pub struct ToolUseAccumulator {
118 pub id: String,
120 pub name: String,
122 pub input_json: String,
124 pub block_index: usize,
126 pub thought_signature: Option<String>,
128}
129
130impl StreamAccumulator {
131 #[must_use]
133 pub fn new() -> Self {
134 Self::default()
135 }
136
137 pub fn apply(&mut self, delta: &StreamDelta) {
139 match delta {
140 StreamDelta::TextDelta { delta, block_index } => {
141 while self.text_blocks.len() <= *block_index {
142 self.text_blocks.push(String::new());
143 }
144 self.text_blocks[*block_index].push_str(delta);
145 }
146 StreamDelta::ThinkingDelta { delta, block_index } => {
147 while self.thinking_blocks.len() <= *block_index {
148 self.thinking_blocks.push(String::new());
149 }
150 self.thinking_blocks[*block_index].push_str(delta);
151 }
152 StreamDelta::ToolUseStart {
153 id,
154 name,
155 block_index,
156 thought_signature,
157 } => {
158 self.tool_uses.push(ToolUseAccumulator {
159 id: id.clone(),
160 name: name.clone(),
161 input_json: String::new(),
162 block_index: *block_index,
163 thought_signature: thought_signature.clone(),
164 });
165 }
166 StreamDelta::ToolInputDelta { id, delta, .. } => {
167 if let Some(tool) = self.tool_uses.iter_mut().find(|t| t.id == *id) {
168 tool.input_json.push_str(delta);
169 }
170 }
171 StreamDelta::SignatureDelta { delta, block_index } => {
172 self.thinking_signatures
173 .entry(*block_index)
174 .or_default()
175 .push_str(delta);
176 }
177 StreamDelta::RedactedThinking { data, block_index } => {
178 self.redacted_thinking_blocks
179 .push((*block_index, data.clone()));
180 }
181 StreamDelta::Usage(u) => {
182 self.usage = Some(u.clone());
183 }
184 StreamDelta::Done { stop_reason } => {
185 self.stop_reason = *stop_reason;
186 }
187 StreamDelta::Error { .. } => {}
188 }
189 }
190
191 #[must_use]
193 pub const fn usage(&self) -> Option<&Usage> {
194 self.usage.as_ref()
195 }
196
197 #[must_use]
199 pub const fn stop_reason(&self) -> Option<&StopReason> {
200 self.stop_reason.as_ref()
201 }
202
203 #[must_use]
208 pub fn into_content_blocks(self) -> Vec<ContentBlock> {
209 let mut blocks: Vec<(usize, ContentBlock)> = Vec::new();
210
211 let mut signatures = self.thinking_signatures;
213 for (idx, thinking) in self.thinking_blocks.into_iter().enumerate() {
214 if !thinking.is_empty() {
215 let signature = signatures.remove(&idx).filter(|s| !s.is_empty());
216 blocks.push((
217 idx,
218 ContentBlock::Thinking {
219 thinking,
220 signature,
221 },
222 ));
223 }
224 }
225
226 for (idx, data) in self.redacted_thinking_blocks {
228 blocks.push((idx, ContentBlock::RedactedThinking { data }));
229 }
230
231 for (idx, text) in self.text_blocks.into_iter().enumerate() {
233 if !text.is_empty() {
234 blocks.push((idx, ContentBlock::Text { text }));
235 }
236 }
237
238 for tool in self.tool_uses {
240 let input: serde_json::Value =
241 serde_json::from_str(&tool.input_json).unwrap_or_else(|_| serde_json::json!({}));
242 blocks.push((
243 tool.block_index,
244 ContentBlock::ToolUse {
245 id: tool.id,
246 name: tool.name,
247 input,
248 thought_signature: tool.thought_signature,
249 },
250 ));
251 }
252
253 blocks.sort_by_key(|(idx, _)| *idx);
255
256 blocks.into_iter().map(|(_, block)| block).collect()
257 }
258
259 pub const fn take_usage(&mut self) -> Option<Usage> {
261 self.usage.take()
262 }
263
264 pub const fn take_stop_reason(&mut self) -> Option<StopReason> {
266 self.stop_reason.take()
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_accumulator_text_deltas() {
276 let mut acc = StreamAccumulator::new();
277
278 acc.apply(&StreamDelta::TextDelta {
279 delta: "Hello".to_string(),
280 block_index: 0,
281 });
282 acc.apply(&StreamDelta::TextDelta {
283 delta: " world".to_string(),
284 block_index: 0,
285 });
286
287 let blocks = acc.into_content_blocks();
288 assert_eq!(blocks.len(), 1);
289 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello world"));
290 }
291
292 #[test]
293 fn test_accumulator_multiple_text_blocks() {
294 let mut acc = StreamAccumulator::new();
295
296 acc.apply(&StreamDelta::TextDelta {
297 delta: "First".to_string(),
298 block_index: 0,
299 });
300 acc.apply(&StreamDelta::TextDelta {
301 delta: "Second".to_string(),
302 block_index: 1,
303 });
304
305 let blocks = acc.into_content_blocks();
306 assert_eq!(blocks.len(), 2);
307 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "First"));
308 assert!(matches!(&blocks[1], ContentBlock::Text { text } if text == "Second"));
309 }
310
311 #[test]
312 fn test_accumulator_tool_use() {
313 let mut acc = StreamAccumulator::new();
314
315 acc.apply(&StreamDelta::ToolUseStart {
316 id: "call_123".to_string(),
317 name: "read_file".to_string(),
318 block_index: 0,
319 thought_signature: None,
320 });
321 acc.apply(&StreamDelta::ToolInputDelta {
322 id: "call_123".to_string(),
323 delta: r#"{"path":"#.to_string(),
324 block_index: 0,
325 });
326 acc.apply(&StreamDelta::ToolInputDelta {
327 id: "call_123".to_string(),
328 delta: r#""test.txt"}"#.to_string(),
329 block_index: 0,
330 });
331
332 let blocks = acc.into_content_blocks();
333 assert_eq!(blocks.len(), 1);
334 match &blocks[0] {
335 ContentBlock::ToolUse {
336 id, name, input, ..
337 } => {
338 assert_eq!(id, "call_123");
339 assert_eq!(name, "read_file");
340 assert_eq!(input["path"], "test.txt");
341 }
342 _ => panic!("Expected ToolUse block"),
343 }
344 }
345
346 #[test]
347 fn test_accumulator_mixed_content() {
348 let mut acc = StreamAccumulator::new();
349
350 acc.apply(&StreamDelta::TextDelta {
351 delta: "Let me read that file.".to_string(),
352 block_index: 0,
353 });
354 acc.apply(&StreamDelta::ToolUseStart {
355 id: "call_456".to_string(),
356 name: "read_file".to_string(),
357 block_index: 1,
358 thought_signature: None,
359 });
360 acc.apply(&StreamDelta::ToolInputDelta {
361 id: "call_456".to_string(),
362 delta: r#"{"path":"file.txt"}"#.to_string(),
363 block_index: 1,
364 });
365 acc.apply(&StreamDelta::Usage(Usage {
366 input_tokens: 100,
367 output_tokens: 50,
368 }));
369 acc.apply(&StreamDelta::Done {
370 stop_reason: Some(StopReason::ToolUse),
371 });
372
373 assert!(acc.usage().is_some());
374 assert_eq!(acc.usage().map(|u| u.input_tokens), Some(100));
375 assert!(matches!(acc.stop_reason(), Some(StopReason::ToolUse)));
376
377 let blocks = acc.into_content_blocks();
378 assert_eq!(blocks.len(), 2);
379 assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
380 assert!(matches!(&blocks[1], ContentBlock::ToolUse { .. }));
381 }
382
383 #[test]
384 fn test_accumulator_invalid_tool_json() {
385 let mut acc = StreamAccumulator::new();
386
387 acc.apply(&StreamDelta::ToolUseStart {
388 id: "call_789".to_string(),
389 name: "test_tool".to_string(),
390 block_index: 0,
391 thought_signature: None,
392 });
393 acc.apply(&StreamDelta::ToolInputDelta {
394 id: "call_789".to_string(),
395 delta: "invalid json {".to_string(),
396 block_index: 0,
397 });
398
399 let blocks = acc.into_content_blocks();
400 assert_eq!(blocks.len(), 1);
401 match &blocks[0] {
402 ContentBlock::ToolUse { input, .. } => {
403 assert!(input.is_object());
404 }
405 _ => panic!("Expected ToolUse block"),
406 }
407 }
408
409 #[test]
410 fn test_accumulator_empty() {
411 let acc = StreamAccumulator::new();
412 let blocks = acc.into_content_blocks();
413 assert!(blocks.is_empty());
414 }
415
416 #[test]
417 fn test_accumulator_skips_empty_text() {
418 let mut acc = StreamAccumulator::new();
419
420 acc.apply(&StreamDelta::TextDelta {
421 delta: String::new(),
422 block_index: 0,
423 });
424 acc.apply(&StreamDelta::TextDelta {
425 delta: "Hello".to_string(),
426 block_index: 1,
427 });
428
429 let blocks = acc.into_content_blocks();
430 assert_eq!(blocks.len(), 1);
431 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello"));
432 }
433}