deepseek_api/async_impl/
json_stream.rs1use anyhow::{Context, Error, Result};
2use futures_util::future;
3use futures_util::io::{AsyncBufReadExt, BufReader};
4use futures_util::stream::{Stream, StreamExt, TryStreamExt};
5use reqwest::Response;
6use serde::de::DeserializeOwned;
7use std::{
8 pin::Pin,
9 task::{Context as TaskContext, Poll},
10};
11
12pub struct JsonStream<T> {
62 inner: Pin<Box<dyn Stream<Item = Result<T, Error>> + Send>>,
63}
64
65impl<T: DeserializeOwned + Send + 'static> JsonStream<T> {
66 pub fn new(response: Response) -> Self {
67 let byte_stream = response
68 .bytes_stream()
69 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
70
71 let async_read = byte_stream.into_async_read();
72 let processed = BufReader::new(async_read)
73 .lines()
74 .map_err(Error::from)
75 .take_while(|res| {
76 future::ready(match res {
77 Ok(ref line) => line != "data: [DONE]",
78 Err(_) => true,
79 })
80 })
81 .try_filter_map(|line| async move {
82 let line = line.trim();
83 if line.is_empty() || line == ": keep-alive" {
84 return Ok(None);
85 }
86 let json = line
87 .strip_prefix("data: ")
88 .context("Missing 'data: ' prefix")?;
89 let obj = serde_json::from_str(json)?;
90 Ok(Some(obj))
91 });
92
93 JsonStream {
94 inner: Box::pin(processed),
95 }
96 }
97}
98
99impl<T: Unpin> Stream for JsonStream<T> {
100 type Item = Result<T, Error>;
101 fn poll_next(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
102 self.inner.as_mut().poll_next(cx)
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use bytes::Bytes;
110 use futures_util::stream::StreamExt;
111 use http::StatusCode;
112 use reqwest::Response;
113 use serde::{Deserialize, Serialize};
114
115 #[derive(Debug, PartialEq, Serialize, Deserialize)]
116 struct TestData {
117 id: String,
118 value: u32,
119 }
120
121 fn mock_response(data: Vec<Result<Bytes, reqwest::Error>>) -> Response {
122 let body = reqwest::Body::wrap_stream(futures_util::stream::iter(data));
123 let http_response = http::response::Response::builder()
124 .status(StatusCode::OK)
125 .body(body)
126 .unwrap();
127 Response::from(http_response)
128 }
129
130 #[tokio::test]
131 async fn test_normal_sse_stream() {
132 let data = vec![
133 Ok(Bytes::from("data: {\"id\":\"1\",\"value\":100}\n")),
134 Ok(Bytes::from("data: {\"id\":\"2\",\"value\":200}\n")),
135 ];
136 let response = mock_response(data);
137 let mut stream = JsonStream::<TestData>::new(response);
138
139 let mut results = vec![];
140 while let Some(item) = stream.next().await {
141 results.push(item.unwrap());
142 }
143
144 assert_eq!(
145 results,
146 vec![
147 TestData {
148 id: "1".into(),
149 value: 100
150 },
151 TestData {
152 id: "2".into(),
153 value: 200
154 }
155 ]
156 );
157 }
158
159 #[tokio::test]
160 async fn test_chunked_data() {
161 let data = vec![
162 Ok(Bytes::from("data: {\"id\":\"3\",\"")),
163 Ok(Bytes::from("value\":300}\n")),
164 ];
165 let response = mock_response(data);
166 let mut stream = JsonStream::<TestData>::new(response);
167
168 assert_eq!(
169 stream.next().await.unwrap().unwrap(),
170 TestData {
171 id: "3".into(),
172 value: 300
173 }
174 );
175 assert!(stream.next().await.is_none());
176 }
177
178 #[tokio::test]
179 async fn test_empty_lines_and_done() {
180 let data = vec![
181 Ok(Bytes::from("\n")),
182 Ok(Bytes::from("data: {\"id\":\"4\",\"value\":400}\n")),
183 Ok(Bytes::from("data: [DONE]\n")),
184 Ok(Bytes::from("data: {\"id\":\"5\",\"value\":500}\n")),
185 ];
186 let response = mock_response(data);
187 let mut stream = JsonStream::<TestData>::new(response);
188
189 let result = stream.next().await.unwrap().unwrap();
190 assert_eq!(
191 result,
192 TestData {
193 id: "4".into(),
194 value: 400
195 }
196 );
197 assert!(stream.next().await.is_none());
198 }
199
200 #[tokio::test]
201 async fn test_invalid_prefix() {
202 let data = vec![Ok(Bytes::from("invalid data\n"))];
203 let response = mock_response(data);
204 let mut stream = JsonStream::<TestData>::new(response);
205
206 let err = stream.next().await.unwrap().unwrap_err();
207 assert!(err.to_string().contains("Missing 'data: ' prefix"));
208 }
209
210 #[tokio::test]
211 async fn test_malformed_json() {
212 let data = vec![Ok(Bytes::from("data: {invalid}\n"))];
213 let response = mock_response(data);
214 let mut stream = JsonStream::<TestData>::new(response);
215
216 let err = stream.next().await.unwrap().unwrap_err();
217 assert!(err.is::<serde_json::Error>());
218 }
219}