Skip to main content

albert_api/
sse.rs

1use crate::error::ApiError;
2use crate::types::StreamEvent;
3
4#[derive(Debug, Default)]
5pub struct SseParser {
6    buffer: Vec<u8>,
7}
8
9impl SseParser {
10    #[must_use]
11    pub fn new() -> Self {
12        Self::default()
13    }
14
15    pub fn push(&mut self, chunk: &[u8]) -> Result<Vec<StreamEvent>, ApiError> {
16        self.buffer.extend_from_slice(chunk);
17        let mut events = Vec::new();
18
19        while let Some(frame) = self.next_frame() {
20            if let Some(event) = parse_frame(&frame)? {
21                events.push(event);
22            }
23        }
24
25        Ok(events)
26    }
27
28    pub fn finish(&mut self) -> Result<Vec<StreamEvent>, ApiError> {
29        if self.buffer.is_empty() {
30            return Ok(Vec::new());
31        }
32
33        let trailing = std::mem::take(&mut self.buffer);
34        match parse_frame(&String::from_utf8_lossy(&trailing))? {
35            Some(event) => Ok(vec![event]),
36            None => Ok(Vec::new()),
37        }
38    }
39
40    fn next_frame(&mut self) -> Option<String> {
41        let separator = self
42            .buffer
43            .windows(2)
44            .position(|window| window == b"\n\n")
45            .map(|position| (position, 2))
46            .or_else(|| {
47                self.buffer
48                    .windows(4)
49                    .position(|window| window == b"\r\n\r\n")
50                    .map(|position| (position, 4))
51            })?;
52
53        let (position, separator_len) = separator;
54        let frame = self
55            .buffer
56            .drain(..position + separator_len)
57            .collect::<Vec<_>>();
58        let frame_len = frame.len().saturating_sub(separator_len);
59        Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned())
60    }
61}
62
63pub fn parse_frame(frame: &str) -> Result<Option<StreamEvent>, ApiError> {
64    let trimmed = frame.trim();
65    if trimmed.is_empty() {
66        return Ok(None);
67    }
68
69    let mut data_lines = Vec::new();
70    let mut event_name: Option<&str> = None;
71
72    for line in trimmed.lines() {
73        if line.starts_with(':') {
74            continue;
75        }
76        if let Some(name) = line.strip_prefix("event:") {
77            event_name = Some(name.trim());
78            continue;
79        }
80        if let Some(data) = line.strip_prefix("data:") {
81            data_lines.push(data.trim_start());
82        }
83    }
84
85    if matches!(event_name, Some("ping")) {
86        return Ok(None);
87    }
88
89    if data_lines.is_empty() {
90        return Ok(None);
91    }
92
93    let payload = data_lines.join("\n");
94    if payload == "[DONE]" {
95        return Ok(None);
96    }
97
98    serde_json::from_str::<StreamEvent>(&payload)
99        .map(Some)
100        .map_err(ApiError::from)
101}
102
103#[cfg(test)]
104mod tests {
105    use super::{parse_frame, SseParser};
106    use crate::types::{ContentBlockDelta, MessageDelta, OutputContentBlock, StreamEvent, Usage};
107
108    #[test]
109    fn parses_single_frame() {
110        let frame = concat!(
111            "event: content_block_start\n",
112            "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"Hi\"}}\n\n"
113        );
114
115        let event = parse_frame(frame).expect("frame should parse");
116        assert_eq!(
117            event,
118            Some(StreamEvent::ContentBlockStart(
119                crate::types::ContentBlockStartEvent {
120                    index: 0,
121                    content_block: OutputContentBlock::Text {
122                        text: "Hi".to_string(),
123                    },
124                },
125            ))
126        );
127    }
128
129    #[test]
130    fn parses_chunked_stream() {
131        let mut parser = SseParser::new();
132        let first = b"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hel";
133        let second = b"lo\"}}\n\n";
134
135        assert!(parser
136            .push(first)
137            .expect("first chunk should buffer")
138            .is_empty());
139        let events = parser.push(second).expect("second chunk should parse");
140
141        assert_eq!(
142            events,
143            vec![StreamEvent::ContentBlockDelta(
144                crate::types::ContentBlockDeltaEvent {
145                    index: 0,
146                    delta: ContentBlockDelta::TextDelta {
147                        text: "Hello".to_string(),
148                    },
149                }
150            )]
151        );
152    }
153
154    #[test]
155    fn ignores_ping_and_done() {
156        let mut parser = SseParser::new();
157        let payload = concat!(
158            ": keepalive\n",
159            "event: ping\n",
160            "data: {\"type\":\"ping\"}\n\n",
161            "event: message_delta\n",
162            "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}\n\n",
163            "event: message_stop\n",
164            "data: {\"type\":\"message_stop\"}\n\n",
165            "data: [DONE]\n\n"
166        );
167
168        let events = parser
169            .push(payload.as_bytes())
170            .expect("parser should succeed");
171        assert_eq!(
172            events,
173            vec![
174                StreamEvent::MessageDelta(crate::types::MessageDeltaEvent {
175                    delta: MessageDelta {
176                        stop_reason: Some("tool_use".to_string()),
177                        stop_sequence: None,
178                    },
179                    usage: Usage {
180                        input_tokens: 1,
181                        cache_creation_input_tokens: 0,
182                        cache_read_input_tokens: 0,
183                        output_tokens: 2,
184                    },
185                }),
186                StreamEvent::MessageStop(crate::types::MessageStopEvent {}),
187            ]
188        );
189    }
190
191    #[test]
192    fn ignores_data_less_event_frames() {
193        let frame = "event: ping\n\n";
194        let event = parse_frame(frame).expect("frame without data should be ignored");
195        assert_eq!(event, None);
196    }
197
198    #[test]
199    fn parses_split_json_across_data_lines() {
200        let frame = concat!(
201            "event: content_block_delta\n",
202            "data: {\"type\":\"content_block_delta\",\"index\":0,\n",
203            "data: \"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n"
204        );
205
206        let event = parse_frame(frame).expect("frame should parse");
207        assert_eq!(
208            event,
209            Some(StreamEvent::ContentBlockDelta(
210                crate::types::ContentBlockDeltaEvent {
211                    index: 0,
212                    delta: ContentBlockDelta::TextDelta {
213                        text: "Hello".to_string(),
214                    },
215                }
216            ))
217        );
218    }
219}