deepseek_api/async_impl/
json_stream.rs

1use anyhow::{Context, Result};
2use futures_util::{Stream, TryStreamExt};
3use reqwest::Response;
4use serde::de::DeserializeOwned;
5use std::{
6    pin::Pin,
7    task::{Context as TaskContext, Poll},
8};
9use tokio::io::{AsyncBufReadExt, BufReader};
10use tokio_stream::{wrappers::LinesStream, StreamExt};
11use tokio_util::io::StreamReader;
12
13/// A stream that processes Server-Sent Events (SSE) and deserializes JSON data.
14///
15/// The `JsonStream` struct wraps an asynchronous stream of lines from an HTTP response,
16/// where each line is expected to be a JSON object prefixed with "data: ". The stream
17/// terminates when it encounters a line with "data: [DONE]".
18///
19/// # Type Parameters
20///
21/// * `T`: The type of the deserialized JSON objects. It must implement `DeserializeOwned` and `Send`.
22///
23/// # Examples
24///
25/// ```rust
26/// use reqwest::Response;
27/// use serde::Deserialize;
28/// use tokio_stream::StreamExt;
29/// use deepseek_api::json_stream::JsonStream;
30///
31/// #[derive(Debug, Deserialize)]
32/// struct MyData {
33///     id: String,
34///     value: u32,
35/// }
36///
37/// async fn process_response(response: Response) {
38///     let mut stream = JsonStream::<MyData>::new(response);
39///
40///     while let Some(item) = stream.next().await {
41///         match item {
42///             Ok(data) => println!("{:?}", data),
43///             Err(e) => eprintln!("Error: {:?}", e),
44///         }
45///     }
46/// }
47/// ```
48///
49/// # Errors
50///
51/// The stream yields `anyhow::Error` if:
52/// - The line does not start with "data: "
53/// - The JSON deserialization fails
54///
55/// # Methods
56///
57/// * `new(response: Response) -> Self`: Creates a new `JsonStream` from an HTTP response.
58///
59/// # Trait Implementations
60///
61/// * `Stream` for `JsonStream<T>`: Allows the `JsonStream` to be used as a stream of `Result<T, anyhow::Error>`.
62pub struct JsonStream<T> {
63    inner: Pin<Box<dyn Stream<Item = Result<T, anyhow::Error>> + Send>>,
64}
65
66impl<T: DeserializeOwned + Send + 'static> JsonStream<T> {
67    pub fn new(response: Response) -> Self {
68        let byte_stream = response
69            .bytes_stream()
70            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
71        let async_reader = StreamReader::new(byte_stream);
72
73        let line_stream = LinesStream::new(BufReader::new(async_reader).lines());
74
75        let processed_stream = line_stream
76            .map_ok(|line| line.trim().to_string())
77            .map_err(anyhow::Error::from) // 将 std::io::Error 转换为 anyhow::Error
78            .take_while(|line| line.as_ref().is_ok_and(|data| data != "data: [DONE]")) // 遇到 "data: [DONE]" 终止流
79            .try_filter_map(|line| async move {
80                if line.is_empty() || line == ": keep-alive" {
81                    return Ok(None);
82                }
83
84                let json_str = line
85                    .strip_prefix("data: ")
86                    .context("Missing 'data: ' prefix")?;
87
88                serde_json::from_str(json_str)
89                    .map(Some)
90                    .map_err(anyhow::Error::from)
91            });
92
93        JsonStream {
94            inner: Box::pin(processed_stream),
95        }
96    }
97}
98
99impl<T: Unpin> Stream for JsonStream<T> {
100    type Item = Result<T, anyhow::Error>;
101
102    fn poll_next(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
103        self.inner.as_mut().poll_next(cx)
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use bytes::Bytes;
111    use http::StatusCode;
112    use reqwest::Response;
113    use serde::{Deserialize, Serialize};
114    use tokio_stream::{self as stream, StreamExt};
115
116    #[derive(Debug, PartialEq, Serialize, Deserialize)]
117    struct TestData {
118        id: String,
119        value: u32,
120    }
121
122    fn mock_response(data: Vec<Result<Bytes, reqwest::Error>>) -> Response {
123        let body = reqwest::Body::wrap_stream(stream::iter(data));
124        let http_response = http::response::Response::builder()
125            .status(StatusCode::OK)
126            .body(body)
127            .unwrap();
128        Response::from(http_response)
129    }
130
131    #[tokio::test]
132    async fn test_normal_sse_stream() {
133        let data = vec![
134            Ok(Bytes::from("data: {\"id\":\"1\",\"value\":100}\n")),
135            Ok(Bytes::from("data: {\"id\":\"2\",\"value\":200}\n")),
136        ];
137        let response = mock_response(data);
138        let mut stream = JsonStream::<TestData>::new(response);
139
140        let mut results = vec![];
141        while let Some(item) = stream.next().await {
142            results.push(item.unwrap());
143        }
144
145        assert_eq!(
146            results,
147            vec![
148                TestData {
149                    id: "1".into(),
150                    value: 100
151                },
152                TestData {
153                    id: "2".into(),
154                    value: 200
155                }
156            ]
157        );
158    }
159
160    #[tokio::test]
161    async fn test_chunked_data() {
162        let data = vec![
163            Ok(Bytes::from("data: {\"id\":\"3\",\"")),
164            Ok(Bytes::from("value\":300}\n")),
165        ];
166        let response = mock_response(data);
167        let mut stream = JsonStream::<TestData>::new(response);
168
169        assert_eq!(
170            stream.next().await.unwrap().unwrap(),
171            TestData {
172                id: "3".into(),
173                value: 300
174            }
175        );
176        assert!(stream.next().await.is_none());
177    }
178
179    #[tokio::test]
180    async fn test_empty_lines_and_done() {
181        let data = vec![
182            Ok(Bytes::from("\n")),
183            Ok(Bytes::from("data: {\"id\":\"4\",\"value\":400}\n")),
184            Ok(Bytes::from("data: [DONE]\n")),
185            Ok(Bytes::from("data: {\"id\":\"5\",\"value\":500}\n")),
186        ];
187        let response = mock_response(data);
188        let mut stream = JsonStream::<TestData>::new(response);
189
190        let result = stream.next().await.unwrap().unwrap();
191        assert_eq!(
192            result,
193            TestData {
194                id: "4".into(),
195                value: 400
196            }
197        );
198        assert!(stream.next().await.is_none());
199    }
200
201    #[tokio::test]
202    async fn test_invalid_prefix() {
203        let data = vec![Ok(Bytes::from("invalid data\n"))];
204        let response = mock_response(data);
205        let mut stream = JsonStream::<TestData>::new(response);
206
207        let err = stream.next().await.unwrap().unwrap_err();
208        assert!(err.to_string().contains("Missing 'data: ' prefix"));
209    }
210
211    #[tokio::test]
212    async fn test_malformed_json() {
213        let data = vec![Ok(Bytes::from("data: {invalid}\n"))];
214        let response = mock_response(data);
215        let mut stream = JsonStream::<TestData>::new(response);
216
217        let err = stream.next().await.unwrap().unwrap_err();
218        assert!(err.is::<serde_json::Error>());
219    }
220}