Skip to main content

codineer_api/
sse.rs

1use crate::error::ApiError;
2use crate::types::StreamEvent;
3use serde_json::Value;
4
5/// Top-level `type` values we deserialize today; anything else is ignored for forward compatibility.
6const KNOWN_STREAM_EVENT_TYPES: &[&str] = &[
7    "message_start",
8    "message_delta",
9    "content_block_start",
10    "content_block_delta",
11    "content_block_stop",
12    "message_stop",
13];
14
15#[derive(Debug, Default)]
16pub struct SseParser {
17    buffer: Vec<u8>,
18}
19
20impl SseParser {
21    #[must_use]
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    const MAX_BUFFER_SIZE: usize = 16 * 1024 * 1024;
27
28    pub fn push(&mut self, chunk: &[u8]) -> Result<Vec<StreamEvent>, ApiError> {
29        if self.buffer.len() + chunk.len() > Self::MAX_BUFFER_SIZE {
30            self.buffer.clear();
31            return Err(ApiError::ResponsePayloadTooLarge {
32                limit: Self::MAX_BUFFER_SIZE,
33            });
34        }
35        self.buffer.extend_from_slice(chunk);
36        let mut events = Vec::new();
37
38        while let Some(frame) = self.next_frame() {
39            if let Some(event) = parse_frame(&frame)? {
40                events.push(event);
41            }
42        }
43
44        Ok(events)
45    }
46
47    pub fn finish(&mut self) -> Result<Vec<StreamEvent>, ApiError> {
48        if self.buffer.is_empty() {
49            return Ok(Vec::new());
50        }
51
52        let trailing = std::mem::take(&mut self.buffer);
53        match parse_frame(&String::from_utf8_lossy(&trailing))? {
54            Some(event) => Ok(vec![event]),
55            None => Ok(Vec::new()),
56        }
57    }
58
59    fn next_frame(&mut self) -> Option<String> {
60        let separator = self
61            .buffer
62            .windows(2)
63            .position(|window| window == b"\n\n")
64            .map(|position| (position, 2))
65            .or_else(|| {
66                self.buffer
67                    .windows(4)
68                    .position(|window| window == b"\r\n\r\n")
69                    .map(|position| (position, 4))
70            })?;
71
72        let (position, separator_len) = separator;
73        let frame = self
74            .buffer
75            .drain(..position + separator_len)
76            .collect::<Vec<_>>();
77        let frame_len = frame.len().saturating_sub(separator_len);
78        Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned())
79    }
80}
81
82pub fn parse_frame(frame: &str) -> Result<Option<StreamEvent>, ApiError> {
83    let trimmed = frame.trim();
84    if trimmed.is_empty() {
85        return Ok(None);
86    }
87
88    let mut data_lines = Vec::new();
89    let mut event_name: Option<&str> = None;
90
91    for line in trimmed.lines() {
92        if line.starts_with(':') {
93            continue;
94        }
95        if let Some(name) = line.strip_prefix("event:") {
96            event_name = Some(name.trim());
97            continue;
98        }
99        if let Some(data) = line.strip_prefix("data:") {
100            data_lines.push(data.trim_start());
101        }
102    }
103
104    if matches!(event_name, Some("ping")) {
105        return Ok(None);
106    }
107
108    if data_lines.is_empty() {
109        return Ok(None);
110    }
111
112    let payload = data_lines.join("\n");
113    if payload == "[DONE]" {
114        return Ok(None);
115    }
116
117    let value: Value = serde_json::from_str(&payload).map_err(ApiError::from)?;
118
119    // Anthropic Messages SSE: `type: "error"` must surface as an error, not be dropped as unknown.
120    if value.get("type").and_then(Value::as_str) == Some("error") {
121        let err_obj = value.get("error");
122        let error_type = err_obj
123            .and_then(|e| e.get("type"))
124            .and_then(Value::as_str)
125            .map(str::to_string);
126        let message = err_obj
127            .and_then(|e| e.get("message"))
128            .and_then(Value::as_str)
129            .map(str::to_string)
130            .unwrap_or_else(|| "unknown streaming error".to_string());
131        return Err(ApiError::StreamApplicationError {
132            error_type,
133            message,
134        });
135    }
136
137    let unknown_top_level = matches!(
138        value.get("type").and_then(Value::as_str),
139        Some(t) if !KNOWN_STREAM_EVENT_TYPES.contains(&t)
140    );
141    match serde_json::from_value::<StreamEvent>(value) {
142        Ok(event) => Ok(Some(event)),
143        Err(err) => {
144            // Forward-compatible: ignore only unknown *top-level* event kinds. A known `type` with
145            // a malformed payload must still surface as `Json` so we do not drop real bugs.
146            if unknown_top_level {
147                return Ok(None);
148            }
149            Err(ApiError::from(err))
150        }
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::{parse_frame, SseParser};
157    use crate::error::ApiError;
158    use crate::types::{ContentBlockDelta, MessageDelta, OutputContentBlock, StreamEvent, Usage};
159
160    #[test]
161    fn parses_single_frame() {
162        let frame = concat!(
163            "event: content_block_start\n",
164            "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"Hi\"}}\n\n"
165        );
166
167        let event = parse_frame(frame).expect("frame should parse");
168        assert_eq!(
169            event,
170            Some(StreamEvent::ContentBlockStart(
171                crate::types::ContentBlockStartEvent {
172                    index: 0,
173                    content_block: OutputContentBlock::Text {
174                        text: "Hi".to_string(),
175                    },
176                },
177            ))
178        );
179    }
180
181    #[test]
182    fn parses_chunked_stream() {
183        let mut parser = SseParser::new();
184        let first = b"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hel";
185        let second = b"lo\"}}\n\n";
186
187        assert!(parser
188            .push(first)
189            .expect("first chunk should buffer")
190            .is_empty());
191        let events = parser.push(second).expect("second chunk should parse");
192
193        assert_eq!(
194            events,
195            vec![StreamEvent::ContentBlockDelta(
196                crate::types::ContentBlockDeltaEvent {
197                    index: 0,
198                    delta: ContentBlockDelta::TextDelta {
199                        text: "Hello".to_string(),
200                    },
201                }
202            )]
203        );
204    }
205
206    #[test]
207    fn ignores_ping_and_done() {
208        let mut parser = SseParser::new();
209        let payload = concat!(
210            ": keepalive\n",
211            "event: ping\n",
212            "data: {\"type\":\"ping\"}\n\n",
213            "event: message_delta\n",
214            "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}\n\n",
215            "event: message_stop\n",
216            "data: {\"type\":\"message_stop\"}\n\n",
217            "data: [DONE]\n\n"
218        );
219
220        let events = parser
221            .push(payload.as_bytes())
222            .expect("parser should succeed");
223        assert_eq!(
224            events,
225            vec![
226                StreamEvent::MessageDelta(crate::types::MessageDeltaEvent {
227                    delta: MessageDelta {
228                        stop_reason: Some("tool_use".to_string()),
229                        stop_sequence: None,
230                    },
231                    usage: Usage {
232                        input_tokens: 1,
233                        cache_creation_input_tokens: 0,
234                        cache_read_input_tokens: 0,
235                        output_tokens: 2,
236                    },
237                }),
238                StreamEvent::MessageStop(crate::types::MessageStopEvent {}),
239            ]
240        );
241    }
242
243    #[test]
244    fn ignores_data_less_event_frames() {
245        let frame = "event: ping\n\n";
246        let event = parse_frame(frame).expect("frame without data should be ignored");
247        assert_eq!(event, None);
248    }
249
250    #[test]
251    fn parses_split_json_across_data_lines() {
252        let frame = concat!(
253            "event: content_block_delta\n",
254            "data: {\"type\":\"content_block_delta\",\"index\":0,\n",
255            "data: \"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n"
256        );
257
258        let event = parse_frame(frame).expect("frame should parse");
259        assert_eq!(
260            event,
261            Some(StreamEvent::ContentBlockDelta(
262                crate::types::ContentBlockDeltaEvent {
263                    index: 0,
264                    delta: ContentBlockDelta::TextDelta {
265                        text: "Hello".to_string(),
266                    },
267                }
268            ))
269        );
270    }
271
272    #[test]
273    fn parses_thinking_content_block_start() {
274        let frame = concat!(
275            "event: content_block_start\n",
276            "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\",\"signature\":null}}\n\n"
277        );
278
279        let event = parse_frame(frame).expect("frame should parse");
280        assert_eq!(
281            event,
282            Some(StreamEvent::ContentBlockStart(
283                crate::types::ContentBlockStartEvent {
284                    index: 0,
285                    content_block: OutputContentBlock::Thinking {
286                        thinking: String::new(),
287                        signature: None,
288                    },
289                },
290            ))
291        );
292    }
293
294    #[test]
295    fn parses_thinking_related_deltas() {
296        let thinking = concat!(
297            "event: content_block_delta\n",
298            "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"step 1\"}}\n\n"
299        );
300        let signature = concat!(
301            "event: content_block_delta\n",
302            "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"signature_delta\",\"signature\":\"sig_123\"}}\n\n"
303        );
304
305        let thinking_event = parse_frame(thinking).expect("thinking delta should parse");
306        let signature_event = parse_frame(signature).expect("signature delta should parse");
307
308        assert_eq!(
309            thinking_event,
310            Some(StreamEvent::ContentBlockDelta(
311                crate::types::ContentBlockDeltaEvent {
312                    index: 0,
313                    delta: ContentBlockDelta::ThinkingDelta {
314                        thinking: "step 1".to_string(),
315                    },
316                }
317            ))
318        );
319        assert_eq!(
320            signature_event,
321            Some(StreamEvent::ContentBlockDelta(
322                crate::types::ContentBlockDeltaEvent {
323                    index: 0,
324                    delta: ContentBlockDelta::SignatureDelta {
325                        signature: "sig_123".to_string(),
326                    },
327                }
328            ))
329        );
330    }
331
332    #[test]
333    fn rejects_oversized_buffer() {
334        let mut parser = SseParser::new();
335        let big_chunk = vec![b'x'; SseParser::MAX_BUFFER_SIZE + 1];
336        let err = parser.push(&big_chunk).unwrap_err();
337        assert!(err.to_string().contains("limit"));
338    }
339
340    #[test]
341    fn stream_error_frame_returns_err() {
342        let frame = concat!(
343            "event: error\n",
344            "data: {\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"try again\"}}\n\n"
345        );
346        let err = parse_frame(frame).unwrap_err();
347        assert!(err.is_retryable());
348        match &err {
349            ApiError::StreamApplicationError {
350                error_type,
351                message,
352            } => {
353                assert_eq!(error_type.as_deref(), Some("overloaded_error"));
354                assert_eq!(message, "try again");
355            }
356            other => panic!("expected StreamApplicationError, got {other:?}"),
357        }
358    }
359
360    #[test]
361    fn unknown_event_type_is_skipped() {
362        let frame = concat!(
363            "event: future_event\n",
364            "data: {\"type\":\"hypothetical_future_event\",\"index\":0}\n\n"
365        );
366        assert_eq!(parse_frame(frame).expect("skip unknown"), None);
367    }
368}