Skip to main content

codex_convert_proxy/util/
sse.rs

1//! SSE (Server-Sent Events) parsing and serialization utilities.
2//!
3//! This module provides utilities for parsing SSE format which is used
4//! in streaming responses.
5
6use bytes::{Bytes, BytesMut};
7
8/// SSE event with optional event type and data.
9#[derive(Debug, Clone)]
10pub struct SseEvent {
11    /// Event type (e.g., "response.created", "response.output_text.delta")
12    pub event_type: Option<String>,
13    /// Event data (JSON string)
14    pub data: String,
15}
16
17/// SSE parse error.
18#[derive(Debug, Clone)]
19pub enum SseParseError {
20    UnterminatedJson,
21    MissingDelimiter,
22    InvalidUtf8,
23}
24
25impl std::fmt::Display for SseParseError {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        match self {
28            SseParseError::UnterminatedJson => write!(f, "unterminated JSON object"),
29            SseParseError::MissingDelimiter => write!(f, "missing SSE event delimiter"),
30            SseParseError::InvalidUtf8 => write!(f, "invalid UTF-8"),
31        }
32    }
33}
34
35/// Iterator over SSE events in a text stream.
36pub struct SseEventIterator<'a> {
37    text: &'a str,
38    position: usize,
39}
40
41impl<'a> SseEventIterator<'a> {
42    /// Create a new iterator over SSE events.
43    pub fn new(text: &'a str) -> Self {
44        Self { text, position: 0 }
45    }
46
47    /// Get the current position in the text.
48    pub fn position(&self) -> usize {
49        self.position
50    }
51
52    /// Skip past an incomplete event at current position.
53    /// Returns the new position, or text.len() if already at end.
54    pub fn skip_incomplete_event(&mut self) -> usize {
55        // Try to find \n\n after current position to skip to
56        if let Some(next_delim) = self.text[self.position..].find("\n\n") {
57            self.position += next_delim + 2;
58        } else {
59            // No \n\n found, skip to end of text
60            self.position = self.text.len();
61        }
62        self.position
63    }
64
65    /// Parse the next SSE event from the text.
66    /// Returns None if no more events, or Some(Err) on parse error.
67    pub fn next_event(&mut self) -> Option<Result<SseEvent, SseParseError>> {
68        let base_pos = self.position;
69        let text = &self.text[base_pos..];
70
71        // Find "data:" prefix (with or without trailing space)
72        let data_start = text.find("data:")?;
73
74        // Look for "event:" line before this "data:" line
75        let event_type = if data_start > 0 {
76            // Get the text before "data:"
77            let before_data = &text[..data_start];
78            // Look for "event:" in the preceding lines (search backwards)
79            let mut result = None;
80            for line in before_data.lines().rev() {
81                let line = line.trim();
82                if let Some(stripped) = line.strip_prefix("event:") {
83                    result = Some(stripped.trim().to_string());
84                    break;
85                }
86                // If we hit a non-empty, non-event line, stop looking
87                if !line.is_empty() && !line.starts_with("event:") {
88                    break;
89                }
90            }
91            result
92        } else {
93            None
94        };
95
96        let after_prefix = data_start + 5; // after "data:"
97        if after_prefix >= text.len() {
98            return None;
99        }
100
101        let mut value_start = after_prefix;
102        while value_start < text.len() {
103            let c = text.as_bytes()[value_start];
104            if c == b' ' || c == b'\t' || c == b'\n' || c == b'\r' {
105                value_start += 1;
106            } else {
107                break;
108            }
109        }
110        if value_start >= text.len() {
111            return None;
112        }
113
114        let rest = &text[value_start..];
115
116        // Handle "[DONE]"
117        if rest.starts_with("[DONE]") {
118            self.position = base_pos + value_start + 6;
119            // Skip \n\n if present
120            if self.position + 2 <= self.text.len()
121                && &self.text[self.position..self.position + 2] == "\n\n"
122            {
123                self.position += 2;
124            }
125            return Some(Ok(SseEvent {
126                event_type: None, // [DONE] doesn't have an event type
127                data: "[DONE]".to_string(),
128            }));
129        }
130
131        // Find JSON end by brace matching
132        let json_end = match find_json_end(rest) {
133            Some(pos) => pos,
134            None => return Some(Err(SseParseError::UnterminatedJson)),
135        };
136
137        // Find \n\n SSE delimiter after JSON
138        let after_json = &rest[json_end..];
139        match after_json.find("\n\n") {
140            Some(delimiter_pos) => {
141                let json_content = &rest[..json_end];
142                self.position = base_pos + value_start + json_end + delimiter_pos + 2;
143                Some(Ok(SseEvent {
144                    event_type,
145                    data: json_content.to_string(),
146                }))
147            }
148            None => Some(Err(SseParseError::MissingDelimiter)),
149        }
150    }
151}
152
153/// Parse SSE format text into events.
154///
155/// SSE format example:
156/// - `event: response.created\ndata: {"id": "resp_123"}`
157/// - `event: response.output_text.delta\ndata: {"delta": "hello"}`
158///
159/// Events are separated by double newlines (`\n\n`).
160///
161/// Returns (events, parse_end_position) where parse_end_position is the byte
162/// offset in `text` where parsing stopped (either at end of text, or after
163/// skipping an incomplete event).
164pub fn parse_sse(text: &str) -> (Vec<SseEvent>, usize) {
165    let mut events = Vec::new();
166    let mut iter = SseEventIterator::new(text);
167
168    while let Some(result) = iter.next_event() {
169        match result {
170            Ok(event) => events.push(event),
171            Err(e) => {
172                tracing::debug!("SSE parse error: {}, skipping incomplete event", e);
173                // Skip past the incomplete event so we don't re-parse it
174                iter.skip_incomplete_event();
175                break;
176            }
177        }
178    }
179
180    // If we parsed all events without error, position will be at end of text
181    let end_pos = if !events.is_empty() && iter.position() >= text.len() {
182        text.len()
183    } else {
184        iter.position()
185    };
186
187    (events, end_pos)
188}
189
190/// Find the end of a JSON object/array in text, handling nested braces
191/// and escaped characters in strings.
192fn find_json_end(text: &str) -> Option<usize> {
193    let mut brace_depth = 0;
194    let mut in_string = false;
195    let mut escaped = false;
196
197    for (i, c) in text.char_indices() {
198        if escaped {
199            escaped = false;
200            continue;
201        }
202        match c {
203            '\\' if in_string => {
204                escaped = true;
205            }
206            '"' => {
207                in_string = !in_string;
208            }
209            '{' | '[' if !in_string => {
210                brace_depth += 1;
211            }
212            '}' | ']' if !in_string => {
213                brace_depth -= 1;
214                if brace_depth == 0 {
215                    return Some(i + c.len_utf8());
216                }
217            }
218            _ => {}
219        }
220    }
221    None
222}
223
224/// Serialize a single SSE event to string format.
225pub fn serialize_sse(event: &SseEvent) -> String {
226    let mut result = String::new();
227
228    if let Some(ref et) = event.event_type {
229        result.push_str("event: ");
230        result.push_str(et);
231        result.push('\n');
232    }
233
234    result.push_str("data: ");
235    result.push_str(&event.data);
236    result.push_str("\n\n");
237
238    result
239}
240
241/// Collect body frames into a single Bytes buffer.
242///
243/// This is a helper function to aggregate multiple body chunks
244/// into one continuous buffer for easier SSE parsing.
245pub fn collect_frames(frames: &[Bytes]) -> Bytes {
246    if frames.is_empty() {
247        return Bytes::new();
248    }
249
250    if frames.len() == 1 {
251        return frames[0].clone();
252    }
253
254    // Calculate total length
255    let total_len: usize = frames.iter().map(|f| f.len()).sum();
256
257    let mut result = BytesMut::with_capacity(total_len);
258    for frame in frames {
259        result.extend_from_slice(frame);
260    }
261
262    result.freeze()
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    #[test]
270    fn test_parse_single_event() {
271        let text = "data: {\"type\": \"response.created\", \"id\": \"resp_123\"}\n\n";
272        let (events, _) = parse_sse(text);
273
274        assert_eq!(events.len(), 1);
275        assert_eq!(events[0].event_type, None);
276        assert_eq!(
277            events[0].data,
278            "{\"type\": \"response.created\", \"id\": \"resp_123\"}"
279        );
280    }
281
282    #[test]
283    fn test_parse_event_with_type() {
284        let text = "event: response.created\ndata: {\"id\": \"resp_123\"}\n\n";
285        let (events, _) = parse_sse(text);
286
287        assert_eq!(events.len(), 1);
288        assert_eq!(
289            events[0].event_type,
290            Some("response.created".to_string())
291        );
292        assert_eq!(events[0].data, "{\"id\": \"resp_123\"}");
293    }
294
295    #[test]
296    fn test_parse_event_without_space_after_data_colon() {
297        let text = "event: response.created\ndata:{\"id\":\"resp_123\"}\n\n";
298        let (events, _) = parse_sse(text);
299        assert_eq!(events.len(), 1);
300        assert_eq!(events[0].event_type, Some("response.created".to_string()));
301        assert_eq!(events[0].data, "{\"id\":\"resp_123\"}");
302    }
303
304    #[test]
305    fn test_parse_event_with_newline_after_data_colon() {
306        let text = "event: response.created\ndata:\n{\"id\":\"resp_123\"}\n\n";
307        let (events, _) = parse_sse(text);
308        assert_eq!(events.len(), 1);
309        assert_eq!(events[0].event_type, Some("response.created".to_string()));
310        assert_eq!(events[0].data, "{\"id\":\"resp_123\"}");
311    }
312
313    #[test]
314    fn test_parse_done_event() {
315        // [DONE] event has no event type, just data: [DONE]
316        // The [DONE] string goes into the data field, event_type is None
317        let text = "data: [DONE]\n\n";
318        let (events, _) = parse_sse(text);
319
320        assert_eq!(events.len(), 1);
321        // event_type is None because there's no "event:" line in SSE
322        assert_eq!(events[0].event_type, None);
323        // data contains "[DONE]"
324        assert_eq!(events[0].data, "[DONE]");
325    }
326
327    #[test]
328    fn test_parse_multiple_events() {
329        let text = "event: response.created\ndata: {\"id\": \"1\"}\n\nevent: response.output_text.delta\ndata: {\"delta\": \"hello\"}\n\n";
330        let (events, _) = parse_sse(text);
331
332        assert_eq!(events.len(), 2);
333        assert_eq!(
334            events[0].event_type,
335            Some("response.created".to_string())
336        );
337        assert_eq!(
338            events[1].event_type,
339            Some("response.output_text.delta".to_string())
340        );
341    }
342
343    #[test]
344    fn test_parse_empty_data() {
345        let text = "event: done\ndata: \n\n";
346        let (events, _) = parse_sse(text);
347
348        // Empty data should not create an event
349        assert_eq!(events.len(), 0);
350    }
351
352    #[test]
353    fn test_serialize_sse() {
354        let event = SseEvent {
355            event_type: Some("response.created".to_string()),
356            data: "{\"id\": \"resp_123\"}".to_string(),
357        };
358
359        let result = serialize_sse(&event);
360        assert!(result.contains("event: response.created\n"));
361        assert!(result.contains("data: {\"id\": \"resp_123\"}\n\n"));
362    }
363
364    #[test]
365    fn test_collect_frames_empty() {
366        let frames: [Bytes; 0] = [];
367        let result = collect_frames(&frames);
368        assert!(result.is_empty());
369    }
370
371    #[test]
372    fn test_collect_frames_single() {
373        let frames = vec![Bytes::from("hello")];
374        let result = collect_frames(&frames);
375        assert_eq!(&result[..], b"hello");
376    }
377
378    #[test]
379    fn test_collect_frames_multiple() {
380        let frames = vec![
381            Bytes::from("hello"),
382            Bytes::from(" world"),
383            Bytes::from("!"),
384        ];
385        let result = collect_frames(&frames);
386        assert_eq!(&result[..], b"hello world!");
387    }
388
389    #[test]
390    fn test_find_json_end() {
391        // Simple object
392        assert_eq!(find_json_end(r#"{"key": "value"}"#), Some(16));
393
394        // Nested object
395        assert_eq!(
396            find_json_end(r#"{"outer": {"inner": "value"}}"#),
397            Some(29)
398        );
399
400        // Array - "[1, 2, 3]" is 9 chars
401        assert_eq!(find_json_end(r#"[1, 2, 3]"#), Some(9));
402
403        // Empty
404        assert_eq!(find_json_end(""), None);
405
406        // Unterminated
407        assert_eq!(find_json_end(r#"{"key": "value"#), None);
408    }
409}