1use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures_core::Stream;
7
8use crate::error::OpenAIError;
9
10pub struct SseStream<T> {
14 inner: Pin<Box<dyn Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send>>,
15 buffer: String,
16 done: bool,
17 _phantom: std::marker::PhantomData<T>,
18}
19
20impl<T> SseStream<T> {
21 pub(crate) fn new(response: reqwest::Response) -> Self {
22 Self {
23 inner: Box::pin(response.bytes_stream()),
24 buffer: String::new(),
25 done: false,
26 _phantom: std::marker::PhantomData,
27 }
28 }
29}
30
31impl<T> Unpin for SseStream<T> {}
33
34impl<T: serde::de::DeserializeOwned> Stream for SseStream<T> {
35 type Item = Result<T, OpenAIError>;
36
37 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
38 let this = self.get_mut();
39
40 if this.done {
41 return Poll::Ready(None);
42 }
43
44 if let Some(item) = try_parse_next::<T>(&mut this.buffer, &mut this.done) {
46 return Poll::Ready(Some(item));
47 }
48
49 match this.inner.as_mut().poll_next(cx) {
51 Poll::Ready(Some(Ok(chunk))) => {
52 this.buffer.push_str(&String::from_utf8_lossy(&chunk));
53 match try_parse_next::<T>(&mut this.buffer, &mut this.done) {
54 Some(item) => Poll::Ready(Some(item)),
55 None => {
56 cx.waker().wake_by_ref();
57 Poll::Pending
58 }
59 }
60 }
61 Poll::Ready(Some(Err(e))) => {
62 this.done = true;
63 Poll::Ready(Some(Err(OpenAIError::RequestError(e))))
64 }
65 Poll::Ready(None) => {
66 this.done = true;
67 match try_parse_next::<T>(&mut this.buffer, &mut this.done) {
68 Some(item) => Poll::Ready(Some(item)),
69 None => Poll::Ready(None),
70 }
71 }
72 Poll::Pending => Poll::Pending,
73 }
74 }
75}
76
77fn try_parse_next<T: serde::de::DeserializeOwned>(
80 buffer: &mut String,
81 done: &mut bool,
82) -> Option<Result<T, OpenAIError>> {
83 loop {
84 let newline_pos = buffer.find('\n')?;
85 let line = buffer[..newline_pos].trim_end_matches('\r').to_string();
86 buffer.drain(..=newline_pos);
87
88 if line.is_empty() || line.starts_with(':') {
90 continue;
91 }
92
93 if let Some(data) = line
95 .strip_prefix("data: ")
96 .or_else(|| line.strip_prefix("data:"))
97 {
98 let data = data.trim();
99
100 if data == "[DONE]" {
101 *done = true;
102 return None;
103 }
104
105 match serde_json::from_str::<T>(data) {
106 Ok(value) => return Some(Ok(value)),
107 Err(e) => return Some(Err(OpenAIError::JsonError(e))),
108 }
109 }
110
111 }
113}
114
115pub fn parse_sse_events<T: serde::de::DeserializeOwned>(raw: &str) -> Vec<Result<T, OpenAIError>> {
118 let mut results = Vec::new();
119 let mut buffer = raw.to_string();
120 if !buffer.ends_with('\n') {
121 buffer.push('\n');
122 }
123 let mut done = false;
124
125 while !done {
126 match try_parse_next::<T>(&mut buffer, &mut done) {
127 Some(item) => results.push(item),
128 None => break,
129 }
130 }
131
132 results
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use crate::types::chat::ChatCompletionChunk;
139
140 #[test]
141 fn test_parse_sse_content_chunks() {
142 let raw = r#"data: {"id":"chatcmpl-1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
143
144data: {"id":"chatcmpl-1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
145
146data: {"id":"chatcmpl-1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":" world"},"finish_reason":null}]}
147
148data: {"id":"chatcmpl-1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
149
150data: [DONE]
151
152"#;
153
154 let events = parse_sse_events::<ChatCompletionChunk>(raw);
155 assert_eq!(events.len(), 4);
156
157 let chunk0 = events[0].as_ref().unwrap();
158 assert_eq!(chunk0.choices[0].delta.role.as_deref(), Some("assistant"));
159
160 let chunk1 = events[1].as_ref().unwrap();
161 assert_eq!(chunk1.choices[0].delta.content.as_deref(), Some("Hello"));
162
163 let chunk2 = events[2].as_ref().unwrap();
164 assert_eq!(chunk2.choices[0].delta.content.as_deref(), Some(" world"));
165
166 let chunk3 = events[3].as_ref().unwrap();
167 assert_eq!(chunk3.choices[0].finish_reason.as_deref(), Some("stop"));
168 }
169
170 #[test]
171 fn test_parse_sse_with_comments_and_empty_lines() {
172 let raw = ": this is a comment
173data: {\"id\":\"c1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gpt-4o\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hi\"},\"finish_reason\":null}]}
174
175data: [DONE]
176";
177
178 let events = parse_sse_events::<ChatCompletionChunk>(raw);
179 assert_eq!(events.len(), 1);
180 assert_eq!(
181 events[0].as_ref().unwrap().choices[0]
182 .delta
183 .content
184 .as_deref(),
185 Some("Hi")
186 );
187 }
188
189 #[test]
190 fn test_parse_sse_done_stops_parsing() {
191 let raw = r#"data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"A"},"finish_reason":null}]}
192
193data: [DONE]
194
195data: {"id":"c2","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"B"},"finish_reason":null}]}
196"#;
197
198 let events = parse_sse_events::<ChatCompletionChunk>(raw);
199 assert_eq!(events.len(), 1);
200 }
201
202 #[test]
203 fn test_parse_sse_invalid_json() {
204 let raw = "data: {invalid json}\n\ndata: [DONE]\n";
205 let events = parse_sse_events::<ChatCompletionChunk>(raw);
206 assert_eq!(events.len(), 1);
207 assert!(events[0].is_err());
208 }
209
210 #[test]
211 fn test_parse_sse_tool_call_chunks() {
212 let raw = r#"data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"get_weather","arguments":""}}]},"finish_reason":null}]}
213
214data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"loc"}}]},"finish_reason":null}]}
215
216data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"ation\": \"Boston\"}"}}]},"finish_reason":null}]}
217
218data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}
219
220data: [DONE]
221"#;
222
223 let events = parse_sse_events::<ChatCompletionChunk>(raw);
224 assert_eq!(events.len(), 4);
225
226 let tc = events[0].as_ref().unwrap().choices[0]
227 .delta
228 .tool_calls
229 .as_ref()
230 .unwrap();
231 assert_eq!(tc[0].id.as_deref(), Some("call_1"));
232 assert_eq!(
233 tc[0].function.as_ref().unwrap().name.as_deref(),
234 Some("get_weather")
235 );
236
237 assert_eq!(
238 events[3].as_ref().unwrap().choices[0]
239 .finish_reason
240 .as_deref(),
241 Some("tool_calls")
242 );
243 }
244}