Skip to main content

ai_provider_sdk/
streaming.rs

1//! SSE 流式解码层。把 HTTP 字节流解析为 Server-Sent Events。
2
3use async_stream::try_stream;
4use bytes::{Buf, Bytes, BytesMut};
5use futures_core::Stream;
6use futures_util::StreamExt;
7use serde_json::Value;
8
9use crate::error::{Error, Result};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct ServerSentEvent {
13    /// SSE `event` 字段;为空时表示默认事件类型。
14    pub event: Option<String>,
15    /// SSE `data` 字段拼接结果(多行 `data:` 以换行连接)。
16    pub data: String,
17    /// SSE `id` 字段;用于客户端断线续传语义。
18    pub id: Option<String>,
19    /// SSE `retry` 字段;服务端建议的重连间隔(毫秒)。
20    pub retry: Option<u64>,
21}
22
23pub struct SseStream {
24    response: reqwest::Response,
25}
26
27impl SseStream {
28    pub(crate) fn new(response: reqwest::Response) -> Self {
29        Self { response }
30    }
31
32    /// 将 HTTP 响应体按 SSE 协议解码为事件流。
33    ///
34    /// 行为边界:
35    /// - 遇到 `data: [DONE]` 立即结束流。
36    /// - 若 `data` 可解析为 JSON 且包含 `error` 字段,抛出 `Error::Stream`。
37    pub fn events(self) -> impl Stream<Item = Result<ServerSentEvent>> {
38        let mut chunks = self.response.bytes_stream();
39
40        try_stream! {
41            let mut decoder = SseDecoder::new();
42            while let Some(chunk) = chunks.next().await {
43                let chunk = chunk.map_err(|err| Error::Stream(err.to_string()))?;
44                for event in decoder.push(chunk)? {
45                    if event.data.starts_with("[DONE]") {
46                        return;
47                    }
48                    if let Ok(data) = serde_json::from_str::<Value>(&event.data) {
49                        if let Some(error) = data.get("error") {
50                            Err(Error::Stream(
51                                error
52                                    .get("message")
53                                    .and_then(Value::as_str)
54                                    .unwrap_or("An error occurred during streaming")
55                                    .to_string(),
56                            ))?;
57                        }
58                    }
59                    yield event;
60                }
61            }
62
63            for event in decoder.finish()? {
64                if event.data.starts_with("[DONE]") {
65                    return;
66                }
67                yield event;
68            }
69        }
70    }
71}
72
73#[derive(Debug, Default)]
74/// SSE 解码器状态机。
75///
76/// 状态字段在 `push` 过程中跨 chunk 保持,用于处理分包与多行 `data:`。
77pub struct SseDecoder {
78    bytes: BytesMut,
79    event: Option<String>,
80    data: Vec<String>,
81    last_event_id: Option<String>,
82    retry: Option<u64>,
83}
84
85impl SseDecoder {
86    /// 创建空状态解码器。
87    pub fn new() -> Self {
88        Self::default()
89    }
90
91    /// 推入一个字节分片并返回当前可产出的完整事件集合。
92    ///
93    /// 该方法可重复调用;未组成完整行/事件的数据会保留在内部缓冲区。
94    pub fn push(&mut self, chunk: Bytes) -> Result<Vec<ServerSentEvent>> {
95        self.bytes.extend_from_slice(&chunk);
96        let mut events = Vec::new();
97
98        while let Some(line) = self.next_line()? {
99            if let Some(event) = self.decode_line(&line) {
100                events.push(event);
101            }
102        }
103
104        Ok(events)
105    }
106
107    /// 在底层流结束时冲刷缓冲区,产出剩余事件。
108    pub fn finish(&mut self) -> Result<Vec<ServerSentEvent>> {
109        let mut events = Vec::new();
110        if !self.bytes.is_empty() {
111            let line = std::str::from_utf8(&self.bytes)
112                .map_err(|err| Error::Stream(err.to_string()))?
113                .to_string();
114            self.bytes.clear();
115            if let Some(event) = self.decode_line(&line) {
116                events.push(event);
117            }
118        }
119
120        if let Some(event) = self.flush_event() {
121            events.push(event);
122        }
123
124        Ok(events)
125    }
126
127    /// 从内部缓冲区读取下一行(兼容 `\n` 和 `\r\n`)。
128    fn next_line(&mut self) -> Result<Option<String>> {
129        let Some(pos) = self
130            .bytes
131            .iter()
132            .position(|byte| *byte == b'\n' || *byte == b'\r')
133        else {
134            return Ok(None);
135        };
136
137        let line = self.bytes.split_to(pos);
138        let newline = self.bytes.get_u8();
139        if newline == b'\r' && self.bytes.first() == Some(&b'\n') {
140            self.bytes.advance(1);
141        }
142
143        let line = std::str::from_utf8(&line)
144            .map_err(|err| Error::Stream(err.to_string()))?
145            .to_string();
146        Ok(Some(line))
147    }
148
149    /// 解码单行字段;仅在遇到事件分隔空行时返回完整事件。
150    fn decode_line(&mut self, line: &str) -> Option<ServerSentEvent> {
151        if line.is_empty() {
152            return self.flush_event();
153        }
154
155        if line.starts_with(':') {
156            return None;
157        }
158
159        let (field, value) = line.split_once(':').unwrap_or((line, ""));
160        let value = value.strip_prefix(' ').unwrap_or(value);
161
162        match field {
163            "event" => self.event = Some(value.to_string()),
164            "data" => self.data.push(value.to_string()),
165            "id" if !value.contains('\0') => self.last_event_id = Some(value.to_string()),
166            "retry" => self.retry = value.parse().ok(),
167            _ => {}
168        }
169
170        None
171    }
172
173    /// 将当前累积字段封装为事件,并清理可重置状态。
174    fn flush_event(&mut self) -> Option<ServerSentEvent> {
175        if self.event.is_none()
176            && self.data.is_empty()
177            && self.last_event_id.is_none()
178            && self.retry.is_none()
179        {
180            return None;
181        }
182
183        let event = ServerSentEvent {
184            event: self.event.take(),
185            data: self.data.join("\n"),
186            id: self.last_event_id.clone(),
187            retry: self.retry.take(),
188        };
189        self.data.clear();
190        Some(event)
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn decodes_complete_event() {
200        let mut decoder = SseDecoder::new();
201        let events = decoder
202            .push(Bytes::from_static(
203                b"event: ping\ndata: {\"x\":1}\nid: abc\n\n",
204            ))
205            .unwrap();
206
207        assert_eq!(
208            events,
209            vec![ServerSentEvent {
210                event: Some("ping".to_string()),
211                data: "{\"x\":1}".to_string(),
212                id: Some("abc".to_string()),
213                retry: None,
214            }]
215        );
216    }
217
218    #[test]
219    fn decodes_split_event_and_multi_data_lines() {
220        let mut decoder = SseDecoder::new();
221        assert!(decoder
222            .push(Bytes::from_static(b"data: a\n"))
223            .unwrap()
224            .is_empty());
225        let events = decoder.push(Bytes::from_static(b"data: b\n\n")).unwrap();
226
227        assert_eq!(events[0].data, "a\nb");
228    }
229
230    #[test]
231    fn keeps_last_event_id_across_events() {
232        let mut decoder = SseDecoder::new();
233        let events = decoder
234            .push(Bytes::from_static(b"id: one\ndata: a\n\ndata: b\n\n"))
235            .unwrap();
236
237        assert_eq!(events[0].id.as_deref(), Some("one"));
238        assert_eq!(events[1].id.as_deref(), Some("one"));
239    }
240}