Skip to main content

openai_oxide/
streaming.rs

1// SSE stream parser for OpenAI streaming responses
2
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures_core::Stream;
7
8use crate::error::OpenAIError;
9
10/// A stream of parsed SSE events from an OpenAI streaming response.
11///
12/// Wraps a byte stream from reqwest and yields deserialized items.
13pub struct SseStream<T> {
14    inner: Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>,
15    buffer: String,
16    done: bool,
17    _phantom: std::marker::PhantomData<T>,
18}
19
20impl<T> SseStream<T> {
21    pub(crate) fn new(response: reqwest::Response) -> Self {
22        Self {
23            inner: Box::pin(response.bytes_stream()),
24            buffer: String::new(),
25            done: false,
26            _phantom: std::marker::PhantomData,
27        }
28    }
29}
30
31// SAFETY: SseStream has no self-referential data; inner is heap-boxed.
32impl<T> Unpin for SseStream<T> {}
33
34impl<T: serde::de::DeserializeOwned> Stream for SseStream<T> {
35    type Item = Result<T, OpenAIError>;
36
37    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
38        let this = self.get_mut();
39
40        if this.done {
41            return Poll::Ready(None);
42        }
43
44        // Check if we already have a complete event in the buffer
45        if let Some(item) = try_parse_next::<T>(&mut this.buffer, &mut this.done) {
46            return Poll::Ready(Some(item));
47        }
48
49        // Poll for more data from the byte stream
50        match this.inner.as_mut().poll_next(cx) {
51            Poll::Ready(Some(Ok(chunk))) => {
52                this.buffer.push_str(&String::from_utf8_lossy(&chunk));
53                match try_parse_next::<T>(&mut this.buffer, &mut this.done) {
54                    Some(item) => Poll::Ready(Some(item)),
55                    None => {
56                        cx.waker().wake_by_ref();
57                        Poll::Pending
58                    }
59                }
60            }
61            Poll::Ready(Some(Err(e))) => {
62                this.done = true;
63                Poll::Ready(Some(Err(OpenAIError::RequestError(e))))
64            }
65            Poll::Ready(None) => {
66                this.done = true;
67                match try_parse_next::<T>(&mut this.buffer, &mut this.done) {
68                    Some(item) => Poll::Ready(Some(item)),
69                    None => Poll::Ready(None),
70                }
71            }
72            Poll::Pending => Poll::Pending,
73        }
74    }
75}
76
77/// Try to extract and parse the next SSE event from the buffer.
78/// Returns `Some` if an event was found (success or error), `None` if more data is needed.
79fn try_parse_next<T: serde::de::DeserializeOwned>(
80    buffer: &mut String,
81    done: &mut bool,
82) -> Option<Result<T, OpenAIError>> {
83    loop {
84        let newline_pos = buffer.find('\n')?;
85        let line = buffer[..newline_pos].trim_end_matches('\r').to_string();
86        buffer.drain(..=newline_pos);
87
88        // Skip empty lines and comments
89        if line.is_empty() || line.starts_with(':') {
90            continue;
91        }
92
93        // Parse "data: ..." lines
94        if let Some(data) = line
95            .strip_prefix("data: ")
96            .or_else(|| line.strip_prefix("data:"))
97        {
98            let data = data.trim();
99
100            if data == "[DONE]" {
101                *done = true;
102                return None;
103            }
104
105            match serde_json::from_str::<T>(data) {
106                Ok(value) => return Some(Ok(value)),
107                Err(e) => return Some(Err(OpenAIError::JsonError(e))),
108            }
109        }
110
111        // Skip non-data SSE fields (event:, id:, retry:)
112    }
113}
114
115/// Parse SSE lines from raw text and yield data payloads.
116/// Useful for testing without HTTP. Returns items until `[DONE]` or end of input.
117pub fn parse_sse_events<T: serde::de::DeserializeOwned>(raw: &str) -> Vec<Result<T, OpenAIError>> {
118    let mut results = Vec::new();
119    let mut buffer = raw.to_string();
120    if !buffer.ends_with('\n') {
121        buffer.push('\n');
122    }
123    let mut done = false;
124
125    while !done {
126        match try_parse_next::<T>(&mut buffer, &mut done) {
127            Some(item) => results.push(item),
128            None => break,
129        }
130    }
131
132    results
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use crate::types::chat::ChatCompletionChunk;
139
140    #[test]
141    fn test_parse_sse_content_chunks() {
142        let raw = r#"data: {"id":"chatcmpl-1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
143
144data: {"id":"chatcmpl-1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
145
146data: {"id":"chatcmpl-1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]}
147
148data: {"id":"chatcmpl-1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
149
150data: [DONE]
151
152"#;
153
154        let events = parse_sse_events::<ChatCompletionChunk>(raw);
155        assert_eq!(events.len(), 4);
156
157        let chunk0 = events[0].as_ref().unwrap();
158        assert_eq!(
159            chunk0.choices[0].delta.role,
160            Some(crate::types::common::Role::Assistant)
161        );
162
163        let chunk1 = events[1].as_ref().unwrap();
164        assert_eq!(chunk1.choices[0].delta.content.as_deref(), Some("Hello"));
165
166        let chunk2 = events[2].as_ref().unwrap();
167        assert_eq!(chunk2.choices[0].delta.content.as_deref(), Some(" world"));
168
169        let chunk3 = events[3].as_ref().unwrap();
170        assert_eq!(chunk3.choices[0].finish_reason.as_deref(), Some("stop"));
171    }
172
173    #[test]
174    fn test_parse_sse_with_comments_and_empty_lines() {
175        let raw = ": this is a comment
176data: {\"id\":\"c1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hi\"},\"finish_reason\":null}]}
177
178data: [DONE]
179";
180
181        let events = parse_sse_events::<ChatCompletionChunk>(raw);
182        assert_eq!(events.len(), 1);
183        assert_eq!(
184            events[0].as_ref().unwrap().choices[0]
185                .delta
186                .content
187                .as_deref(),
188            Some("Hi")
189        );
190    }
191
192    #[test]
193    fn test_parse_sse_done_stops_parsing() {
194        let raw = r#"data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"A"},"finish_reason":null}]}
195
196data: [DONE]
197
198data: {"id":"c2","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"B"},"finish_reason":null}]}
199"#;
200
201        let events = parse_sse_events::<ChatCompletionChunk>(raw);
202        assert_eq!(events.len(), 1);
203    }
204
205    #[test]
206    fn test_parse_sse_response_stream_events() {
207        use crate::types::responses::ResponseStreamEvent;
208
209        let raw = r#"data: {"type":"response.created","response":{"id":"resp-1","object":"response","status":"in_progress"}}
210
211data: {"type":"response.output_text.delta","delta":"Hello","output_index":0,"content_index":0}
212
213data: {"type":"response.output_text.delta","delta":" world","output_index":0,"content_index":0}
214
215data: {"type":"response.completed","response":{"id":"resp-1","status":"completed"}}
216
217data: [DONE]
218"#;
219
220        let events = parse_sse_events::<ResponseStreamEvent>(raw);
221        assert_eq!(events.len(), 4);
222        assert_eq!(events[0].as_ref().unwrap().type_, "response.created");
223        assert_eq!(
224            events[1].as_ref().unwrap().type_,
225            "response.output_text.delta"
226        );
227        assert_eq!(events[1].as_ref().unwrap().data["delta"], "Hello");
228        assert_eq!(events[2].as_ref().unwrap().data["delta"], " world");
229        assert_eq!(events[3].as_ref().unwrap().type_, "response.completed");
230    }
231
232    #[test]
233    fn test_parse_sse_invalid_json() {
234        let raw = "data: {invalid json}\n\ndata: [DONE]\n";
235        let events = parse_sse_events::<ChatCompletionChunk>(raw);
236        assert_eq!(events.len(), 1);
237        assert!(events[0].is_err());
238    }
239
240    #[test]
241    fn test_parse_sse_tool_call_chunks() {
242        let raw = r#"data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":""}}]},"finish_reason":null}]}
243
244data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"loc"}}]},"finish_reason":null}]}
245
246data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"ation\": \"Boston\"}"}}]},"finish_reason":null}]}
247
248data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}
249
250data: [DONE]
251"#;
252
253        let events = parse_sse_events::<ChatCompletionChunk>(raw);
254        assert_eq!(events.len(), 4);
255
256        let tc = events[0].as_ref().unwrap().choices[0]
257            .delta
258            .tool_calls
259            .as_ref()
260            .unwrap();
261        assert_eq!(tc[0].id.as_deref(), Some("call_1"));
262        assert_eq!(
263            tc[0].function.as_ref().unwrap().name.as_deref(),
264            Some("get_weather")
265        );
266
267        assert_eq!(
268            events[3].as_ref().unwrap().choices[0]
269                .finish_reason
270                .as_deref(),
271            Some("tool_calls")
272        );
273    }
274}