1use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures_core::Stream;
7
8use crate::error::OpenAIError;
9
10pub struct SseStream<T> {
16 #[cfg(not(target_arch = "wasm32"))]
17 inner: Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>,
18 #[cfg(target_arch = "wasm32")]
19 inner: Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>>>>,
20 buffer: String,
21 done: bool,
22 _phantom: std::marker::PhantomData<T>,
23}
24
25impl<T> SseStream<T> {
26 pub(crate) fn new(response: reqwest::Response) -> Self {
27 Self {
28 inner: Box::pin(response.bytes_stream()),
29 buffer: String::new(),
30 done: false,
31 _phantom: std::marker::PhantomData,
32 }
33 }
34}
35
36impl<T> Unpin for SseStream<T> {}
38
39impl<T: serde::de::DeserializeOwned> Stream for SseStream<T> {
40 type Item = Result<T, OpenAIError>;
41
42 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
43 let this = self.get_mut();
44
45 loop {
46 if this.done {
47 return Poll::Ready(None);
48 }
49
50 if let Some(item) = try_parse_next::<T>(&mut this.buffer, &mut this.done) {
52 return Poll::Ready(Some(item));
53 }
54
55 match this.inner.as_mut().poll_next(cx) {
57 Poll::Ready(Some(Ok(chunk))) => {
58 this.buffer.push_str(&String::from_utf8_lossy(&chunk));
59 if this.buffer.len() > 4 * 1024 * 1024 {
61 this.done = true;
62 return Poll::Ready(Some(Err(OpenAIError::StreamError(
63 "SSE buffer exceeded 4MB".into(),
64 ))));
65 }
66 continue;
70 }
71 Poll::Ready(Some(Err(e))) => {
72 this.done = true;
73 return Poll::Ready(Some(Err(OpenAIError::RequestError(e))));
74 }
75 Poll::Ready(None) => {
76 this.done = true;
77 return match try_parse_next::<T>(&mut this.buffer, &mut this.done) {
78 Some(item) => Poll::Ready(Some(item)),
79 None => Poll::Ready(None),
80 };
81 }
82 Poll::Pending => return Poll::Pending,
83 }
84 }
85 }
86}
87
88fn try_parse_next<T: serde::de::DeserializeOwned>(
91 buffer: &mut String,
92 done: &mut bool,
93) -> Option<Result<T, OpenAIError>> {
94 loop {
95 let newline_pos = buffer.find('\n')?;
96 let line = buffer[..newline_pos].trim_end_matches('\r').to_string();
97 buffer.drain(..=newline_pos);
98
99 if line.is_empty() || line.starts_with(':') {
101 continue;
102 }
103
104 if let Some(data) = line
106 .strip_prefix("data: ")
107 .or_else(|| line.strip_prefix("data:"))
108 {
109 let data = data.trim();
110
111 if data == "[DONE]" {
112 *done = true;
113 return None;
114 }
115
116 match serde_json::from_str::<T>(data) {
117 Ok(value) => return Some(Ok(value)),
118 Err(e) => return Some(Err(OpenAIError::JsonError(e))),
119 }
120 }
121
122 }
124}
125
126pub fn parse_sse_events<T: serde::de::DeserializeOwned>(raw: &str) -> Vec<Result<T, OpenAIError>> {
129 let mut results = Vec::new();
130 let mut buffer = raw.to_string();
131 if !buffer.ends_with('\n') {
132 buffer.push('\n');
133 }
134 let mut done = false;
135
136 while !done {
137 match try_parse_next::<T>(&mut buffer, &mut done) {
138 Some(item) => results.push(item),
139 None => break,
140 }
141 }
142
143 results
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use crate::types::chat::ChatCompletionChunk;
150
151 #[test]
152 fn test_parse_sse_content_chunks() {
153 let raw = r#"data: {"id":"chatcmpl-1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
154
155data: {"id":"chatcmpl-1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
156
157data: {"id":"chatcmpl-1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]}
158
159data: {"id":"chatcmpl-1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
160
161data: [DONE]
162
163"#;
164
165 let events = parse_sse_events::<ChatCompletionChunk>(raw);
166 assert_eq!(events.len(), 4);
167
168 let chunk0 = events[0].as_ref().unwrap();
169 assert_eq!(
170 chunk0.choices[0].delta.role,
171 Some(crate::types::common::Role::Assistant)
172 );
173
174 let chunk1 = events[1].as_ref().unwrap();
175 assert_eq!(chunk1.choices[0].delta.content.as_deref(), Some("Hello"));
176
177 let chunk2 = events[2].as_ref().unwrap();
178 assert_eq!(chunk2.choices[0].delta.content.as_deref(), Some(" world"));
179
180 let chunk3 = events[3].as_ref().unwrap();
181 assert_eq!(
182 chunk3.choices[0].finish_reason,
183 Some(crate::types::common::FinishReason::Stop)
184 );
185 }
186
187 #[test]
188 fn test_parse_sse_with_comments_and_empty_lines() {
189 let raw = ": this is a comment
190data: {\"id\":\"c1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hi\"},\"finish_reason\":null}]}
191
192data: [DONE]
193";
194
195 let events = parse_sse_events::<ChatCompletionChunk>(raw);
196 assert_eq!(events.len(), 1);
197 assert_eq!(
198 events[0].as_ref().unwrap().choices[0]
199 .delta
200 .content
201 .as_deref(),
202 Some("Hi")
203 );
204 }
205
206 #[test]
207 fn test_parse_sse_done_stops_parsing() {
208 let raw = r#"data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"A"},"finish_reason":null}]}
209
210data: [DONE]
211
212data: {"id":"c2","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"B"},"finish_reason":null}]}
213"#;
214
215 let events = parse_sse_events::<ChatCompletionChunk>(raw);
216 assert_eq!(events.len(), 1);
217 }
218
219 #[test]
220 fn test_parse_sse_response_stream_events() {
221 use crate::types::responses::ResponseStreamEvent;
222
223 let raw = r#"data: {"type":"response.created","response":{"id":"resp-1","object":"response","created_at":1.0,"model":"gpt-4o","output":[],"status":"in_progress"}}
224
225data: {"type":"response.output_text.delta","delta":"Hello","output_index":0,"content_index":0}
226
227data: {"type":"response.output_text.delta","delta":" world","output_index":0,"content_index":0}
228
229data: {"type":"response.completed","response":{"id":"resp-1","object":"response","created_at":1.0,"model":"gpt-4o","output":[],"status":"completed"}}
230
231data: [DONE]
232"#;
233
234 let events = parse_sse_events::<ResponseStreamEvent>(raw);
235 assert_eq!(events.len(), 4);
236 assert_eq!(events[0].as_ref().unwrap().event_type(), "response.created");
237 assert_eq!(
238 events[1].as_ref().unwrap().event_type(),
239 "response.output_text.delta"
240 );
241 match events[1].as_ref().unwrap() {
242 ResponseStreamEvent::ResponseOutputTextDelta(evt) => assert_eq!(evt.delta, "Hello"),
243 other => panic!("expected ResponseOutputTextDelta, got: {other:?}"),
244 }
245 match events[2].as_ref().unwrap() {
246 ResponseStreamEvent::ResponseOutputTextDelta(evt) => assert_eq!(evt.delta, " world"),
247 other => panic!("expected ResponseOutputTextDelta, got: {other:?}"),
248 }
249 assert_eq!(
250 events[3].as_ref().unwrap().event_type(),
251 "response.completed"
252 );
253 }
254
255 #[tokio::test]
257 async fn test_sse_stream_via_http() {
258 use futures_util::StreamExt;
259 let mut server = mockito::Server::new_async().await;
260 let sse_body = "data: {\"id\":\"c1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hi\"},\"finish_reason\":null}]}\n\ndata: {\"id\":\"c1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" there\"},\"finish_reason\":null}]}\n\ndata: [DONE]\n\n";
261
262 let _mock = server
263 .mock("POST", "/chat/completions")
264 .with_status(200)
265 .with_header("content-type", "text/event-stream")
266 .with_body(sse_body)
267 .create_async()
268 .await;
269
270 let client = crate::OpenAI::with_config(
271 crate::config::ClientConfig::new("sk-test").base_url(server.url()),
272 );
273 let request = crate::types::chat::ChatCompletionRequest::new(
274 "gpt-4o",
275 vec![crate::types::chat::ChatCompletionMessageParam::User {
276 content: crate::types::chat::UserContent::Text("Hi".into()),
277 name: None,
278 }],
279 );
280 let stream = client
281 .chat()
282 .completions()
283 .create_stream(request)
284 .await
285 .unwrap();
286
287 let chunks: Vec<_> = stream
288 .collect::<Vec<_>>()
289 .await
290 .into_iter()
291 .filter_map(|r| r.ok())
292 .collect();
293
294 assert_eq!(chunks.len(), 2);
295 assert_eq!(chunks[0].choices[0].delta.content.as_deref(), Some("Hi"));
296 assert_eq!(
297 chunks[1].choices[0].delta.content.as_deref(),
298 Some(" there")
299 );
300 }
301
302 #[tokio::test]
304 async fn test_sse_stream_api_error() {
305 let mut server = mockito::Server::new_async().await;
306 let _mock = server
307 .mock("POST", "/chat/completions")
308 .with_status(429)
309 .with_body(r#"{"error":{"message":"Rate limit exceeded","type":"rate_limit","param":null,"code":null}}"#)
310 .create_async()
311 .await;
312
313 let client = crate::OpenAI::with_config(
314 crate::config::ClientConfig::new("sk-test")
315 .base_url(server.url())
316 .max_retries(0),
317 );
318 let request = crate::types::chat::ChatCompletionRequest::new(
319 "gpt-4o",
320 vec![crate::types::chat::ChatCompletionMessageParam::User {
321 content: crate::types::chat::UserContent::Text("Hi".into()),
322 name: None,
323 }],
324 );
325 let err = client
326 .chat()
327 .completions()
328 .create_stream(request)
329 .await
330 .err()
331 .expect("expected error");
332
333 match err {
334 OpenAIError::ApiError { status, .. } => assert_eq!(status, 429),
335 other => panic!("expected ApiError, got: {other:?}"),
336 }
337 }
338
339 #[test]
341 fn test_parse_sse_multibyte_utf8() {
342 let raw = "data: {\"id\":\"c1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello 🌍\"},\"finish_reason\":null}]}\n\ndata: [DONE]\n";
344 let events = parse_sse_events::<ChatCompletionChunk>(raw);
345 assert_eq!(events.len(), 1);
346 assert_eq!(
347 events[0].as_ref().unwrap().choices[0]
348 .delta
349 .content
350 .as_deref(),
351 Some("Hello 🌍")
352 );
353 }
354
355 #[test]
356 fn test_parse_sse_invalid_json() {
357 let raw = "data: {invalid json}\n\ndata: [DONE]\n";
358 let events = parse_sse_events::<ChatCompletionChunk>(raw);
359 assert_eq!(events.len(), 1);
360 assert!(events[0].is_err());
361 }
362
363 #[test]
364 fn test_parse_sse_tool_call_chunks() {
365 let raw = r#"data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":""}}]},"finish_reason":null}]}
366
367data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"loc"}}]},"finish_reason":null}]}
368
369data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"ation\": \"Boston\"}"}}]},"finish_reason":null}]}
370
371data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}
372
373data: [DONE]
374"#;
375
376 let events = parse_sse_events::<ChatCompletionChunk>(raw);
377 assert_eq!(events.len(), 4);
378
379 let tc = events[0].as_ref().unwrap().choices[0]
380 .delta
381 .tool_calls
382 .as_ref()
383 .unwrap();
384 assert_eq!(tc[0].id.as_deref(), Some("call_1"));
385 assert_eq!(
386 tc[0].function.as_ref().unwrap().name.as_deref(),
387 Some("get_weather")
388 );
389
390 assert_eq!(
391 events[3].as_ref().unwrap().choices[0].finish_reason,
392 Some(crate::types::common::FinishReason::ToolCalls)
393 );
394 }
395}