1use crate::llm::{ContentBlock, StopReason, Usage};
8use futures::Stream;
9use std::collections::HashMap;
10use std::pin::Pin;
11
12#[derive(Debug, Clone)]
17pub enum StreamDelta {
18 TextDelta {
20 delta: String,
22 block_index: usize,
24 },
25
26 ThinkingDelta {
28 delta: String,
30 block_index: usize,
32 },
33
34 ToolUseStart {
36 id: String,
38 name: String,
40 block_index: usize,
42 thought_signature: Option<String>,
44 },
45
46 ToolInputDelta {
48 id: String,
50 delta: String,
52 block_index: usize,
54 },
55
56 Usage(Usage),
58
59 Done {
61 stop_reason: Option<StopReason>,
63 },
64
65 SignatureDelta {
67 delta: String,
69 block_index: usize,
71 },
72
73 RedactedThinking {
75 data: String,
77 block_index: usize,
79 },
80
81 Error {
83 message: String,
85 recoverable: bool,
87 },
88}
89
90pub type StreamBox<'a> = Pin<Box<dyn Stream<Item = anyhow::Result<StreamDelta>> + Send + 'a>>;
92
93#[derive(Debug, Default)]
98pub struct StreamAccumulator {
99 text_blocks: Vec<String>,
101 thinking_blocks: Vec<String>,
103 thinking_signatures: HashMap<usize, String>,
105 redacted_thinking_blocks: Vec<(usize, String)>,
107 tool_uses: Vec<ToolUseAccumulator>,
109 usage: Option<Usage>,
111 stop_reason: Option<StopReason>,
113}
114
115#[derive(Debug, Default)]
117pub struct ToolUseAccumulator {
118 pub id: String,
120 pub name: String,
122 pub input_json: String,
124 pub block_index: usize,
126 pub thought_signature: Option<String>,
128}
129
130impl StreamAccumulator {
131 #[must_use]
133 pub fn new() -> Self {
134 Self::default()
135 }
136
137 pub fn apply(&mut self, delta: &StreamDelta) {
139 match delta {
140 StreamDelta::TextDelta { delta, block_index } => {
141 while self.text_blocks.len() <= *block_index {
142 self.text_blocks.push(String::new());
143 }
144 self.text_blocks[*block_index].push_str(delta);
145 }
146 StreamDelta::ThinkingDelta { delta, block_index } => {
147 while self.thinking_blocks.len() <= *block_index {
148 self.thinking_blocks.push(String::new());
149 }
150 self.thinking_blocks[*block_index].push_str(delta);
151 }
152 StreamDelta::ToolUseStart {
153 id,
154 name,
155 block_index,
156 thought_signature,
157 } => {
158 self.tool_uses.push(ToolUseAccumulator {
159 id: id.clone(),
160 name: name.clone(),
161 input_json: String::new(),
162 block_index: *block_index,
163 thought_signature: thought_signature.clone(),
164 });
165 }
166 StreamDelta::ToolInputDelta { id, delta, .. } => {
167 if let Some(tool) = self.tool_uses.iter_mut().find(|t| t.id == *id) {
168 tool.input_json.push_str(delta);
169 }
170 }
171 StreamDelta::SignatureDelta { delta, block_index } => {
172 self.thinking_signatures
173 .entry(*block_index)
174 .or_default()
175 .push_str(delta);
176 }
177 StreamDelta::RedactedThinking { data, block_index } => {
178 self.redacted_thinking_blocks
179 .push((*block_index, data.clone()));
180 }
181 StreamDelta::Usage(u) => {
182 self.usage = Some(u.clone());
183 }
184 StreamDelta::Done { stop_reason } => {
185 self.stop_reason = *stop_reason;
186 }
187 StreamDelta::Error { .. } => {}
188 }
189 }
190
191 #[must_use]
193 pub const fn usage(&self) -> Option<&Usage> {
194 self.usage.as_ref()
195 }
196
197 #[must_use]
199 pub const fn stop_reason(&self) -> Option<&StopReason> {
200 self.stop_reason.as_ref()
201 }
202
203 #[must_use]
208 pub fn into_content_blocks(self) -> Vec<ContentBlock> {
209 let mut blocks: Vec<(usize, ContentBlock)> = Vec::new();
210
211 let mut signatures = self.thinking_signatures;
213 for (idx, thinking) in self.thinking_blocks.into_iter().enumerate() {
214 if !thinking.is_empty() {
215 let signature = signatures.remove(&idx).filter(|s| !s.is_empty());
216 blocks.push((
217 idx,
218 ContentBlock::Thinking {
219 thinking,
220 signature,
221 },
222 ));
223 }
224 }
225
226 for (idx, data) in self.redacted_thinking_blocks {
228 blocks.push((idx, ContentBlock::RedactedThinking { data }));
229 }
230
231 for (idx, text) in self.text_blocks.into_iter().enumerate() {
233 if !text.is_empty() {
234 blocks.push((idx, ContentBlock::Text { text }));
235 }
236 }
237
238 for tool in self.tool_uses {
240 let input: serde_json::Value =
241 serde_json::from_str(&tool.input_json).unwrap_or_else(|_| serde_json::json!({}));
242 blocks.push((
243 tool.block_index,
244 ContentBlock::ToolUse {
245 id: tool.id,
246 name: tool.name,
247 input,
248 thought_signature: tool.thought_signature,
249 },
250 ));
251 }
252
253 blocks.sort_by_key(|(idx, _)| *idx);
255
256 blocks.into_iter().map(|(_, block)| block).collect()
257 }
258
259 pub const fn take_usage(&mut self) -> Option<Usage> {
261 self.usage.take()
262 }
263
264 pub const fn take_stop_reason(&mut self) -> Option<StopReason> {
266 self.stop_reason.take()
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_accumulator_text_deltas() {
276 let mut acc = StreamAccumulator::new();
277
278 acc.apply(&StreamDelta::TextDelta {
279 delta: "Hello".to_string(),
280 block_index: 0,
281 });
282 acc.apply(&StreamDelta::TextDelta {
283 delta: " world".to_string(),
284 block_index: 0,
285 });
286
287 let blocks = acc.into_content_blocks();
288 assert_eq!(blocks.len(), 1);
289 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello world"));
290 }
291
292 #[test]
293 fn test_accumulator_multiple_text_blocks() {
294 let mut acc = StreamAccumulator::new();
295
296 acc.apply(&StreamDelta::TextDelta {
297 delta: "First".to_string(),
298 block_index: 0,
299 });
300 acc.apply(&StreamDelta::TextDelta {
301 delta: "Second".to_string(),
302 block_index: 1,
303 });
304
305 let blocks = acc.into_content_blocks();
306 assert_eq!(blocks.len(), 2);
307 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "First"));
308 assert!(matches!(&blocks[1], ContentBlock::Text { text } if text == "Second"));
309 }
310
311 #[test]
312 fn test_accumulator_thinking_signature() {
313 let mut acc = StreamAccumulator::new();
314
315 acc.apply(&StreamDelta::ThinkingDelta {
316 delta: "Reasoning".to_string(),
317 block_index: 0,
318 });
319 acc.apply(&StreamDelta::SignatureDelta {
320 delta: "sig_123".to_string(),
321 block_index: 0,
322 });
323
324 let blocks = acc.into_content_blocks();
325 assert_eq!(blocks.len(), 1);
326 assert!(matches!(
327 &blocks[0],
328 ContentBlock::Thinking { thinking, signature }
329 if thinking == "Reasoning" && signature.as_deref() == Some("sig_123")
330 ));
331 }
332
333 #[test]
334 fn test_accumulator_tool_use() {
335 let mut acc = StreamAccumulator::new();
336
337 acc.apply(&StreamDelta::ToolUseStart {
338 id: "call_123".to_string(),
339 name: "read_file".to_string(),
340 block_index: 0,
341 thought_signature: None,
342 });
343 acc.apply(&StreamDelta::ToolInputDelta {
344 id: "call_123".to_string(),
345 delta: r#"{"path":"#.to_string(),
346 block_index: 0,
347 });
348 acc.apply(&StreamDelta::ToolInputDelta {
349 id: "call_123".to_string(),
350 delta: r#""test.txt"}"#.to_string(),
351 block_index: 0,
352 });
353
354 let blocks = acc.into_content_blocks();
355 assert_eq!(blocks.len(), 1);
356 match &blocks[0] {
357 ContentBlock::ToolUse {
358 id, name, input, ..
359 } => {
360 assert_eq!(id, "call_123");
361 assert_eq!(name, "read_file");
362 assert_eq!(input["path"], "test.txt");
363 }
364 _ => panic!("Expected ToolUse block"),
365 }
366 }
367
368 #[test]
369 fn test_accumulator_mixed_content() {
370 let mut acc = StreamAccumulator::new();
371
372 acc.apply(&StreamDelta::TextDelta {
373 delta: "Let me read that file.".to_string(),
374 block_index: 0,
375 });
376 acc.apply(&StreamDelta::ToolUseStart {
377 id: "call_456".to_string(),
378 name: "read_file".to_string(),
379 block_index: 1,
380 thought_signature: None,
381 });
382 acc.apply(&StreamDelta::ToolInputDelta {
383 id: "call_456".to_string(),
384 delta: r#"{"path":"file.txt"}"#.to_string(),
385 block_index: 1,
386 });
387 acc.apply(&StreamDelta::Usage(Usage {
388 input_tokens: 100,
389 output_tokens: 50,
390 }));
391 acc.apply(&StreamDelta::Done {
392 stop_reason: Some(StopReason::ToolUse),
393 });
394
395 assert!(acc.usage().is_some());
396 assert_eq!(acc.usage().map(|u| u.input_tokens), Some(100));
397 assert!(matches!(acc.stop_reason(), Some(StopReason::ToolUse)));
398
399 let blocks = acc.into_content_blocks();
400 assert_eq!(blocks.len(), 2);
401 assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
402 assert!(matches!(&blocks[1], ContentBlock::ToolUse { .. }));
403 }
404
405 #[test]
406 fn test_accumulator_invalid_tool_json() {
407 let mut acc = StreamAccumulator::new();
408
409 acc.apply(&StreamDelta::ToolUseStart {
410 id: "call_789".to_string(),
411 name: "test_tool".to_string(),
412 block_index: 0,
413 thought_signature: None,
414 });
415 acc.apply(&StreamDelta::ToolInputDelta {
416 id: "call_789".to_string(),
417 delta: "invalid json {".to_string(),
418 block_index: 0,
419 });
420
421 let blocks = acc.into_content_blocks();
422 assert_eq!(blocks.len(), 1);
423 match &blocks[0] {
424 ContentBlock::ToolUse { input, .. } => {
425 assert!(input.is_object());
426 }
427 _ => panic!("Expected ToolUse block"),
428 }
429 }
430
431 #[test]
432 fn test_accumulator_empty() {
433 let acc = StreamAccumulator::new();
434 let blocks = acc.into_content_blocks();
435 assert!(blocks.is_empty());
436 }
437
438 #[test]
439 fn test_accumulator_skips_empty_text() {
440 let mut acc = StreamAccumulator::new();
441
442 acc.apply(&StreamDelta::TextDelta {
443 delta: String::new(),
444 block_index: 0,
445 });
446 acc.apply(&StreamDelta::TextDelta {
447 delta: "Hello".to_string(),
448 block_index: 1,
449 });
450
451 let blocks = acc.into_content_blocks();
452 assert_eq!(blocks.len(), 1);
453 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello"));
454 }
455}