claude_agent/client/
streaming.rs

1//! SSE streaming support for the Anthropic Messages API.
2
3use bytes::Bytes;
4use futures::Stream;
5use pin_project_lite::pin_project;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use super::recovery::StreamRecoveryState;
10use crate::Result;
11use crate::types::{Citation, ContentDelta, StreamEvent};
12
13#[derive(Debug, Clone)]
14pub enum StreamItem {
15    Event(StreamEvent),
16    Text(String),
17    Thinking(String),
18    Citation(Citation),
19}
20
21pin_project! {
22    pub struct StreamParser<S> {
23        #[pin]
24        inner: S,
25        buffer: Vec<u8>,
26        pos: usize,
27    }
28}
29
30impl<S> StreamParser<S>
31where
32    S: Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
33{
34    pub fn new(inner: S) -> Self {
35        Self {
36            inner,
37            buffer: Vec::with_capacity(4096),
38            pos: 0,
39        }
40    }
41
42    #[inline]
43    fn find_delimiter(buf: &[u8]) -> Option<usize> {
44        buf.windows(2).position(|w| w == b"\n\n")
45    }
46
47    fn extract_json_data(event_block: &str) -> Option<&str> {
48        for line in event_block.lines() {
49            let line = line.trim();
50            if let Some(json_str) = line.strip_prefix("data: ") {
51                let json_str = json_str.trim();
52                if json_str == "[DONE]"
53                    || json_str.contains("\"type\": \"ping\"")
54                    || json_str.contains("\"type\":\"ping\"")
55                {
56                    return None;
57                }
58                if !json_str.is_empty() {
59                    return Some(json_str);
60                }
61            }
62        }
63        None
64    }
65
66    fn parse_event(event_block: &str) -> Option<StreamEvent> {
67        let trimmed = event_block.trim();
68        if trimmed.is_empty() || trimmed.starts_with(':') {
69            return None;
70        }
71        let json_str = Self::extract_json_data(event_block)?;
72        serde_json::from_str::<StreamEvent>(json_str)
73            .inspect_err(|e| {
74                tracing::warn!("Failed to parse stream event: {} - data: {}", e, json_str)
75            })
76            .ok()
77    }
78}
79
80impl<S> Stream for StreamParser<S>
81where
82    S: Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
83{
84    type Item = Result<StreamItem>;
85
86    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
87        let mut this = self.project();
88
89        loop {
90            let search_slice = &this.buffer[*this.pos..];
91            if let Some(rel_pos) = Self::find_delimiter(search_slice) {
92                let start_pos = *this.pos;
93                let end_pos = start_pos + rel_pos;
94                let event_block = match std::str::from_utf8(&this.buffer[start_pos..end_pos]) {
95                    Ok(s) => s,
96                    Err(e) => {
97                        return Poll::Ready(Some(Err(crate::Error::Config(format!(
98                            "Invalid UTF-8 in event: {}",
99                            e
100                        )))));
101                    }
102                };
103
104                let event = Self::parse_event(event_block);
105                *this.pos = end_pos + 2;
106
107                if this.buffer.len() > 8192 && *this.pos > this.buffer.len() / 2 {
108                    this.buffer.drain(..*this.pos);
109                    *this.pos = 0;
110                }
111
112                if let Some(event) = event {
113                    let item = match &event {
114                        StreamEvent::ContentBlockDelta {
115                            delta: ContentDelta::TextDelta { text },
116                            ..
117                        } => StreamItem::Text(text.clone()),
118                        StreamEvent::ContentBlockDelta {
119                            delta: ContentDelta::ThinkingDelta { thinking },
120                            ..
121                        } => StreamItem::Thinking(thinking.clone()),
122                        StreamEvent::ContentBlockDelta {
123                            delta: ContentDelta::CitationsDelta { citation },
124                            ..
125                        } => StreamItem::Citation(citation.clone()),
126                        _ => StreamItem::Event(event),
127                    };
128                    return Poll::Ready(Some(Ok(item)));
129                }
130                continue;
131            }
132
133            match this.inner.as_mut().poll_next(cx) {
134                Poll::Ready(Some(Ok(bytes))) => {
135                    if *this.pos > 0 && this.buffer.len() + bytes.len() > 16384 {
136                        this.buffer.drain(..*this.pos);
137                        *this.pos = 0;
138                    }
139                    this.buffer.extend_from_slice(&bytes);
140                }
141                Poll::Ready(Some(Err(e))) => {
142                    return Poll::Ready(Some(Err(crate::Error::Network(e))));
143                }
144                Poll::Ready(None) => {
145                    if *this.pos < this.buffer.len() {
146                        let remaining = match std::str::from_utf8(&this.buffer[*this.pos..]) {
147                            Ok(s) => s,
148                            Err(_) => return Poll::Ready(None),
149                        };
150                        if let Some(event) = Self::parse_event(remaining) {
151                            return Poll::Ready(Some(Ok(StreamItem::Event(event))));
152                        }
153                    }
154                    return Poll::Ready(None);
155                }
156                Poll::Pending => return Poll::Pending,
157            }
158        }
159    }
160}
161
162pin_project! {
163    pub struct RecoverableStream<S> {
164        #[pin]
165        inner: StreamParser<S>,
166        recovery: StreamRecoveryState,
167        current_block_type: Option<BlockType>,
168    }
169}
170
171#[derive(Debug, Clone, Copy)]
172enum BlockType {
173    Text,
174    Thinking,
175    ToolUse,
176}
177
178impl<S> RecoverableStream<S>
179where
180    S: Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
181{
182    pub fn new(inner: S) -> Self {
183        Self {
184            inner: StreamParser::new(inner),
185            recovery: StreamRecoveryState::new(),
186            current_block_type: None,
187        }
188    }
189
190    pub fn recovery_state(&self) -> &StreamRecoveryState {
191        &self.recovery
192    }
193
194    pub fn take_recovery_state(self) -> StreamRecoveryState {
195        self.recovery
196    }
197}
198
199impl<S> Stream for RecoverableStream<S>
200where
201    S: Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
202{
203    type Item = Result<StreamItem>;
204
205    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
206        let this = self.project();
207
208        match this.inner.poll_next(cx) {
209            Poll::Ready(Some(Ok(item))) => {
210                match &item {
211                    StreamItem::Text(text) => {
212                        *this.current_block_type = Some(BlockType::Text);
213                        this.recovery.append_text(text);
214                    }
215                    StreamItem::Thinking(thinking) => {
216                        *this.current_block_type = Some(BlockType::Thinking);
217                        this.recovery.append_thinking(thinking);
218                    }
219                    StreamItem::Event(event) => match event {
220                        StreamEvent::ContentBlockStart {
221                            content_block: crate::types::ContentBlock::ToolUse(tu),
222                            ..
223                        } => {
224                            *this.current_block_type = Some(BlockType::ToolUse);
225                            this.recovery.start_tool_use(tu.id.clone(), tu.name.clone());
226                        }
227                        StreamEvent::ContentBlockDelta {
228                            delta: ContentDelta::InputJsonDelta { partial_json },
229                            ..
230                        } => {
231                            this.recovery.append_tool_json(partial_json);
232                        }
233                        StreamEvent::ContentBlockDelta {
234                            delta: ContentDelta::SignatureDelta { signature },
235                            ..
236                        } => {
237                            this.recovery.append_signature(signature);
238                        }
239                        StreamEvent::ContentBlockStop { .. } => {
240                            match this.current_block_type.take() {
241                                Some(BlockType::Text) => this.recovery.complete_text_block(),
242                                Some(BlockType::Thinking) => {
243                                    this.recovery.complete_thinking_block()
244                                }
245                                Some(BlockType::ToolUse) => this.recovery.complete_tool_use_block(),
246                                None => {}
247                            }
248                        }
249                        _ => {}
250                    },
251                    StreamItem::Citation(_) => {}
252                }
253                Poll::Ready(Some(Ok(item)))
254            }
255            other => other,
256        }
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    type EmptyStream = futures::stream::Empty<std::result::Result<Bytes, reqwest::Error>>;
265
266    #[test]
267    fn test_parse_simple_data() {
268        let data = r#"data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#;
269        let event = StreamParser::<EmptyStream>::parse_event(data);
270        assert!(event.is_some());
271    }
272
273    #[test]
274    fn test_parse_event_with_type() {
275        let data = "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hi\"}}";
276        let event = StreamParser::<EmptyStream>::parse_event(data);
277        assert!(event.is_some());
278    }
279
280    #[test]
281    fn test_parse_message_start() {
282        let data = r#"event: message_start
283data: {"type":"message_start","message":{"model":"claude-sonnet-4-5","id":"msg_123","type":"message","role":"assistant","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":1}}}"#;
284        let event = StreamParser::<EmptyStream>::parse_event(data);
285        assert!(event.is_some());
286        assert!(matches!(event, Some(StreamEvent::MessageStart { .. })));
287    }
288
289    #[test]
290    fn test_skip_done_marker() {
291        let data = "data: [DONE]";
292        let event = StreamParser::<EmptyStream>::parse_event(data);
293        assert!(event.is_none());
294    }
295
296    #[test]
297    fn test_skip_ping_event() {
298        let data = "event: ping\ndata: {\"type\": \"ping\"}";
299        let event = StreamParser::<EmptyStream>::parse_event(data);
300        assert!(event.is_none());
301    }
302
303    #[test]
304    fn test_skip_empty_block() {
305        assert!(StreamParser::<EmptyStream>::parse_event("").is_none());
306        assert!(StreamParser::<EmptyStream>::parse_event("   \n  ").is_none());
307    }
308
309    #[test]
310    fn test_skip_comment() {
311        let data = ": this is a comment";
312        let event = StreamParser::<EmptyStream>::parse_event(data);
313        assert!(event.is_none());
314    }
315
316    #[test]
317    fn test_extract_json_data() {
318        let json = StreamParser::<EmptyStream>::extract_json_data("data: {\"foo\":\"bar\"}");
319        assert_eq!(json, Some("{\"foo\":\"bar\"}"));
320
321        let json =
322            StreamParser::<EmptyStream>::extract_json_data("event: test\ndata: {\"foo\":\"bar\"}");
323        assert_eq!(json, Some("{\"foo\":\"bar\"}"));
324
325        let json = StreamParser::<EmptyStream>::extract_json_data("data: [DONE]");
326        assert!(json.is_none());
327
328        let json = StreamParser::<EmptyStream>::extract_json_data(
329            "event: ping\ndata: {\"type\": \"ping\"}",
330        );
331        assert!(json.is_none());
332    }
333}