Skip to main content

openai_oxide/
streaming.rs

1// SSE stream parser for OpenAI streaming responses
2
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures_core::Stream;
7
8use crate::error::OpenAIError;
9
10/// A stream of parsed SSE events from an OpenAI streaming response.
11///
12/// See [OpenAI streaming guide](https://platform.openai.com/docs/api-reference/streaming).
13///
14/// Wraps a byte stream from reqwest and yields deserialized items.
15pub struct SseStream<T> {
16    #[cfg(not(target_arch = "wasm32"))]
17    inner: Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>,
18    #[cfg(target_arch = "wasm32")]
19    inner: Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>>>>,
20    buffer: String,
21    done: bool,
22    _phantom: std::marker::PhantomData<T>,
23}
24
25impl<T> SseStream<T> {
26    pub(crate) fn new(response: reqwest::Response) -> Self {
27        Self {
28            inner: Box::pin(response.bytes_stream()),
29            buffer: String::new(),
30            done: false,
31            _phantom: std::marker::PhantomData,
32        }
33    }
34}
35
36// SAFETY: SseStream has no self-referential data; inner is heap-boxed.
37impl<T> Unpin for SseStream<T> {}
38
39impl<T: serde::de::DeserializeOwned> Stream for SseStream<T> {
40    type Item = Result<T, OpenAIError>;
41
42    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
43        let this = self.get_mut();
44
45        loop {
46            if this.done {
47                return Poll::Ready(None);
48            }
49
50            // Check if we already have a complete event in the buffer
51            if let Some(item) = try_parse_next::<T>(&mut this.buffer, &mut this.done) {
52                return Poll::Ready(Some(item));
53            }
54
55            // Poll for more data from the byte stream
56            match this.inner.as_mut().poll_next(cx) {
57                Poll::Ready(Some(Ok(chunk))) => {
58                    this.buffer.push_str(&String::from_utf8_lossy(&chunk));
59                    // Safety cap: 4MB max buffer to prevent unbounded growth on malformed streams
60                    if this.buffer.len() > 4 * 1024 * 1024 {
61                        this.done = true;
62                        return Poll::Ready(Some(Err(OpenAIError::StreamError(
63                            "SSE buffer exceeded 4MB".into(),
64                        ))));
65                    }
66                    // Loop back to try_parse_next — avoids wake_by_ref() busy-poll.
67                    // If no complete event yet, we'll poll inner again which will
68                    // either give us more data or return Pending (registering waker).
69                    continue;
70                }
71                Poll::Ready(Some(Err(e))) => {
72                    this.done = true;
73                    return Poll::Ready(Some(Err(OpenAIError::RequestError(e))));
74                }
75                Poll::Ready(None) => {
76                    this.done = true;
77                    return match try_parse_next::<T>(&mut this.buffer, &mut this.done) {
78                        Some(item) => Poll::Ready(Some(item)),
79                        None => Poll::Ready(None),
80                    };
81                }
82                Poll::Pending => return Poll::Pending,
83            }
84        }
85    }
86}
87
88/// Try to extract and parse the next SSE event from the buffer.
89/// Returns `Some` if an event was found (success or error), `None` if more data is needed.
90fn try_parse_next<T: serde::de::DeserializeOwned>(
91    buffer: &mut String,
92    done: &mut bool,
93) -> Option<Result<T, OpenAIError>> {
94    loop {
95        let newline_pos = buffer.find('\n')?;
96        let line = buffer[..newline_pos].trim_end_matches('\r').to_string();
97        buffer.drain(..=newline_pos);
98
99        // Skip empty lines and comments
100        if line.is_empty() || line.starts_with(':') {
101            continue;
102        }
103
104        // Parse "data: ..." lines
105        if let Some(data) = line
106            .strip_prefix("data: ")
107            .or_else(|| line.strip_prefix("data:"))
108        {
109            let data = data.trim();
110
111            if data == "[DONE]" {
112                *done = true;
113                return None;
114            }
115
116            match serde_json::from_str::<T>(data) {
117                Ok(value) => return Some(Ok(value)),
118                Err(e) => return Some(Err(OpenAIError::JsonError(e))),
119            }
120        }
121
122        // Skip non-data SSE fields (event:, id:, retry:)
123    }
124}
125
126/// Parse SSE lines from raw text and yield data payloads.
127/// Useful for testing without HTTP. Returns items until `[DONE]` or end of input.
128pub fn parse_sse_events<T: serde::de::DeserializeOwned>(raw: &str) -> Vec<Result<T, OpenAIError>> {
129    let mut results = Vec::new();
130    let mut buffer = raw.to_string();
131    if !buffer.ends_with('\n') {
132        buffer.push('\n');
133    }
134    let mut done = false;
135
136    while !done {
137        match try_parse_next::<T>(&mut buffer, &mut done) {
138            Some(item) => results.push(item),
139            None => break,
140        }
141    }
142
143    results
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use crate::types::chat::ChatCompletionChunk;
150
151    #[test]
152    fn test_parse_sse_content_chunks() {
153        let raw = r#"data: {"id":"chatcmpl-1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
154
155data: {"id":"chatcmpl-1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
156
157data: {"id":"chatcmpl-1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]}
158
159data: {"id":"chatcmpl-1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
160
161data: [DONE]
162
163"#;
164
165        let events = parse_sse_events::<ChatCompletionChunk>(raw);
166        assert_eq!(events.len(), 4);
167
168        let chunk0 = events[0].as_ref().unwrap();
169        assert_eq!(
170            chunk0.choices[0].delta.role,
171            Some(crate::types::common::Role::Assistant)
172        );
173
174        let chunk1 = events[1].as_ref().unwrap();
175        assert_eq!(chunk1.choices[0].delta.content.as_deref(), Some("Hello"));
176
177        let chunk2 = events[2].as_ref().unwrap();
178        assert_eq!(chunk2.choices[0].delta.content.as_deref(), Some(" world"));
179
180        let chunk3 = events[3].as_ref().unwrap();
181        assert_eq!(
182            chunk3.choices[0].finish_reason,
183            Some(crate::types::common::FinishReason::Stop)
184        );
185    }
186
187    #[test]
188    fn test_parse_sse_with_comments_and_empty_lines() {
189        let raw = ": this is a comment
190data: {\"id\":\"c1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hi\"},\"finish_reason\":null}]}
191
192data: [DONE]
193";
194
195        let events = parse_sse_events::<ChatCompletionChunk>(raw);
196        assert_eq!(events.len(), 1);
197        assert_eq!(
198            events[0].as_ref().unwrap().choices[0]
199                .delta
200                .content
201                .as_deref(),
202            Some("Hi")
203        );
204    }
205
206    #[test]
207    fn test_parse_sse_done_stops_parsing() {
208        let raw = r#"data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"A"},"finish_reason":null}]}
209
210data: [DONE]
211
212data: {"id":"c2","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"B"},"finish_reason":null}]}
213"#;
214
215        let events = parse_sse_events::<ChatCompletionChunk>(raw);
216        assert_eq!(events.len(), 1);
217    }
218
219    #[test]
220    fn test_parse_sse_response_stream_events() {
221        use crate::types::responses::ResponseStreamEvent;
222
223        let raw = r#"data: {"type":"response.created","response":{"id":"resp-1","object":"response","created_at":1.0,"model":"gpt-4o","output":[],"status":"in_progress"}}
224
225data: {"type":"response.output_text.delta","delta":"Hello","output_index":0,"content_index":0}
226
227data: {"type":"response.output_text.delta","delta":" world","output_index":0,"content_index":0}
228
229data: {"type":"response.completed","response":{"id":"resp-1","object":"response","created_at":1.0,"model":"gpt-4o","output":[],"status":"completed"}}
230
231data: [DONE]
232"#;
233
234        let events = parse_sse_events::<ResponseStreamEvent>(raw);
235        assert_eq!(events.len(), 4);
236        assert_eq!(events[0].as_ref().unwrap().event_type(), "response.created");
237        assert_eq!(
238            events[1].as_ref().unwrap().event_type(),
239            "response.output_text.delta"
240        );
241        match events[1].as_ref().unwrap() {
242            ResponseStreamEvent::ResponseOutputTextDelta(evt) => assert_eq!(evt.delta, "Hello"),
243            other => panic!("expected ResponseOutputTextDelta, got: {other:?}"),
244        }
245        match events[2].as_ref().unwrap() {
246            ResponseStreamEvent::ResponseOutputTextDelta(evt) => assert_eq!(evt.delta, " world"),
247            other => panic!("expected ResponseOutputTextDelta, got: {other:?}"),
248        }
249        assert_eq!(
250            events[3].as_ref().unwrap().event_type(),
251            "response.completed"
252        );
253    }
254
255    /// Test SSE streaming through actual HTTP (mockito), not just parsing.
256    #[tokio::test]
257    async fn test_sse_stream_via_http() {
258        use futures_util::StreamExt;
259        let mut server = mockito::Server::new_async().await;
260        let sse_body = "data: {\"id\":\"c1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hi\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"c1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" there\"},\"finish_reason\":null}]}\n\ndata: [DONE]\n\n";
261
262        let _mock = server
263            .mock("POST", "/chat/completions")
264            .with_status(200)
265            .with_header("content-type", "text/event-stream")
266            .with_body(sse_body)
267            .create_async()
268            .await;
269
270        let client = crate::OpenAI::with_config(
271            crate::config::ClientConfig::new("sk-test").base_url(server.url()),
272        );
273        let request = crate::types::chat::ChatCompletionRequest::new(
274            "gpt-4o",
275            vec![crate::types::chat::ChatCompletionMessageParam::User {
276                content: crate::types::chat::UserContent::Text("Hi".into()),
277                name: None,
278            }],
279        );
280        let stream = client
281            .chat()
282            .completions()
283            .create_stream(request)
284            .await
285            .unwrap();
286
287        let chunks: Vec<_> = stream
288            .collect::<Vec<_>>()
289            .await
290            .into_iter()
291            .filter_map(|r| r.ok())
292            .collect();
293
294        assert_eq!(chunks.len(), 2);
295        assert_eq!(chunks[0].choices[0].delta.content.as_deref(), Some("Hi"));
296        assert_eq!(
297            chunks[1].choices[0].delta.content.as_deref(),
298            Some(" there")
299        );
300    }
301
302    /// Test that SSE stream surfaces API errors from the server.
303    #[tokio::test]
304    async fn test_sse_stream_api_error() {
305        let mut server = mockito::Server::new_async().await;
306        let _mock = server
307            .mock("POST", "/chat/completions")
308            .with_status(429)
309            .with_body(r#"{"error":{"message":"Rate limit exceeded","type":"rate_limit","param":null,"code":null}}"#)
310            .create_async()
311            .await;
312
313        let client = crate::OpenAI::with_config(
314            crate::config::ClientConfig::new("sk-test")
315                .base_url(server.url())
316                .max_retries(0),
317        );
318        let request = crate::types::chat::ChatCompletionRequest::new(
319            "gpt-4o",
320            vec![crate::types::chat::ChatCompletionMessageParam::User {
321                content: crate::types::chat::UserContent::Text("Hi".into()),
322                name: None,
323            }],
324        );
325        let err = client
326            .chat()
327            .completions()
328            .create_stream(request)
329            .await
330            .err()
331            .expect("expected error");
332
333        match err {
334            OpenAIError::ApiError { status, .. } => assert_eq!(status, 429),
335            other => panic!("expected ApiError, got: {other:?}"),
336        }
337    }
338
339    /// Test SSE with multi-byte UTF-8 that may split across chunks.
340    #[test]
341    fn test_parse_sse_multibyte_utf8() {
342        // Emoji in content
343        let raw = "data: {\"id\":\"c1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello 🌍\"},\"finish_reason\":null}]}\n\ndata: [DONE]\n";
344        let events = parse_sse_events::<ChatCompletionChunk>(raw);
345        assert_eq!(events.len(), 1);
346        assert_eq!(
347            events[0].as_ref().unwrap().choices[0]
348                .delta
349                .content
350                .as_deref(),
351            Some("Hello 🌍")
352        );
353    }
354
355    #[test]
356    fn test_parse_sse_invalid_json() {
357        let raw = "data: {invalid json}\n\ndata: [DONE]\n";
358        let events = parse_sse_events::<ChatCompletionChunk>(raw);
359        assert_eq!(events.len(), 1);
360        assert!(events[0].is_err());
361    }
362
363    #[test]
364    fn test_parse_sse_tool_call_chunks() {
365        let raw = r#"data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":""}}]},"finish_reason":null}]}
366
367data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"loc"}}]},"finish_reason":null}]}
368
369data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"ation\": \"Boston\"}"}}]},"finish_reason":null}]}
370
371data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}
372
373data: [DONE]
374"#;
375
376        let events = parse_sse_events::<ChatCompletionChunk>(raw);
377        assert_eq!(events.len(), 4);
378
379        let tc = events[0].as_ref().unwrap().choices[0]
380            .delta
381            .tool_calls
382            .as_ref()
383            .unwrap();
384        assert_eq!(tc[0].id.as_deref(), Some("call_1"));
385        assert_eq!(
386            tc[0].function.as_ref().unwrap().name.as_deref(),
387            Some("get_weather")
388        );
389
390        assert_eq!(
391            events[3].as_ref().unwrap().choices[0].finish_reason,
392            Some(crate::types::common::FinishReason::ToolCalls)
393        );
394    }
395}