1use agent_sdk_foundation::llm::{ContentBlock, StopReason, Usage};
8#[cfg(any(feature = "openai", feature = "openai-codex"))]
9use bytes::BytesMut;
10use futures::Stream;
11use std::collections::HashMap;
12use std::pin::Pin;
13
14const MAX_BLOCK_INDEX: usize = 4096;
23
24#[cfg(any(feature = "openai", feature = "openai-codex"))]
40#[derive(Debug, Default)]
41pub(crate) struct SseLineBuffer {
42 buf: BytesMut,
43}
44
45#[cfg(any(feature = "openai", feature = "openai-codex"))]
46impl SseLineBuffer {
47 #[must_use]
49 pub(crate) fn new() -> Self {
50 Self::default()
51 }
52
53 pub(crate) fn extend(&mut self, chunk: &[u8]) {
55 self.buf.extend_from_slice(chunk);
56 }
57
58 pub(crate) fn next_line(&mut self) -> Option<String> {
63 let newline = self.buf.iter().position(|&b| b == b'\n')?;
64 let mut line = self.buf.split_to(newline + 1);
65 line.truncate(newline);
66 Some(String::from_utf8_lossy(&line).into_owned())
67 }
68}
69
70#[derive(Debug, Clone)]
75#[non_exhaustive]
76pub enum StreamDelta {
77 TextDelta {
79 delta: String,
81 block_index: usize,
83 },
84
85 ThinkingDelta {
87 delta: String,
89 block_index: usize,
91 },
92
93 ToolUseStart {
95 id: String,
97 name: String,
99 block_index: usize,
101 thought_signature: Option<String>,
103 },
104
105 ToolInputDelta {
107 id: String,
109 delta: String,
111 block_index: usize,
113 },
114
115 Usage(Usage),
117
118 Done {
120 stop_reason: Option<StopReason>,
122 },
123
124 SignatureDelta {
126 delta: String,
128 block_index: usize,
130 },
131
132 RedactedThinking {
134 data: String,
136 block_index: usize,
138 },
139
140 Error {
142 message: String,
144 kind: StreamErrorKind,
149 },
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
161#[non_exhaustive]
162pub enum StreamErrorKind {
163 RateLimited,
165 ServerError,
168 InvalidRequest,
171 Unknown,
181}
182
183impl StreamErrorKind {
184 #[must_use]
188 pub const fn is_recoverable(self) -> bool {
189 matches!(self, Self::RateLimited | Self::ServerError)
190 }
191}
192
193pub type StreamBox<'a> = Pin<Box<dyn Stream<Item = anyhow::Result<StreamDelta>> + Send + 'a>>;
195
196#[derive(Debug, Default)]
201pub struct StreamAccumulator {
202 text_blocks: Vec<String>,
204 thinking_blocks: Vec<String>,
206 thinking_signatures: HashMap<usize, String>,
208 redacted_thinking_blocks: Vec<(usize, String)>,
210 tool_uses: Vec<ToolUseAccumulator>,
212 usage: Option<Usage>,
214 stop_reason: Option<StopReason>,
216}
217
218#[derive(Debug, Default)]
220pub struct ToolUseAccumulator {
221 pub id: String,
223 pub name: String,
225 pub input_json: String,
227 pub block_index: usize,
229 pub thought_signature: Option<String>,
231}
232
233impl StreamAccumulator {
234 #[must_use]
236 pub fn new() -> Self {
237 Self::default()
238 }
239
240 pub fn apply(&mut self, delta: &StreamDelta) {
242 match delta {
243 StreamDelta::TextDelta { delta, block_index } => {
244 if *block_index > MAX_BLOCK_INDEX {
245 log::warn!(
246 "dropping TextDelta with out-of-range block_index {block_index} (max {MAX_BLOCK_INDEX})"
247 );
248 return;
249 }
250 while self.text_blocks.len() <= *block_index {
251 self.text_blocks.push(String::new());
252 }
253 self.text_blocks[*block_index].push_str(delta);
254 }
255 StreamDelta::ThinkingDelta { delta, block_index } => {
256 if *block_index > MAX_BLOCK_INDEX {
257 log::warn!(
258 "dropping ThinkingDelta with out-of-range block_index {block_index} (max {MAX_BLOCK_INDEX})"
259 );
260 return;
261 }
262 while self.thinking_blocks.len() <= *block_index {
263 self.thinking_blocks.push(String::new());
264 }
265 self.thinking_blocks[*block_index].push_str(delta);
266 }
267 StreamDelta::ToolUseStart {
268 id,
269 name,
270 block_index,
271 thought_signature,
272 } => {
273 self.tool_uses.push(ToolUseAccumulator {
274 id: id.clone(),
275 name: name.clone(),
276 input_json: String::new(),
277 block_index: *block_index,
278 thought_signature: thought_signature.clone(),
279 });
280 }
281 StreamDelta::ToolInputDelta { id, delta, .. } => {
282 if let Some(tool) = self.tool_uses.iter_mut().find(|t| t.id == *id) {
283 tool.input_json.push_str(delta);
284 }
285 }
286 StreamDelta::SignatureDelta { delta, block_index } => {
287 self.thinking_signatures
288 .entry(*block_index)
289 .or_default()
290 .push_str(delta);
291 }
292 StreamDelta::RedactedThinking { data, block_index } => {
293 self.redacted_thinking_blocks
294 .push((*block_index, data.clone()));
295 }
296 StreamDelta::Usage(u) => {
297 self.usage = Some(u.clone());
298 }
299 StreamDelta::Done { stop_reason } => {
300 self.stop_reason = *stop_reason;
301 }
302 StreamDelta::Error { .. } => {}
303 }
304 }
305
306 #[must_use]
308 pub const fn usage(&self) -> Option<&Usage> {
309 self.usage.as_ref()
310 }
311
312 #[must_use]
314 pub const fn stop_reason(&self) -> Option<&StopReason> {
315 self.stop_reason.as_ref()
316 }
317
318 #[must_use]
323 pub fn into_content_blocks(self) -> Vec<ContentBlock> {
324 let mut blocks: Vec<(usize, ContentBlock)> = Vec::new();
325
326 let mut signatures = self.thinking_signatures;
328 for (idx, thinking) in self.thinking_blocks.into_iter().enumerate() {
329 if !thinking.is_empty() {
330 let signature = signatures.remove(&idx).filter(|s| !s.is_empty());
331 blocks.push((
332 idx,
333 ContentBlock::Thinking {
334 thinking,
335 signature,
336 },
337 ));
338 }
339 }
340
341 for (idx, data) in self.redacted_thinking_blocks {
343 blocks.push((idx, ContentBlock::RedactedThinking { data }));
344 }
345
346 for (idx, text) in self.text_blocks.into_iter().enumerate() {
348 if !text.is_empty() {
349 blocks.push((idx, ContentBlock::Text { text }));
350 }
351 }
352
353 for tool in self.tool_uses {
355 let input: serde_json::Value =
356 serde_json::from_str(&tool.input_json).unwrap_or_else(|e| {
357 log::warn!(
358 "Failed to parse streamed tool input JSON for tool '{}' (id={}): {} — \
359 input_json ({} bytes): '{}'",
360 tool.name,
361 tool.id,
362 e,
363 tool.input_json.len(),
364 tool.input_json.chars().take(500).collect::<String>(),
365 );
366 serde_json::json!({})
367 });
368 blocks.push((
369 tool.block_index,
370 ContentBlock::ToolUse {
371 id: tool.id,
372 name: tool.name,
373 input,
374 thought_signature: tool.thought_signature,
375 },
376 ));
377 }
378
379 blocks.sort_by_key(|(idx, _)| *idx);
381
382 blocks.into_iter().map(|(_, block)| block).collect()
383 }
384
385 pub const fn take_usage(&mut self) -> Option<Usage> {
387 self.usage.take()
388 }
389
390 pub const fn take_stop_reason(&mut self) -> Option<StopReason> {
392 self.stop_reason.take()
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399
400 #[test]
401 fn test_accumulator_text_deltas() {
402 let mut acc = StreamAccumulator::new();
403
404 acc.apply(&StreamDelta::TextDelta {
405 delta: "Hello".to_string(),
406 block_index: 0,
407 });
408 acc.apply(&StreamDelta::TextDelta {
409 delta: " world".to_string(),
410 block_index: 0,
411 });
412
413 let blocks = acc.into_content_blocks();
414 assert_eq!(blocks.len(), 1);
415 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello world"));
416 }
417
418 #[test]
419 fn test_accumulator_multiple_text_blocks() {
420 let mut acc = StreamAccumulator::new();
421
422 acc.apply(&StreamDelta::TextDelta {
423 delta: "First".to_string(),
424 block_index: 0,
425 });
426 acc.apply(&StreamDelta::TextDelta {
427 delta: "Second".to_string(),
428 block_index: 1,
429 });
430
431 let blocks = acc.into_content_blocks();
432 assert_eq!(blocks.len(), 2);
433 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "First"));
434 assert!(matches!(&blocks[1], ContentBlock::Text { text } if text == "Second"));
435 }
436
437 #[test]
438 fn test_accumulator_thinking_signature() {
439 let mut acc = StreamAccumulator::new();
440
441 acc.apply(&StreamDelta::ThinkingDelta {
442 delta: "Reasoning".to_string(),
443 block_index: 0,
444 });
445 acc.apply(&StreamDelta::SignatureDelta {
446 delta: "sig_123".to_string(),
447 block_index: 0,
448 });
449
450 let blocks = acc.into_content_blocks();
451 assert_eq!(blocks.len(), 1);
452 assert!(matches!(
453 &blocks[0],
454 ContentBlock::Thinking { thinking, signature }
455 if thinking == "Reasoning" && signature.as_deref() == Some("sig_123")
456 ));
457 }
458
459 #[test]
460 fn test_accumulator_tool_use() {
461 let mut acc = StreamAccumulator::new();
462
463 acc.apply(&StreamDelta::ToolUseStart {
464 id: "call_123".to_string(),
465 name: "read_file".to_string(),
466 block_index: 0,
467 thought_signature: None,
468 });
469 acc.apply(&StreamDelta::ToolInputDelta {
470 id: "call_123".to_string(),
471 delta: r#"{"path":"#.to_string(),
472 block_index: 0,
473 });
474 acc.apply(&StreamDelta::ToolInputDelta {
475 id: "call_123".to_string(),
476 delta: r#""test.txt"}"#.to_string(),
477 block_index: 0,
478 });
479
480 let blocks = acc.into_content_blocks();
481 assert_eq!(blocks.len(), 1);
482 match &blocks[0] {
483 ContentBlock::ToolUse {
484 id, name, input, ..
485 } => {
486 assert_eq!(id, "call_123");
487 assert_eq!(name, "read_file");
488 assert_eq!(input["path"], "test.txt");
489 }
490 _ => panic!("Expected ToolUse block"),
491 }
492 }
493
494 #[test]
495 fn test_accumulator_mixed_content() {
496 let mut acc = StreamAccumulator::new();
497
498 acc.apply(&StreamDelta::TextDelta {
499 delta: "Let me read that file.".to_string(),
500 block_index: 0,
501 });
502 acc.apply(&StreamDelta::ToolUseStart {
503 id: "call_456".to_string(),
504 name: "read_file".to_string(),
505 block_index: 1,
506 thought_signature: None,
507 });
508 acc.apply(&StreamDelta::ToolInputDelta {
509 id: "call_456".to_string(),
510 delta: r#"{"path":"file.txt"}"#.to_string(),
511 block_index: 1,
512 });
513 acc.apply(&StreamDelta::Usage(Usage {
514 input_tokens: 100,
515 output_tokens: 50,
516 cached_input_tokens: 0,
517 cache_creation_input_tokens: 0,
518 }));
519 acc.apply(&StreamDelta::Done {
520 stop_reason: Some(StopReason::ToolUse),
521 });
522
523 assert!(acc.usage().is_some());
524 assert_eq!(acc.usage().map(|u| u.input_tokens), Some(100));
525 assert!(matches!(acc.stop_reason(), Some(StopReason::ToolUse)));
526
527 let blocks = acc.into_content_blocks();
528 assert_eq!(blocks.len(), 2);
529 assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
530 assert!(matches!(&blocks[1], ContentBlock::ToolUse { .. }));
531 }
532
533 #[test]
534 fn test_accumulator_invalid_tool_json() {
535 let mut acc = StreamAccumulator::new();
536
537 acc.apply(&StreamDelta::ToolUseStart {
538 id: "call_789".to_string(),
539 name: "test_tool".to_string(),
540 block_index: 0,
541 thought_signature: None,
542 });
543 acc.apply(&StreamDelta::ToolInputDelta {
544 id: "call_789".to_string(),
545 delta: "invalid json {".to_string(),
546 block_index: 0,
547 });
548
549 let blocks = acc.into_content_blocks();
550 assert_eq!(blocks.len(), 1);
551 match &blocks[0] {
552 ContentBlock::ToolUse { input, .. } => {
553 assert!(input.is_object());
554 }
555 _ => panic!("Expected ToolUse block"),
556 }
557 }
558
559 #[test]
560 fn test_accumulator_empty_tool_input_falls_back_to_empty_object() {
561 let mut acc = StreamAccumulator::new();
566
567 acc.apply(&StreamDelta::ToolUseStart {
568 id: "call_empty".to_string(),
569 name: "read".to_string(),
570 block_index: 0,
571 thought_signature: None,
572 });
573 let blocks = acc.into_content_blocks();
576 assert_eq!(blocks.len(), 1);
577 match &blocks[0] {
578 ContentBlock::ToolUse { input, name, .. } => {
579 assert_eq!(name, "read");
580 assert_eq!(input, &serde_json::json!({}));
581 }
582 _ => panic!("Expected ToolUse block"),
583 }
584 }
585
586 #[test]
587 fn test_accumulator_mismatched_delta_id_drops_input() {
588 let mut acc = StreamAccumulator::new();
591
592 acc.apply(&StreamDelta::ToolUseStart {
593 id: "call_A".to_string(),
594 name: "bash".to_string(),
595 block_index: 0,
596 thought_signature: None,
597 });
598 acc.apply(&StreamDelta::ToolInputDelta {
600 id: "call_B".to_string(),
601 delta: r#"{"command":"ls"}"#.to_string(),
602 block_index: 0,
603 });
604
605 let blocks = acc.into_content_blocks();
606 assert_eq!(blocks.len(), 1);
607 match &blocks[0] {
608 ContentBlock::ToolUse { input, .. } => {
609 assert_eq!(input, &serde_json::json!({}));
611 }
612 _ => panic!("Expected ToolUse block"),
613 }
614 }
615
616 #[test]
617 fn test_accumulator_empty() {
618 let acc = StreamAccumulator::new();
619 let blocks = acc.into_content_blocks();
620 assert!(blocks.is_empty());
621 }
622
623 #[test]
624 fn test_accumulator_skips_empty_text() {
625 let mut acc = StreamAccumulator::new();
626
627 acc.apply(&StreamDelta::TextDelta {
628 delta: String::new(),
629 block_index: 0,
630 });
631 acc.apply(&StreamDelta::TextDelta {
632 delta: "Hello".to_string(),
633 block_index: 1,
634 });
635
636 let blocks = acc.into_content_blocks();
637 assert_eq!(blocks.len(), 1);
638 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello"));
639 }
640
641 #[test]
642 fn test_accumulator_ignores_out_of_range_block_index() {
643 let mut acc = StreamAccumulator::new();
647
648 acc.apply(&StreamDelta::TextDelta {
649 delta: "ok".to_string(),
650 block_index: 0,
651 });
652 acc.apply(&StreamDelta::TextDelta {
653 delta: "boom".to_string(),
654 block_index: usize::MAX,
655 });
656 acc.apply(&StreamDelta::ThinkingDelta {
657 delta: "boom".to_string(),
658 block_index: usize::MAX,
659 });
660
661 let blocks = acc.into_content_blocks();
662 assert_eq!(blocks.len(), 1);
663 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "ok"));
664 }
665
666 #[cfg(any(feature = "openai", feature = "openai-codex"))]
667 #[test]
668 fn test_sse_line_buffer_splits_multiple_lines() {
669 let mut buf = SseLineBuffer::new();
670 buf.extend(b"data: one\ndata: two\n");
671 assert_eq!(buf.next_line().as_deref(), Some("data: one"));
672 assert_eq!(buf.next_line().as_deref(), Some("data: two"));
673 assert_eq!(buf.next_line(), None);
674 }
675
676 #[cfg(any(feature = "openai", feature = "openai-codex"))]
677 #[test]
678 fn test_sse_line_buffer_buffers_partial_line_until_newline() {
679 let mut buf = SseLineBuffer::new();
680 buf.extend(b"data: par");
681 assert_eq!(buf.next_line(), None);
682 buf.extend(b"tial\n");
683 assert_eq!(buf.next_line().as_deref(), Some("data: partial"));
684 }
685
686 #[cfg(any(feature = "openai", feature = "openai-codex"))]
687 #[test]
688 fn test_sse_line_buffer_handles_utf8_split_across_chunks() {
689 let mut buf = SseLineBuffer::new();
694 let line = "data: café\n";
695 let bytes = line.as_bytes();
696 let split = bytes.len() - 2; buf.extend(&bytes[..split]);
698 assert_eq!(buf.next_line(), None);
699 buf.extend(&bytes[split..]);
700 assert_eq!(buf.next_line().as_deref(), Some("data: café"));
701 }
702}