1use std::fmt;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use async_stream::try_stream;
6use futures_util::{Stream, StreamExt};
7
8use crate::error::{Result, SerializationError, StreamError};
9use crate::response_meta::ResponseMeta;
10
11#[derive(Debug, Default, Clone)]
13pub struct LineDecoder {
14 buffer: Vec<u8>,
15}
16
17impl LineDecoder {
18 pub fn push(&mut self, chunk: &[u8]) -> Result<Vec<String>> {
24 self.buffer.extend_from_slice(chunk);
25 let mut lines = Vec::new();
26 let mut start = 0usize;
27 let mut index = 0usize;
28
29 while index < self.buffer.len() {
30 match self.buffer[index] {
31 b'\n' => {
32 let end = if index > start && self.buffer[index - 1] == b'\r' {
33 index - 1
34 } else {
35 index
36 };
37 lines.push(bytes_to_string(&self.buffer[start..end])?);
38 start = index + 1;
39 }
40 b'\r' => {
41 let end = index;
42 if index + 1 < self.buffer.len() {
43 if self.buffer[index + 1] == b'\n' {
44 index += 1;
45 lines.push(bytes_to_string(&self.buffer[start..end])?);
46 start = index + 1;
47 } else {
48 lines.push(bytes_to_string(&self.buffer[start..end])?);
49 start = index + 1;
50 }
51 } else {
52 break;
53 }
54 }
55 _ => {}
56 }
57 index += 1;
58 }
59
60 if start > 0 {
61 self.buffer.drain(0..start);
62 }
63
64 Ok(lines)
65 }
66
67 pub fn finish(&mut self) -> Result<Option<String>> {
73 if self.buffer.is_empty() {
74 return Ok(None);
75 }
76
77 let line = if self.buffer.last() == Some(&b'\r') {
78 let length = self.buffer.len() - 1;
79 bytes_to_string(&self.buffer[..length])?
80 } else {
81 bytes_to_string(&self.buffer)?
82 };
83 self.buffer.clear();
84 Ok(Some(line))
85 }
86}
87
88fn bytes_to_string(bytes: &[u8]) -> Result<String> {
89 String::from_utf8(bytes.to_vec()).map_err(|error| {
90 SerializationError::new(format!("SSE 行解码失败,收到非法 UTF-8: {error}")).into()
91 })
92}
93
94#[derive(Debug, Clone, PartialEq, Eq)]
96pub struct SseEvent {
97 pub event: Option<String>,
99 pub data: String,
101 pub id: Option<String>,
103 pub retry: Option<u64>,
105}
106
107#[derive(Debug, Default)]
108struct PendingSseEvent {
109 event: Option<String>,
110 data: Vec<String>,
111 id: Option<String>,
112 retry: Option<u64>,
113}
114
115impl PendingSseEvent {
116 fn push_line(&mut self, line: &str) -> Result<Option<SseEvent>> {
117 if line.is_empty() {
118 if self.event.is_none()
119 && self.data.is_empty()
120 && self.id.is_none()
121 && self.retry.is_none()
122 {
123 return Ok(None);
124 }
125
126 let event = SseEvent {
127 event: self.event.take(),
128 data: self.data.join("\n"),
129 id: self.id.take(),
130 retry: self.retry.take(),
131 };
132 self.data.clear();
133 return Ok(Some(event));
134 }
135
136 if line.starts_with(':') {
137 return Ok(None);
138 }
139
140 let (field, value) = match line.split_once(':') {
141 Some((field, value)) => (field, value.strip_prefix(' ').unwrap_or(value)),
142 None => (line, ""),
143 };
144
145 match field {
146 "event" => self.event = Some(value.to_owned()),
147 "data" => self.data.push(value.to_owned()),
148 "id" => self.id = Some(value.to_owned()),
149 "retry" => {
150 self.retry = value.parse::<u64>().ok();
151 }
152 _ => {}
153 }
154
155 Ok(None)
156 }
157
158 fn flush(&mut self) -> Option<SseEvent> {
159 if self.event.is_none() && self.data.is_empty() && self.id.is_none() && self.retry.is_none()
160 {
161 return None;
162 }
163
164 let event = SseEvent {
165 event: self.event.take(),
166 data: self.data.join("\n"),
167 id: self.id.take(),
168 retry: self.retry.take(),
169 };
170 self.data.clear();
171 Some(event)
172 }
173}
174
175pub struct RawSseStream {
177 inner: Pin<Box<dyn Stream<Item = Result<SseEvent>> + Send>>,
178 meta: ResponseMeta,
179}
180
181impl RawSseStream {
182 #[allow(clippy::collapsible_if, tail_expr_drop_order)]
184 pub fn new(response: reqwest::Response, meta: ResponseMeta) -> Self {
185 let stream = try_stream! {
186 let mut decoder = LineDecoder::default();
187 let mut pending = PendingSseEvent::default();
188 let mut byte_stream = response.bytes_stream();
189
190 while let Some(chunk) = byte_stream.next().await {
191 let chunk = chunk.map_err(|error| StreamError::new(format!("读取 SSE 数据流失败: {error}")))?;
192 for line in decoder.push(&chunk)? {
193 if let Some(event) = pending.push_line(&line)? {
194 yield event;
195 }
196 }
197 }
198
199 if let Some(line) = decoder.finish()? {
200 if let Some(event) = pending.push_line(&line)? {
201 yield event;
202 }
203 }
204
205 if let Some(event) = pending.flush() {
206 yield event;
207 }
208 };
209
210 Self {
211 inner: Box::pin(stream),
212 meta,
213 }
214 }
215
216 pub fn meta(&self) -> &ResponseMeta {
218 &self.meta
219 }
220
221 #[allow(tail_expr_drop_order)]
223 pub fn into_typed<T>(self) -> SseStream<T>
224 where
225 T: serde::de::DeserializeOwned + Send + 'static,
226 {
227 let meta = self.meta.clone();
228 let stream = try_stream! {
229 let mut raw = self;
230 while let Some(event) = raw.next().await {
231 let event = event?;
232 if event.data == "[DONE]" {
233 break;
234 }
235 let item = serde_json::from_str::<T>(&event.data).map_err(|error| {
236 StreamError::new(format!("解析 SSE JSON 事件失败: {error}; payload={}", event.data))
237 })?;
238 yield item;
239 }
240 };
241
242 SseStream {
243 inner: Box::pin(stream),
244 meta,
245 }
246 }
247}
248
249impl fmt::Debug for RawSseStream {
250 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251 f.debug_struct("RawSseStream")
252 .field("meta", &self.meta)
253 .finish()
254 }
255}
256
257impl Stream for RawSseStream {
258 type Item = Result<SseEvent>;
259
260 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
261 self.get_mut().inner.as_mut().poll_next(cx)
262 }
263}
264
265#[cfg(test)]
266mod property_tests {
267 use proptest::prelude::*;
268
269 use super::LineDecoder;
270
271 #[derive(Debug, Clone, Copy)]
272 enum Separator {
273 Lf,
274 Cr,
275 CrLf,
276 }
277
278 impl Separator {
279 fn as_str(self) -> &'static str {
280 match self {
281 Self::Lf => "\n",
282 Self::Cr => "\r",
283 Self::CrLf => "\r\n",
284 }
285 }
286 }
287
288 fn separator_strategy() -> impl Strategy<Value = Separator> {
289 prop_oneof![
290 Just(Separator::Lf),
291 Just(Separator::Cr),
292 Just(Separator::CrLf),
293 ]
294 }
295
296 proptest! {
297 #[test]
298 fn line_decoder_preserves_lines_across_arbitrary_chunking(
299 lines in prop::collection::vec("[^\r\n]{0,16}", 1..8),
300 separator in separator_strategy(),
301 chunk_sizes in prop::collection::vec(1usize..8, 1..32),
302 ) {
303 let mut payload = String::new();
304 for line in lines.iter() {
305 payload.push_str(line);
306 payload.push_str(separator.as_str());
307 }
308
309 let mut decoder = LineDecoder::default();
310 let mut decoded = Vec::new();
311 let bytes = payload.as_bytes();
312 let mut offset = 0usize;
313
314 for chunk_size in chunk_sizes {
315 if offset >= bytes.len() {
316 break;
317 }
318 let end = (offset + chunk_size).min(bytes.len());
319 decoded.extend(decoder.push(&bytes[offset..end]).unwrap());
320 offset = end;
321 }
322
323 if offset < bytes.len() {
324 decoded.extend(decoder.push(&bytes[offset..]).unwrap());
325 }
326
327 if let Some(tail) = decoder.finish().unwrap() {
328 decoded.push(tail);
329 }
330 prop_assert_eq!(decoded, lines);
331 }
332 }
333
334 #[test]
335 fn line_decoder_flushes_final_partial_line() {
336 let mut decoder = LineDecoder::default();
337 assert!(decoder.push(b"event: response.created").unwrap().is_empty());
338 assert_eq!(
339 decoder.finish().unwrap(),
340 Some("event: response.created".into())
341 );
342 }
343}
344
345pub struct SseStream<T> {
347 inner: Pin<Box<dyn Stream<Item = Result<T>> + Send>>,
348 meta: ResponseMeta,
349}
350
351impl<T> SseStream<T> {
352 pub fn meta(&self) -> &ResponseMeta {
354 &self.meta
355 }
356}
357
358impl<T> Stream for SseStream<T> {
359 type Item = Result<T>;
360
361 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
362 self.get_mut().inner.as_mut().poll_next(cx)
363 }
364}
365
366impl<T> fmt::Debug for SseStream<T> {
367 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
368 f.debug_struct("SseStream")
369 .field("meta", &self.meta)
370 .finish()
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::{LineDecoder, PendingSseEvent};
377
378 #[test]
379 fn test_should_decode_lines_for_mixed_newlines() {
380 let mut decoder = LineDecoder::default();
381 let first = decoder
382 .push(b"data: one\r\ndata: two\rdata: three\n")
383 .unwrap();
384 assert_eq!(
385 first,
386 vec![
387 "data: one".to_string(),
388 "data: two".to_string(),
389 "data: three".to_string(),
390 ]
391 );
392 assert_eq!(decoder.finish().unwrap(), None);
393 }
394
395 #[test]
396 fn test_should_decode_utf8_split_across_chunks() {
397 let mut decoder = LineDecoder::default();
398 let snowman = "你好";
399 let bytes = snowman.as_bytes();
400 let first = decoder.push(&bytes[..2]).unwrap();
401 assert!(first.is_empty());
402 let second = decoder.push(&bytes[2..]).unwrap();
403 assert!(second.is_empty());
404 let third = decoder.push(b"\n").unwrap();
405 assert_eq!(third, vec![snowman.to_string()]);
406 }
407
408 #[test]
409 fn test_should_preserve_crlf_split_across_chunks() {
410 let mut decoder = LineDecoder::default();
411 assert_eq!(decoder.push(b"data: one\r").unwrap(), Vec::<String>::new());
412 assert_eq!(decoder.push(b"\n").unwrap(), vec!["data: one".to_string()]);
413 assert_eq!(decoder.finish().unwrap(), None);
414 }
415
416 #[test]
417 fn test_should_parse_empty_and_multiline_sse_data_fields() {
418 let mut pending = PendingSseEvent::default();
419 assert_eq!(pending.push_line("event: message").unwrap(), None);
420 assert_eq!(pending.push_line("data:").unwrap(), None);
421 assert_eq!(pending.push_line("data: hello").unwrap(), None);
422
423 let event = pending.push_line("").unwrap().unwrap();
424 assert_eq!(event.event.as_deref(), Some("message"));
425 assert_eq!(event.data, "\nhello");
426 }
427}