claude_agent/client/
streaming.rs1use bytes::Bytes;
4use futures::Stream;
5use pin_project_lite::pin_project;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use super::recovery::StreamRecoveryState;
10use crate::Result;
11use crate::types::{Citation, ContentDelta, StreamEvent};
12
13#[derive(Debug, Clone)]
14pub enum StreamItem {
15 Event(StreamEvent),
16 Text(String),
17 Thinking(String),
18 Citation(Citation),
19}
20
21pin_project! {
22 pub struct StreamParser<S> {
23 #[pin]
24 inner: S,
25 buffer: Vec<u8>,
26 pos: usize,
27 }
28}
29
30impl<S> StreamParser<S>
31where
32 S: Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
33{
34 pub fn new(inner: S) -> Self {
35 Self {
36 inner,
37 buffer: Vec::with_capacity(4096),
38 pos: 0,
39 }
40 }
41
42 #[inline]
43 fn find_delimiter(buf: &[u8]) -> Option<usize> {
44 buf.windows(2).position(|w| w == b"\n\n")
45 }
46
47 fn extract_json_data(event_block: &str) -> Option<&str> {
48 for line in event_block.lines() {
49 let line = line.trim();
50 if let Some(json_str) = line.strip_prefix("data: ") {
51 let json_str = json_str.trim();
52 if json_str == "[DONE]"
53 || json_str.contains("\"type\": \"ping\"")
54 || json_str.contains("\"type\":\"ping\"")
55 {
56 return None;
57 }
58 if !json_str.is_empty() {
59 return Some(json_str);
60 }
61 }
62 }
63 None
64 }
65
66 fn parse_event(event_block: &str) -> Option<StreamEvent> {
67 let trimmed = event_block.trim();
68 if trimmed.is_empty() || trimmed.starts_with(':') {
69 return None;
70 }
71 let json_str = Self::extract_json_data(event_block)?;
72 serde_json::from_str::<StreamEvent>(json_str)
73 .inspect_err(|e| {
74 tracing::warn!("Failed to parse stream event: {} - data: {}", e, json_str)
75 })
76 .ok()
77 }
78}
79
80impl<S> Stream for StreamParser<S>
81where
82 S: Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
83{
84 type Item = Result<StreamItem>;
85
86 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
87 let mut this = self.project();
88
89 loop {
90 let search_slice = &this.buffer[*this.pos..];
91 if let Some(rel_pos) = Self::find_delimiter(search_slice) {
92 let start_pos = *this.pos;
93 let end_pos = start_pos + rel_pos;
94 let event_block = match std::str::from_utf8(&this.buffer[start_pos..end_pos]) {
95 Ok(s) => s,
96 Err(e) => {
97 return Poll::Ready(Some(Err(crate::Error::Config(format!(
98 "Invalid UTF-8 in event: {}",
99 e
100 )))));
101 }
102 };
103
104 let event = Self::parse_event(event_block);
105 *this.pos = end_pos + 2;
106
107 if this.buffer.len() > 8192 && *this.pos > this.buffer.len() / 2 {
108 this.buffer.drain(..*this.pos);
109 *this.pos = 0;
110 }
111
112 if let Some(event) = event {
113 let item = match &event {
114 StreamEvent::ContentBlockDelta {
115 delta: ContentDelta::TextDelta { text },
116 ..
117 } => StreamItem::Text(text.clone()),
118 StreamEvent::ContentBlockDelta {
119 delta: ContentDelta::ThinkingDelta { thinking },
120 ..
121 } => StreamItem::Thinking(thinking.clone()),
122 StreamEvent::ContentBlockDelta {
123 delta: ContentDelta::CitationsDelta { citation },
124 ..
125 } => StreamItem::Citation(citation.clone()),
126 _ => StreamItem::Event(event),
127 };
128 return Poll::Ready(Some(Ok(item)));
129 }
130 continue;
131 }
132
133 match this.inner.as_mut().poll_next(cx) {
134 Poll::Ready(Some(Ok(bytes))) => {
135 if *this.pos > 0 && this.buffer.len() + bytes.len() > 16384 {
136 this.buffer.drain(..*this.pos);
137 *this.pos = 0;
138 }
139 this.buffer.extend_from_slice(&bytes);
140 }
141 Poll::Ready(Some(Err(e))) => {
142 return Poll::Ready(Some(Err(crate::Error::Network(e))));
143 }
144 Poll::Ready(None) => {
145 if *this.pos < this.buffer.len() {
146 let remaining = match std::str::from_utf8(&this.buffer[*this.pos..]) {
147 Ok(s) => s,
148 Err(_) => return Poll::Ready(None),
149 };
150 if let Some(event) = Self::parse_event(remaining) {
151 return Poll::Ready(Some(Ok(StreamItem::Event(event))));
152 }
153 }
154 return Poll::Ready(None);
155 }
156 Poll::Pending => return Poll::Pending,
157 }
158 }
159 }
160}
161
162pin_project! {
163 pub struct RecoverableStream<S> {
164 #[pin]
165 inner: StreamParser<S>,
166 recovery: StreamRecoveryState,
167 current_block_type: Option<BlockType>,
168 }
169}
170
171#[derive(Debug, Clone, Copy)]
172enum BlockType {
173 Text,
174 Thinking,
175 ToolUse,
176}
177
178impl<S> RecoverableStream<S>
179where
180 S: Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
181{
182 pub fn new(inner: S) -> Self {
183 Self {
184 inner: StreamParser::new(inner),
185 recovery: StreamRecoveryState::new(),
186 current_block_type: None,
187 }
188 }
189
190 pub fn recovery_state(&self) -> &StreamRecoveryState {
191 &self.recovery
192 }
193
194 pub fn take_recovery_state(self) -> StreamRecoveryState {
195 self.recovery
196 }
197}
198
199impl<S> Stream for RecoverableStream<S>
200where
201 S: Stream<Item = std::result::Result<Bytes, reqwest::Error>>,
202{
203 type Item = Result<StreamItem>;
204
205 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
206 let this = self.project();
207
208 match this.inner.poll_next(cx) {
209 Poll::Ready(Some(Ok(item))) => {
210 match &item {
211 StreamItem::Text(text) => {
212 *this.current_block_type = Some(BlockType::Text);
213 this.recovery.append_text(text);
214 }
215 StreamItem::Thinking(thinking) => {
216 *this.current_block_type = Some(BlockType::Thinking);
217 this.recovery.append_thinking(thinking);
218 }
219 StreamItem::Event(event) => match event {
220 StreamEvent::ContentBlockStart {
221 content_block: crate::types::ContentBlock::ToolUse(tu),
222 ..
223 } => {
224 *this.current_block_type = Some(BlockType::ToolUse);
225 this.recovery.start_tool_use(tu.id.clone(), tu.name.clone());
226 }
227 StreamEvent::ContentBlockDelta {
228 delta: ContentDelta::InputJsonDelta { partial_json },
229 ..
230 } => {
231 this.recovery.append_tool_json(partial_json);
232 }
233 StreamEvent::ContentBlockDelta {
234 delta: ContentDelta::SignatureDelta { signature },
235 ..
236 } => {
237 this.recovery.append_signature(signature);
238 }
239 StreamEvent::ContentBlockStop { .. } => {
240 match this.current_block_type.take() {
241 Some(BlockType::Text) => this.recovery.complete_text_block(),
242 Some(BlockType::Thinking) => {
243 this.recovery.complete_thinking_block()
244 }
245 Some(BlockType::ToolUse) => this.recovery.complete_tool_use_block(),
246 None => {}
247 }
248 }
249 _ => {}
250 },
251 StreamItem::Citation(_) => {}
252 }
253 Poll::Ready(Some(Ok(item)))
254 }
255 other => other,
256 }
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 type EmptyStream = futures::stream::Empty<std::result::Result<Bytes, reqwest::Error>>;
265
266 #[test]
267 fn test_parse_simple_data() {
268 let data = r#"data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#;
269 let event = StreamParser::<EmptyStream>::parse_event(data);
270 assert!(event.is_some());
271 }
272
273 #[test]
274 fn test_parse_event_with_type() {
275 let data = "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hi\"}}";
276 let event = StreamParser::<EmptyStream>::parse_event(data);
277 assert!(event.is_some());
278 }
279
280 #[test]
281 fn test_parse_message_start() {
282 let data = r#"event: message_start
283data: {"type":"message_start","message":{"model":"claude-sonnet-4-5","id":"msg_123","type":"message","role":"assistant","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":1}}}"#;
284 let event = StreamParser::<EmptyStream>::parse_event(data);
285 assert!(event.is_some());
286 assert!(matches!(event, Some(StreamEvent::MessageStart { .. })));
287 }
288
289 #[test]
290 fn test_skip_done_marker() {
291 let data = "data: [DONE]";
292 let event = StreamParser::<EmptyStream>::parse_event(data);
293 assert!(event.is_none());
294 }
295
296 #[test]
297 fn test_skip_ping_event() {
298 let data = "event: ping\ndata: {\"type\": \"ping\"}";
299 let event = StreamParser::<EmptyStream>::parse_event(data);
300 assert!(event.is_none());
301 }
302
303 #[test]
304 fn test_skip_empty_block() {
305 assert!(StreamParser::<EmptyStream>::parse_event("").is_none());
306 assert!(StreamParser::<EmptyStream>::parse_event(" \n ").is_none());
307 }
308
309 #[test]
310 fn test_skip_comment() {
311 let data = ": this is a comment";
312 let event = StreamParser::<EmptyStream>::parse_event(data);
313 assert!(event.is_none());
314 }
315
316 #[test]
317 fn test_extract_json_data() {
318 let json = StreamParser::<EmptyStream>::extract_json_data("data: {\"foo\":\"bar\"}");
319 assert_eq!(json, Some("{\"foo\":\"bar\"}"));
320
321 let json =
322 StreamParser::<EmptyStream>::extract_json_data("event: test\ndata: {\"foo\":\"bar\"}");
323 assert_eq!(json, Some("{\"foo\":\"bar\"}"));
324
325 let json = StreamParser::<EmptyStream>::extract_json_data("data: [DONE]");
326 assert!(json.is_none());
327
328 let json = StreamParser::<EmptyStream>::extract_json_data(
329 "event: ping\ndata: {\"type\": \"ping\"}",
330 );
331 assert!(json.is_none());
332 }
333}