Skip to main content

openai_core/stream/
sse.rs

1use std::fmt;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use async_stream::try_stream;
6use futures_util::{Stream, StreamExt};
7
8use crate::error::{Result, SerializationError, StreamError};
9use crate::response_meta::ResponseMeta;
10
11/// 用于把字节流切分为逻辑行。
12#[derive(Debug, Default, Clone)]
13pub struct LineDecoder {
14    buffer: Vec<u8>,
15}
16
17impl LineDecoder {
18    /// 向解码器推入一个新分片,并返回已经完整的行。
19    ///
20    /// # Errors
21    ///
22    /// 当 UTF-8 解码失败时返回错误。
23    pub fn push(&mut self, chunk: &[u8]) -> Result<Vec<String>> {
24        self.buffer.extend_from_slice(chunk);
25        let mut lines = Vec::new();
26        let mut start = 0usize;
27        let mut index = 0usize;
28
29        while index < self.buffer.len() {
30            match self.buffer[index] {
31                b'\n' => {
32                    let end = if index > start && self.buffer[index - 1] == b'\r' {
33                        index - 1
34                    } else {
35                        index
36                    };
37                    lines.push(bytes_to_string(&self.buffer[start..end])?);
38                    start = index + 1;
39                }
40                b'\r' => {
41                    let end = index;
42                    if index + 1 < self.buffer.len() {
43                        if self.buffer[index + 1] == b'\n' {
44                            index += 1;
45                            lines.push(bytes_to_string(&self.buffer[start..end])?);
46                            start = index + 1;
47                        } else {
48                            lines.push(bytes_to_string(&self.buffer[start..end])?);
49                            start = index + 1;
50                        }
51                    } else {
52                        break;
53                    }
54                }
55                _ => {}
56            }
57            index += 1;
58        }
59
60        if start > 0 {
61            self.buffer.drain(0..start);
62        }
63
64        Ok(lines)
65    }
66
67    /// 在输入结束时刷新最后一行。
68    ///
69    /// # Errors
70    ///
71    /// 当 UTF-8 解码失败时返回错误。
72    pub fn finish(&mut self) -> Result<Option<String>> {
73        if self.buffer.is_empty() {
74            return Ok(None);
75        }
76
77        let line = if self.buffer.last() == Some(&b'\r') {
78            let length = self.buffer.len() - 1;
79            bytes_to_string(&self.buffer[..length])?
80        } else {
81            bytes_to_string(&self.buffer)?
82        };
83        self.buffer.clear();
84        Ok(Some(line))
85    }
86}
87
88fn bytes_to_string(bytes: &[u8]) -> Result<String> {
89    String::from_utf8(bytes.to_vec()).map_err(|error| {
90        SerializationError::new(format!("SSE 行解码失败,收到非法 UTF-8: {error}")).into()
91    })
92}
93
94/// 表示一个标准 SSE 事件。
95#[derive(Debug, Clone, PartialEq, Eq)]
96pub struct SseEvent {
97    /// 事件名。
98    pub event: Option<String>,
99    /// 数据体。
100    pub data: String,
101    /// 事件 ID。
102    pub id: Option<String>,
103    /// 服务端建议的重连时间。
104    pub retry: Option<u64>,
105}
106
107#[derive(Debug, Default)]
108struct PendingSseEvent {
109    event: Option<String>,
110    data: Vec<String>,
111    id: Option<String>,
112    retry: Option<u64>,
113}
114
115impl PendingSseEvent {
116    fn push_line(&mut self, line: &str) -> Result<Option<SseEvent>> {
117        if line.is_empty() {
118            if self.event.is_none()
119                && self.data.is_empty()
120                && self.id.is_none()
121                && self.retry.is_none()
122            {
123                return Ok(None);
124            }
125
126            let event = SseEvent {
127                event: self.event.take(),
128                data: self.data.join("\n"),
129                id: self.id.take(),
130                retry: self.retry.take(),
131            };
132            self.data.clear();
133            return Ok(Some(event));
134        }
135
136        if line.starts_with(':') {
137            return Ok(None);
138        }
139
140        let (field, value) = match line.split_once(':') {
141            Some((field, value)) => (field, value.strip_prefix(' ').unwrap_or(value)),
142            None => (line, ""),
143        };
144
145        match field {
146            "event" => self.event = Some(value.to_owned()),
147            "data" => self.data.push(value.to_owned()),
148            "id" => self.id = Some(value.to_owned()),
149            "retry" => {
150                self.retry = value.parse::<u64>().ok();
151            }
152            _ => {}
153        }
154
155        Ok(None)
156    }
157
158    fn flush(&mut self) -> Option<SseEvent> {
159        if self.event.is_none() && self.data.is_empty() && self.id.is_none() && self.retry.is_none()
160        {
161            return None;
162        }
163
164        let event = SseEvent {
165            event: self.event.take(),
166            data: self.data.join("\n"),
167            id: self.id.take(),
168            retry: self.retry.take(),
169        };
170        self.data.clear();
171        Some(event)
172    }
173}
174
175/// 表示原始 SSE 流。
176pub struct RawSseStream {
177    inner: Pin<Box<dyn Stream<Item = Result<SseEvent>> + Send>>,
178    meta: ResponseMeta,
179}
180
181impl RawSseStream {
182    /// 从 `reqwest::Response` 创建原始 SSE 流。
183    #[allow(clippy::collapsible_if, tail_expr_drop_order)]
184    pub fn new(response: reqwest::Response, meta: ResponseMeta) -> Self {
185        let stream = try_stream! {
186            let mut decoder = LineDecoder::default();
187            let mut pending = PendingSseEvent::default();
188            let mut byte_stream = response.bytes_stream();
189
190            while let Some(chunk) = byte_stream.next().await {
191                let chunk = chunk.map_err(|error| StreamError::new(format!("读取 SSE 数据流失败: {error}")))?;
192                for line in decoder.push(&chunk)? {
193                    if let Some(event) = pending.push_line(&line)? {
194                        yield event;
195                    }
196                }
197            }
198
199            if let Some(line) = decoder.finish()? {
200                if let Some(event) = pending.push_line(&line)? {
201                    yield event;
202                }
203            }
204
205            if let Some(event) = pending.flush() {
206                yield event;
207            }
208        };
209
210        Self {
211            inner: Box::pin(stream),
212            meta,
213        }
214    }
215
216    /// 返回流对应的响应元信息。
217    pub fn meta(&self) -> &ResponseMeta {
218        &self.meta
219    }
220
221    /// 将原始 SSE 流转换为 JSON 事件流。
222    #[allow(tail_expr_drop_order)]
223    pub fn into_typed<T>(self) -> SseStream<T>
224    where
225        T: serde::de::DeserializeOwned + Send + 'static,
226    {
227        let meta = self.meta.clone();
228        let stream = try_stream! {
229            let mut raw = self;
230            while let Some(event) = raw.next().await {
231                let event = event?;
232                if event.data == "[DONE]" {
233                    break;
234                }
235                let item = serde_json::from_str::<T>(&event.data).map_err(|error| {
236                    StreamError::new(format!("解析 SSE JSON 事件失败: {error}; payload={}", event.data))
237                })?;
238                yield item;
239            }
240        };
241
242        SseStream {
243            inner: Box::pin(stream),
244            meta,
245        }
246    }
247}
248
249impl fmt::Debug for RawSseStream {
250    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251        f.debug_struct("RawSseStream")
252            .field("meta", &self.meta)
253            .finish()
254    }
255}
256
257impl Stream for RawSseStream {
258    type Item = Result<SseEvent>;
259
260    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
261        self.get_mut().inner.as_mut().poll_next(cx)
262    }
263}
264
265#[cfg(test)]
266mod property_tests {
267    use proptest::prelude::*;
268
269    use super::LineDecoder;
270
271    #[derive(Debug, Clone, Copy)]
272    enum Separator {
273        Lf,
274        Cr,
275        CrLf,
276    }
277
278    impl Separator {
279        fn as_str(self) -> &'static str {
280            match self {
281                Self::Lf => "\n",
282                Self::Cr => "\r",
283                Self::CrLf => "\r\n",
284            }
285        }
286    }
287
288    fn separator_strategy() -> impl Strategy<Value = Separator> {
289        prop_oneof![
290            Just(Separator::Lf),
291            Just(Separator::Cr),
292            Just(Separator::CrLf),
293        ]
294    }
295
296    proptest! {
297        #[test]
298        fn line_decoder_preserves_lines_across_arbitrary_chunking(
299            lines in prop::collection::vec("[^\r\n]{0,16}", 1..8),
300            separator in separator_strategy(),
301            chunk_sizes in prop::collection::vec(1usize..8, 1..32),
302        ) {
303            let mut payload = String::new();
304            for line in lines.iter() {
305                payload.push_str(line);
306                payload.push_str(separator.as_str());
307            }
308
309            let mut decoder = LineDecoder::default();
310            let mut decoded = Vec::new();
311            let bytes = payload.as_bytes();
312            let mut offset = 0usize;
313
314            for chunk_size in chunk_sizes {
315                if offset >= bytes.len() {
316                    break;
317                }
318                let end = (offset + chunk_size).min(bytes.len());
319                decoded.extend(decoder.push(&bytes[offset..end]).unwrap());
320                offset = end;
321            }
322
323            if offset < bytes.len() {
324                decoded.extend(decoder.push(&bytes[offset..]).unwrap());
325            }
326
327            if let Some(tail) = decoder.finish().unwrap() {
328                decoded.push(tail);
329            }
330            prop_assert_eq!(decoded, lines);
331        }
332    }
333
334    #[test]
335    fn line_decoder_flushes_final_partial_line() {
336        let mut decoder = LineDecoder::default();
337        assert!(decoder.push(b"event: response.created").unwrap().is_empty());
338        assert_eq!(
339            decoder.finish().unwrap(),
340            Some("event: response.created".into())
341        );
342    }
343}
344
345/// 表示一个类型化后的 SSE 流。
346pub struct SseStream<T> {
347    inner: Pin<Box<dyn Stream<Item = Result<T>> + Send>>,
348    meta: ResponseMeta,
349}
350
351impl<T> SseStream<T> {
352    /// 返回流对应的响应元信息。
353    pub fn meta(&self) -> &ResponseMeta {
354        &self.meta
355    }
356}
357
358impl<T> Stream for SseStream<T> {
359    type Item = Result<T>;
360
361    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
362        self.get_mut().inner.as_mut().poll_next(cx)
363    }
364}
365
366impl<T> fmt::Debug for SseStream<T> {
367    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
368        f.debug_struct("SseStream")
369            .field("meta", &self.meta)
370            .finish()
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::{LineDecoder, PendingSseEvent};
377
378    #[test]
379    fn test_should_decode_lines_for_mixed_newlines() {
380        let mut decoder = LineDecoder::default();
381        let first = decoder
382            .push(b"data: one\r\ndata: two\rdata: three\n")
383            .unwrap();
384        assert_eq!(
385            first,
386            vec![
387                "data: one".to_string(),
388                "data: two".to_string(),
389                "data: three".to_string(),
390            ]
391        );
392        assert_eq!(decoder.finish().unwrap(), None);
393    }
394
395    #[test]
396    fn test_should_decode_utf8_split_across_chunks() {
397        let mut decoder = LineDecoder::default();
398        let snowman = "你好";
399        let bytes = snowman.as_bytes();
400        let first = decoder.push(&bytes[..2]).unwrap();
401        assert!(first.is_empty());
402        let second = decoder.push(&bytes[2..]).unwrap();
403        assert!(second.is_empty());
404        let third = decoder.push(b"\n").unwrap();
405        assert_eq!(third, vec![snowman.to_string()]);
406    }
407
408    #[test]
409    fn test_should_preserve_crlf_split_across_chunks() {
410        let mut decoder = LineDecoder::default();
411        assert_eq!(decoder.push(b"data: one\r").unwrap(), Vec::<String>::new());
412        assert_eq!(decoder.push(b"\n").unwrap(), vec!["data: one".to_string()]);
413        assert_eq!(decoder.finish().unwrap(), None);
414    }
415
416    #[test]
417    fn test_should_parse_empty_and_multiline_sse_data_fields() {
418        let mut pending = PendingSseEvent::default();
419        assert_eq!(pending.push_line("event: message").unwrap(), None);
420        assert_eq!(pending.push_line("data:").unwrap(), None);
421        assert_eq!(pending.push_line("data: hello").unwrap(), None);
422
423        let event = pending.push_line("").unwrap().unwrap();
424        assert_eq!(event.event.as_deref(), Some("message"));
425        assert_eq!(event.data, "\nhello");
426    }
427}