claude_agent/client/
streaming.rs1use bytes::Bytes;
4use futures::Stream;
5use pin_project_lite::pin_project;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use super::recovery::StreamRecoveryState;
10use crate::Result;
11use crate::types::{Citation, ContentDelta, StreamEvent};
12
13#[derive(Debug, Clone)]
14pub enum StreamItem {
15 Event(StreamEvent),
16 Text(String),
17 Thinking(String),
18 Citation(Citation),
19 ToolUseComplete(crate::types::ToolUseBlock),
20}
21
22pin_project! {
23 pub struct StreamParser<S> {
24 #[pin]
25 inner: S,
26 buffer: Vec<u8>,
27 pos: usize,
28 }
29}
30
31impl<S> StreamParser<S>
32where
33 S: Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
34{
35 pub fn new(inner: S) -> Self {
36 Self {
37 inner,
38 buffer: Vec::with_capacity(4096),
39 pos: 0,
40 }
41 }
42
43 #[inline]
44 fn find_delimiter(buf: &[u8]) -> Option<usize> {
45 buf.windows(2).position(|w| w == b"\n\n")
46 }
47
48 fn extract_json_data(event_block: &str) -> Option<&str> {
49 for line in event_block.lines() {
50 let line = line.trim();
51 if let Some(json_str) = line.strip_prefix("data: ") {
52 let json_str = json_str.trim();
53 if json_str == "[DONE]"
54 || json_str.contains("\"type\": \"ping\"")
55 || json_str.contains("\"type\":\"ping\"")
56 {
57 return None;
58 }
59 if !json_str.is_empty() {
60 return Some(json_str);
61 }
62 }
63 }
64 None
65 }
66
67 fn parse_event(event_block: &str) -> Option<StreamEvent> {
68 let trimmed = event_block.trim();
69 if trimmed.is_empty() || trimmed.starts_with(':') {
70 return None;
71 }
72 let json_str = Self::extract_json_data(event_block)?;
73 serde_json::from_str::<StreamEvent>(json_str)
74 .inspect_err(|e| {
75 tracing::warn!("Failed to parse stream event: {} - data: {}", e, json_str)
76 })
77 .ok()
78 }
79}
80
81impl<S> Stream for StreamParser<S>
82where
83 S: Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
84{
85 type Item = Result<StreamItem>;
86
87 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
88 let mut this = self.project();
89
90 loop {
91 let search_slice = &this.buffer[*this.pos..];
92 if let Some(rel_pos) = Self::find_delimiter(search_slice) {
93 let start_pos = *this.pos;
94 let end_pos = start_pos + rel_pos;
95 let event_block = match std::str::from_utf8(&this.buffer[start_pos..end_pos]) {
96 Ok(s) => s,
97 Err(e) => {
98 return Poll::Ready(Some(Err(crate::Error::Config(format!(
99 "Invalid UTF-8 in event: {}",
100 e
101 )))));
102 }
103 };
104
105 let event = Self::parse_event(event_block);
106 *this.pos = end_pos + 2;
107
108 if this.buffer.len() > 8192 && *this.pos > this.buffer.len() / 2 {
109 this.buffer.drain(..*this.pos);
110 *this.pos = 0;
111 }
112
113 if let Some(event) = event {
114 let item = match &event {
115 StreamEvent::ContentBlockDelta {
116 delta: ContentDelta::TextDelta { text },
117 ..
118 } => StreamItem::Text(text.clone()),
119 StreamEvent::ContentBlockDelta {
120 delta: ContentDelta::ThinkingDelta { thinking },
121 ..
122 } => StreamItem::Thinking(thinking.clone()),
123 StreamEvent::ContentBlockDelta {
124 delta: ContentDelta::CitationsDelta { citation },
125 ..
126 } => StreamItem::Citation(citation.clone()),
127 _ => StreamItem::Event(event),
128 };
129 return Poll::Ready(Some(Ok(item)));
130 }
131 continue;
132 }
133
134 match this.inner.as_mut().poll_next(cx) {
135 Poll::Ready(Some(Ok(bytes))) => {
136 if *this.pos > 0 && this.buffer.len() + bytes.len() > 16384 {
137 this.buffer.drain(..*this.pos);
138 *this.pos = 0;
139 }
140 this.buffer.extend_from_slice(&bytes);
141 }
142 Poll::Ready(Some(Err(e))) => {
143 return Poll::Ready(Some(Err(crate::Error::Network(e))));
144 }
145 Poll::Ready(None) => {
146 if *this.pos < this.buffer.len() {
147 let remaining = match std::str::from_utf8(&this.buffer[*this.pos..]) {
148 Ok(s) => s,
149 Err(_) => return Poll::Ready(None),
150 };
151 if let Some(event) = Self::parse_event(remaining) {
152 return Poll::Ready(Some(Ok(StreamItem::Event(event))));
153 }
154 }
155 return Poll::Ready(None);
156 }
157 Poll::Pending => return Poll::Pending,
158 }
159 }
160 }
161}
162
163pin_project! {
164 pub struct RecoverableStream<S> {
165 #[pin]
166 inner: StreamParser<S>,
167 recovery: StreamRecoveryState,
168 current_block_type: Option<BlockType>,
169 }
170}
171
172#[derive(Debug, Clone, Copy)]
173enum BlockType {
174 Text,
175 Thinking,
176 ToolUse,
177}
178
179impl<S> RecoverableStream<S>
180where
181 S: Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
182{
183 pub fn new(inner: S) -> Self {
184 Self {
185 inner: StreamParser::new(inner),
186 recovery: StreamRecoveryState::new(),
187 current_block_type: None,
188 }
189 }
190
191 pub fn recovery_state(&self) -> &StreamRecoveryState {
192 &self.recovery
193 }
194
195 pub fn take_recovery_state(self) -> StreamRecoveryState {
196 self.recovery
197 }
198}
199
200impl<S> Stream for RecoverableStream<S>
201where
202 S: Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
203{
204 type Item = Result<StreamItem>;
205
206 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
207 let this = self.project();
208
209 match this.inner.poll_next(cx) {
210 Poll::Ready(Some(Ok(item))) => {
211 match &item {
212 StreamItem::Text(text) => {
213 *this.current_block_type = Some(BlockType::Text);
214 this.recovery.append_text(text);
215 }
216 StreamItem::Thinking(thinking) => {
217 *this.current_block_type = Some(BlockType::Thinking);
218 this.recovery.append_thinking(thinking);
219 }
220 StreamItem::ToolUseComplete(_) => {}
221 StreamItem::Event(event) => match event {
222 StreamEvent::ContentBlockStart {
223 content_block: crate::types::ContentBlock::ToolUse(tu),
224 ..
225 } => {
226 *this.current_block_type = Some(BlockType::ToolUse);
227 this.recovery.start_tool_use(tu.id.clone(), tu.name.clone());
228 }
229 StreamEvent::ContentBlockDelta {
230 delta: ContentDelta::InputJsonDelta { partial_json },
231 ..
232 } => {
233 this.recovery.append_tool_json(partial_json);
234 }
235 StreamEvent::ContentBlockDelta {
236 delta: ContentDelta::SignatureDelta { signature },
237 ..
238 } => {
239 this.recovery.append_signature(signature);
240 }
241 StreamEvent::ContentBlockStop { .. } => {
242 match this.current_block_type.take() {
243 Some(BlockType::Text) => this.recovery.complete_text_block(),
244 Some(BlockType::Thinking) => {
245 this.recovery.complete_thinking_block()
246 }
247 Some(BlockType::ToolUse) => {
248 if let Some(tool_use) = this.recovery.complete_tool_use_block()
249 {
250 return Poll::Ready(Some(Ok(StreamItem::ToolUseComplete(
251 tool_use,
252 ))));
253 }
254 }
255 None => {}
256 }
257 }
258 _ => {}
259 },
260 StreamItem::Citation(_) => {}
261 }
262 Poll::Ready(Some(Ok(item)))
263 }
264 other => other,
265 }
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 type EmptyStream = futures::stream::Empty<std::result::Result<Bytes, reqwest::Error>>;
274
275 #[test]
276 fn test_parse_simple_data() {
277 let data = r#"data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#;
278 let event = StreamParser::<EmptyStream>::parse_event(data);
279 assert!(event.is_some());
280 }
281
282 #[test]
283 fn test_parse_event_with_type() {
284 let data = "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hi\"}}";
285 let event = StreamParser::<EmptyStream>::parse_event(data);
286 assert!(event.is_some());
287 }
288
289 #[test]
290 fn test_parse_message_start() {
291 let data = r#"event: message_start
292data: {"type":"message_start","message":{"model":"claude-sonnet-4-5","id":"msg_123","type":"message","role":"assistant","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":1}}}"#;
293 let event = StreamParser::<EmptyStream>::parse_event(data);
294 assert!(event.is_some());
295 assert!(matches!(event, Some(StreamEvent::MessageStart { .. })));
296 }
297
298 #[test]
299 fn test_skip_done_marker() {
300 let data = "data: [DONE]";
301 let event = StreamParser::<EmptyStream>::parse_event(data);
302 assert!(event.is_none());
303 }
304
305 #[test]
306 fn test_skip_ping_event() {
307 let data = "event: ping\ndata: {\"type\": \"ping\"}";
308 let event = StreamParser::<EmptyStream>::parse_event(data);
309 assert!(event.is_none());
310 }
311
312 #[test]
313 fn test_skip_empty_block() {
314 assert!(StreamParser::<EmptyStream>::parse_event("").is_none());
315 assert!(StreamParser::<EmptyStream>::parse_event(" \n ").is_none());
316 }
317
318 #[test]
319 fn test_skip_comment() {
320 let data = ": this is a comment";
321 let event = StreamParser::<EmptyStream>::parse_event(data);
322 assert!(event.is_none());
323 }
324
325 #[test]
326 fn test_extract_json_data() {
327 let json = StreamParser::<EmptyStream>::extract_json_data("data: {\"foo\":\"bar\"}");
328 assert_eq!(json, Some("{\"foo\":\"bar\"}"));
329
330 let json =
331 StreamParser::<EmptyStream>::extract_json_data("event: test\ndata: {\"foo\":\"bar\"}");
332 assert_eq!(json, Some("{\"foo\":\"bar\"}"));
333
334 let json = StreamParser::<EmptyStream>::extract_json_data("data: [DONE]");
335 assert!(json.is_none());
336
337 let json = StreamParser::<EmptyStream>::extract_json_data(
338 "event: ping\ndata: {\"type\": \"ping\"}",
339 );
340 assert!(json.is_none());
341 }
342}