claude_agent/client/
streaming.rs

1//! SSE streaming support for the Anthropic Messages API.
2
3use bytes::Bytes;
4use futures::Stream;
5use pin_project_lite::pin_project;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use super::recovery::StreamRecoveryState;
10use crate::Result;
11use crate::types::{Citation, ContentDelta, StreamEvent};
12
13#[derive(Debug, Clone)]
14pub enum StreamItem {
15    Event(StreamEvent),
16    Text(String),
17    Thinking(String),
18    Citation(Citation),
19    ToolUseComplete(crate::types::ToolUseBlock),
20}
21
22pin_project! {
23    pub struct StreamParser<S> {
24        #[pin]
25        inner: S,
26        buffer: Vec<u8>,
27        pos: usize,
28    }
29}
30
31impl<S> StreamParser<S>
32where
33    S: Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
34{
35    pub fn new(inner: S) -> Self {
36        Self {
37            inner,
38            buffer: Vec::with_capacity(4096),
39            pos: 0,
40        }
41    }
42
43    #[inline]
44    fn find_delimiter(buf: &[u8]) -> Option<usize> {
45        buf.windows(2).position(|w| w == b"\n\n")
46    }
47
48    fn extract_json_data(event_block: &str) -> Option<&str> {
49        for line in event_block.lines() {
50            let line = line.trim();
51            if let Some(json_str) = line.strip_prefix("data: ") {
52                let json_str = json_str.trim();
53                if json_str == "[DONE]"
54                    || json_str.contains("\"type\": \"ping\"")
55                    || json_str.contains("\"type\":\"ping\"")
56                {
57                    return None;
58                }
59                if !json_str.is_empty() {
60                    return Some(json_str);
61                }
62            }
63        }
64        None
65    }
66
67    fn parse_event(event_block: &str) -> Option<StreamEvent> {
68        let trimmed = event_block.trim();
69        if trimmed.is_empty() || trimmed.starts_with(':') {
70            return None;
71        }
72        let json_str = Self::extract_json_data(event_block)?;
73        serde_json::from_str::<StreamEvent>(json_str)
74            .inspect_err(|e| {
75                tracing::warn!("Failed to parse stream event: {} - data: {}", e, json_str)
76            })
77            .ok()
78    }
79}
80
81impl<S> Stream for StreamParser<S>
82where
83    S: Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
84{
85    type Item = Result<StreamItem>;
86
87    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
88        let mut this = self.project();
89
90        loop {
91            let search_slice = &this.buffer[*this.pos..];
92            if let Some(rel_pos) = Self::find_delimiter(search_slice) {
93                let start_pos = *this.pos;
94                let end_pos = start_pos + rel_pos;
95                let event_block = match std::str::from_utf8(&this.buffer[start_pos..end_pos]) {
96                    Ok(s) => s,
97                    Err(e) => {
98                        return Poll::Ready(Some(Err(crate::Error::Config(format!(
99                            "Invalid UTF-8 in event: {}",
100                            e
101                        )))));
102                    }
103                };
104
105                let event = Self::parse_event(event_block);
106                *this.pos = end_pos + 2;
107
108                if this.buffer.len() > 8192 && *this.pos > this.buffer.len() / 2 {
109                    this.buffer.drain(..*this.pos);
110                    *this.pos = 0;
111                }
112
113                if let Some(event) = event {
114                    let item = match &event {
115                        StreamEvent::ContentBlockDelta {
116                            delta: ContentDelta::TextDelta { text },
117                            ..
118                        } => StreamItem::Text(text.clone()),
119                        StreamEvent::ContentBlockDelta {
120                            delta: ContentDelta::ThinkingDelta { thinking },
121                            ..
122                        } => StreamItem::Thinking(thinking.clone()),
123                        StreamEvent::ContentBlockDelta {
124                            delta: ContentDelta::CitationsDelta { citation },
125                            ..
126                        } => StreamItem::Citation(citation.clone()),
127                        _ => StreamItem::Event(event),
128                    };
129                    return Poll::Ready(Some(Ok(item)));
130                }
131                continue;
132            }
133
134            match this.inner.as_mut().poll_next(cx) {
135                Poll::Ready(Some(Ok(bytes))) => {
136                    if *this.pos > 0 && this.buffer.len() + bytes.len() > 16384 {
137                        this.buffer.drain(..*this.pos);
138                        *this.pos = 0;
139                    }
140                    this.buffer.extend_from_slice(&bytes);
141                }
142                Poll::Ready(Some(Err(e))) => {
143                    return Poll::Ready(Some(Err(crate::Error::Network(e))));
144                }
145                Poll::Ready(None) => {
146                    if *this.pos < this.buffer.len() {
147                        let remaining = match std::str::from_utf8(&this.buffer[*this.pos..]) {
148                            Ok(s) => s,
149                            Err(_) => return Poll::Ready(None),
150                        };
151                        if let Some(event) = Self::parse_event(remaining) {
152                            return Poll::Ready(Some(Ok(StreamItem::Event(event))));
153                        }
154                    }
155                    return Poll::Ready(None);
156                }
157                Poll::Pending => return Poll::Pending,
158            }
159        }
160    }
161}
162
163pin_project! {
164    pub struct RecoverableStream<S> {
165        #[pin]
166        inner: StreamParser<S>,
167        recovery: StreamRecoveryState,
168        current_block_type: Option<BlockType>,
169    }
170}
171
172#[derive(Debug, Clone, Copy)]
173enum BlockType {
174    Text,
175    Thinking,
176    ToolUse,
177}
178
179impl<S> RecoverableStream<S>
180where
181    S: Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
182{
183    pub fn new(inner: S) -> Self {
184        Self {
185            inner: StreamParser::new(inner),
186            recovery: StreamRecoveryState::new(),
187            current_block_type: None,
188        }
189    }
190
191    pub fn recovery_state(&self) -> &StreamRecoveryState {
192        &self.recovery
193    }
194
195    pub fn take_recovery_state(self) -> StreamRecoveryState {
196        self.recovery
197    }
198}
199
200impl<S> Stream for RecoverableStream<S>
201where
202    S: Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
203{
204    type Item = Result<StreamItem>;
205
206    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
207        let this = self.project();
208
209        match this.inner.poll_next(cx) {
210            Poll::Ready(Some(Ok(item))) => {
211                match &item {
212                    StreamItem::Text(text) => {
213                        *this.current_block_type = Some(BlockType::Text);
214                        this.recovery.append_text(text);
215                    }
216                    StreamItem::Thinking(thinking) => {
217                        *this.current_block_type = Some(BlockType::Thinking);
218                        this.recovery.append_thinking(thinking);
219                    }
220                    StreamItem::ToolUseComplete(_) => {}
221                    StreamItem::Event(event) => match event {
222                        StreamEvent::ContentBlockStart {
223                            content_block: crate::types::ContentBlock::ToolUse(tu),
224                            ..
225                        } => {
226                            *this.current_block_type = Some(BlockType::ToolUse);
227                            this.recovery.start_tool_use(tu.id.clone(), tu.name.clone());
228                        }
229                        StreamEvent::ContentBlockDelta {
230                            delta: ContentDelta::InputJsonDelta { partial_json },
231                            ..
232                        } => {
233                            this.recovery.append_tool_json(partial_json);
234                        }
235                        StreamEvent::ContentBlockDelta {
236                            delta: ContentDelta::SignatureDelta { signature },
237                            ..
238                        } => {
239                            this.recovery.append_signature(signature);
240                        }
241                        StreamEvent::ContentBlockStop { .. } => {
242                            match this.current_block_type.take() {
243                                Some(BlockType::Text) => this.recovery.complete_text_block(),
244                                Some(BlockType::Thinking) => {
245                                    this.recovery.complete_thinking_block()
246                                }
247                                Some(BlockType::ToolUse) => {
248                                    if let Some(tool_use) = this.recovery.complete_tool_use_block()
249                                    {
250                                        return Poll::Ready(Some(Ok(StreamItem::ToolUseComplete(
251                                            tool_use,
252                                        ))));
253                                    }
254                                }
255                                None => {}
256                            }
257                        }
258                        _ => {}
259                    },
260                    StreamItem::Citation(_) => {}
261                }
262                Poll::Ready(Some(Ok(item)))
263            }
264            other => other,
265        }
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    type EmptyStream = futures::stream::Empty<std::result::Result<Bytes, reqwest::Error>>;
274
275    #[test]
276    fn test_parse_simple_data() {
277        let data = r#"data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#;
278        let event = StreamParser::<EmptyStream>::parse_event(data);
279        assert!(event.is_some());
280    }
281
282    #[test]
283    fn test_parse_event_with_type() {
284        let data = "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hi\"}}";
285        let event = StreamParser::<EmptyStream>::parse_event(data);
286        assert!(event.is_some());
287    }
288
289    #[test]
290    fn test_parse_message_start() {
291        let data = r#"event: message_start
292data: {"type":"message_start","message":{"model":"claude-sonnet-4-5","id":"msg_123","type":"message","role":"assistant","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":1}}}"#;
293        let event = StreamParser::<EmptyStream>::parse_event(data);
294        assert!(event.is_some());
295        assert!(matches!(event, Some(StreamEvent::MessageStart { .. })));
296    }
297
298    #[test]
299    fn test_skip_done_marker() {
300        let data = "data: [DONE]";
301        let event = StreamParser::<EmptyStream>::parse_event(data);
302        assert!(event.is_none());
303    }
304
305    #[test]
306    fn test_skip_ping_event() {
307        let data = "event: ping\ndata: {\"type\": \"ping\"}";
308        let event = StreamParser::<EmptyStream>::parse_event(data);
309        assert!(event.is_none());
310    }
311
312    #[test]
313    fn test_skip_empty_block() {
314        assert!(StreamParser::<EmptyStream>::parse_event("").is_none());
315        assert!(StreamParser::<EmptyStream>::parse_event("   \n  ").is_none());
316    }
317
318    #[test]
319    fn test_skip_comment() {
320        let data = ": this is a comment";
321        let event = StreamParser::<EmptyStream>::parse_event(data);
322        assert!(event.is_none());
323    }
324
325    #[test]
326    fn test_extract_json_data() {
327        let json = StreamParser::<EmptyStream>::extract_json_data("data: {\"foo\":\"bar\"}");
328        assert_eq!(json, Some("{\"foo\":\"bar\"}"));
329
330        let json =
331            StreamParser::<EmptyStream>::extract_json_data("event: test\ndata: {\"foo\":\"bar\"}");
332        assert_eq!(json, Some("{\"foo\":\"bar\"}"));
333
334        let json = StreamParser::<EmptyStream>::extract_json_data("data: [DONE]");
335        assert!(json.is_none());
336
337        let json = StreamParser::<EmptyStream>::extract_json_data(
338            "event: ping\ndata: {\"type\": \"ping\"}",
339        );
340        assert!(json.is_none());
341    }
342}