deepseek_api/async_impl/
json_stream.rs

1use anyhow::{Context, Error, Result};
2use futures_util::future;
3use futures_util::io::{AsyncBufReadExt, BufReader};
4use futures_util::stream::{Stream, StreamExt, TryStreamExt};
5use reqwest::Response;
6use serde::de::DeserializeOwned;
7use std::{
8    pin::Pin,
9    task::{Context as TaskContext, Poll},
10};
11
12/// A stream that processes Server-Sent Events (SSE) and deserializes JSON data.
13///
14/// The `JsonStream` struct wraps an asynchronous stream of lines from an HTTP response,
15/// where each line is expected to be a JSON object prefixed with "data: ". The stream
16/// terminates when it encounters a line with "data: [DONE]".
17///
18/// # Type Parameters
19///
20/// * `T`: The type of the deserialized JSON objects. It must implement `DeserializeOwned` and `Send`.
21///
22/// # Examples
23///
24/// ```rust
25/// use reqwest::Response;
26/// use serde::Deserialize;
27/// use futures_util::stream::StreamExt;
28/// use deepseek_api::json_stream::JsonStream;
29///
30/// #[derive(Debug, Deserialize)]
31/// struct MyData {
32///     id: String,
33///     value: u32,
34/// }
35///
36/// async fn process_response(response: Response) {
37///     let mut stream = JsonStream::<MyData>::new(response);
38///
39///     while let Some(item) = stream.next().await {
40///         match item {
41///             Ok(data) => println!("{:?}", data),
42///             Err(e) => eprintln!("Error: {:?}", e),
43///         }
44///     }
45/// }
46/// ```
47///
48/// # Errors
49///
50/// The stream yields `anyhow::Error` if:
51/// - The line does not start with "data: "
52/// - The JSON deserialization fails
53///
54/// # Methods
55///
56/// * `new(response: Response) -> Self`: Creates a new `JsonStream` from an HTTP response.
57///
58/// # Trait Implementations
59///
60/// * `Stream` for `JsonStream<T>`: Allows the `JsonStream` to be used as a stream of `Result<T, anyhow::Error>`.
61pub struct JsonStream<T> {
62    inner: Pin<Box<dyn Stream<Item = Result<T, Error>> + Send>>,
63}
64
65impl<T: DeserializeOwned + Send + 'static> JsonStream<T> {
66    pub fn new(response: Response) -> Self {
67        let byte_stream = response
68            .bytes_stream()
69            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
70
71        let async_read = byte_stream.into_async_read();
72        let processed = BufReader::new(async_read)
73            .lines()
74            .map_err(Error::from)
75            .take_while(|res| {
76                future::ready(match res {
77                    Ok(ref line) => line != "data: [DONE]",
78                    Err(_) => true,
79                })
80            })
81            .try_filter_map(|line| async move {
82                let line = line.trim();
83                if line.is_empty() || line == ": keep-alive" {
84                    return Ok(None);
85                }
86                let json = line
87                    .strip_prefix("data: ")
88                    .context("Missing 'data: ' prefix")?;
89                let obj = serde_json::from_str(json)?;
90                Ok(Some(obj))
91            });
92
93        JsonStream {
94            inner: Box::pin(processed),
95        }
96    }
97}
98
99impl<T: Unpin> Stream for JsonStream<T> {
100    type Item = Result<T, Error>;
101    fn poll_next(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
102        self.inner.as_mut().poll_next(cx)
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use bytes::Bytes;
110    use futures_util::stream::StreamExt;
111    use http::StatusCode;
112    use reqwest::Response;
113    use serde::{Deserialize, Serialize};
114
115    #[derive(Debug, PartialEq, Serialize, Deserialize)]
116    struct TestData {
117        id: String,
118        value: u32,
119    }
120
121    fn mock_response(data: Vec<Result<Bytes, reqwest::Error>>) -> Response {
122        let body = reqwest::Body::wrap_stream(futures_util::stream::iter(data));
123        let http_response = http::response::Response::builder()
124            .status(StatusCode::OK)
125            .body(body)
126            .unwrap();
127        Response::from(http_response)
128    }
129
130    #[tokio::test]
131    async fn test_normal_sse_stream() {
132        let data = vec![
133            Ok(Bytes::from("data: {\"id\":\"1\",\"value\":100}\n")),
134            Ok(Bytes::from("data: {\"id\":\"2\",\"value\":200}\n")),
135        ];
136        let response = mock_response(data);
137        let mut stream = JsonStream::<TestData>::new(response);
138
139        let mut results = vec![];
140        while let Some(item) = stream.next().await {
141            results.push(item.unwrap());
142        }
143
144        assert_eq!(
145            results,
146            vec![
147                TestData {
148                    id: "1".into(),
149                    value: 100
150                },
151                TestData {
152                    id: "2".into(),
153                    value: 200
154                }
155            ]
156        );
157    }
158
159    #[tokio::test]
160    async fn test_chunked_data() {
161        let data = vec![
162            Ok(Bytes::from("data: {\"id\":\"3\",\"")),
163            Ok(Bytes::from("value\":300}\n")),
164        ];
165        let response = mock_response(data);
166        let mut stream = JsonStream::<TestData>::new(response);
167
168        assert_eq!(
169            stream.next().await.unwrap().unwrap(),
170            TestData {
171                id: "3".into(),
172                value: 300
173            }
174        );
175        assert!(stream.next().await.is_none());
176    }
177
178    #[tokio::test]
179    async fn test_empty_lines_and_done() {
180        let data = vec![
181            Ok(Bytes::from("\n")),
182            Ok(Bytes::from("data: {\"id\":\"4\",\"value\":400}\n")),
183            Ok(Bytes::from("data: [DONE]\n")),
184            Ok(Bytes::from("data: {\"id\":\"5\",\"value\":500}\n")),
185        ];
186        let response = mock_response(data);
187        let mut stream = JsonStream::<TestData>::new(response);
188
189        let result = stream.next().await.unwrap().unwrap();
190        assert_eq!(
191            result,
192            TestData {
193                id: "4".into(),
194                value: 400
195            }
196        );
197        assert!(stream.next().await.is_none());
198    }
199
200    #[tokio::test]
201    async fn test_invalid_prefix() {
202        let data = vec![Ok(Bytes::from("invalid data\n"))];
203        let response = mock_response(data);
204        let mut stream = JsonStream::<TestData>::new(response);
205
206        let err = stream.next().await.unwrap().unwrap_err();
207        assert!(err.to_string().contains("Missing 'data: ' prefix"));
208    }
209
210    #[tokio::test]
211    async fn test_malformed_json() {
212        let data = vec![Ok(Bytes::from("data: {invalid}\n"))];
213        let response = mock_response(data);
214        let mut stream = JsonStream::<TestData>::new(response);
215
216        let err = stream.next().await.unwrap().unwrap_err();
217        assert!(err.is::<serde_json::Error>());
218    }
219}