deepseek_api/async_impl/
json_stream.rs1use anyhow::{Context, Result};
2use futures_util::{Stream, TryStreamExt};
3use reqwest::Response;
4use serde::de::DeserializeOwned;
5use std::{
6 pin::Pin,
7 task::{Context as TaskContext, Poll},
8};
9use tokio::io::{AsyncBufReadExt, BufReader};
10use tokio_stream::{wrappers::LinesStream, StreamExt};
11use tokio_util::io::StreamReader;
12
13pub struct JsonStream<T> {
63 inner: Pin<Box<dyn Stream<Item = Result<T, anyhow::Error>> + Send>>,
64}
65
66impl<T: DeserializeOwned + Send + 'static> JsonStream<T> {
67 pub fn new(response: Response) -> Self {
68 let byte_stream = response
69 .bytes_stream()
70 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e));
71 let async_reader = StreamReader::new(byte_stream);
72
73 let line_stream = LinesStream::new(BufReader::new(async_reader).lines());
74
75 let processed_stream = line_stream
76 .map_ok(|line| line.trim().to_string())
77 .map_err(anyhow::Error::from) .take_while(|line| line.as_ref().is_ok_and(|data| data != "data: [DONE]")) .try_filter_map(|line| async move {
80 if line.is_empty() || line == ": keep-alive" {
81 return Ok(None);
82 }
83
84 let json_str = line
85 .strip_prefix("data: ")
86 .context("Missing 'data: ' prefix")?;
87
88 serde_json::from_str(json_str)
89 .map(Some)
90 .map_err(anyhow::Error::from)
91 });
92
93 JsonStream {
94 inner: Box::pin(processed_stream),
95 }
96 }
97}
98
99impl<T: Unpin> Stream for JsonStream<T> {
100 type Item = Result<T, anyhow::Error>;
101
102 fn poll_next(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
103 self.inner.as_mut().poll_next(cx)
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use bytes::Bytes;
111 use http::StatusCode;
112 use reqwest::Response;
113 use serde::{Deserialize, Serialize};
114 use tokio_stream::{self as stream, StreamExt};
115
116 #[derive(Debug, PartialEq, Serialize, Deserialize)]
117 struct TestData {
118 id: String,
119 value: u32,
120 }
121
122 fn mock_response(data: Vec<Result<Bytes, reqwest::Error>>) -> Response {
123 let body = reqwest::Body::wrap_stream(stream::iter(data));
124 let http_response = http::response::Response::builder()
125 .status(StatusCode::OK)
126 .body(body)
127 .unwrap();
128 Response::from(http_response)
129 }
130
131 #[tokio::test]
132 async fn test_normal_sse_stream() {
133 let data = vec![
134 Ok(Bytes::from("data: {\"id\":\"1\",\"value\":100}\n")),
135 Ok(Bytes::from("data: {\"id\":\"2\",\"value\":200}\n")),
136 ];
137 let response = mock_response(data);
138 let mut stream = JsonStream::<TestData>::new(response);
139
140 let mut results = vec![];
141 while let Some(item) = stream.next().await {
142 results.push(item.unwrap());
143 }
144
145 assert_eq!(
146 results,
147 vec![
148 TestData {
149 id: "1".into(),
150 value: 100
151 },
152 TestData {
153 id: "2".into(),
154 value: 200
155 }
156 ]
157 );
158 }
159
160 #[tokio::test]
161 async fn test_chunked_data() {
162 let data = vec![
163 Ok(Bytes::from("data: {\"id\":\"3\",\"")),
164 Ok(Bytes::from("value\":300}\n")),
165 ];
166 let response = mock_response(data);
167 let mut stream = JsonStream::<TestData>::new(response);
168
169 assert_eq!(
170 stream.next().await.unwrap().unwrap(),
171 TestData {
172 id: "3".into(),
173 value: 300
174 }
175 );
176 assert!(stream.next().await.is_none());
177 }
178
179 #[tokio::test]
180 async fn test_empty_lines_and_done() {
181 let data = vec![
182 Ok(Bytes::from("\n")),
183 Ok(Bytes::from("data: {\"id\":\"4\",\"value\":400}\n")),
184 Ok(Bytes::from("data: [DONE]\n")),
185 Ok(Bytes::from("data: {\"id\":\"5\",\"value\":500}\n")),
186 ];
187 let response = mock_response(data);
188 let mut stream = JsonStream::<TestData>::new(response);
189
190 let result = stream.next().await.unwrap().unwrap();
191 assert_eq!(
192 result,
193 TestData {
194 id: "4".into(),
195 value: 400
196 }
197 );
198 assert!(stream.next().await.is_none());
199 }
200
201 #[tokio::test]
202 async fn test_invalid_prefix() {
203 let data = vec![Ok(Bytes::from("invalid data\n"))];
204 let response = mock_response(data);
205 let mut stream = JsonStream::<TestData>::new(response);
206
207 let err = stream.next().await.unwrap().unwrap_err();
208 assert!(err.to_string().contains("Missing 'data: ' prefix"));
209 }
210
211 #[tokio::test]
212 async fn test_malformed_json() {
213 let data = vec![Ok(Bytes::from("data: {invalid}\n"))];
214 let response = mock_response(data);
215 let mut stream = JsonStream::<TestData>::new(response);
216
217 let err = stream.next().await.unwrap().unwrap_err();
218 assert!(err.is::<serde_json::Error>());
219 }
220}