reqwest_streams/
json_stream.rs1use crate::error::StreamBodyKind;
2use crate::json_array_codec::JsonArrayCodec;
3use crate::{StreamBodyError, StreamBodyResult};
4use async_trait::*;
5use futures::{StreamExt, TryStreamExt};
6use serde::Deserialize;
7use tokio_util::io::StreamReader;
8
9#[async_trait]
12pub trait JsonStreamResponse {
13 fn json_array_stream<'a, 'b, T>(self, max_obj_len: usize) -> impl futures::Stream<Item = StreamBodyResult<T>> + Send + 'b
43 where
44 T: for<'de> Deserialize<'de> + Send + 'b;
45
46 fn json_array_stream_with_capacity<'a, 'b, T>(
79 self,
80 max_obj_len: usize,
81 buf_capacity: usize,
82 ) -> impl futures::Stream<Item = StreamBodyResult<T>> + Send + 'b
83 where
84 T: for<'de> Deserialize<'de> + Send + 'b;
85
86 fn json_nl_stream<'a, 'b, T>(self, max_obj_len: usize) -> impl futures::Stream<Item = StreamBodyResult<T>> + Send + 'b
116 where
117 T: for<'de> Deserialize<'de> + Send + 'b;
118
119 fn json_nl_stream_with_capacity<'a, 'b, T>(
150 self,
151 max_obj_len: usize,
152 buf_capacity: usize,
153 ) -> impl futures::Stream<Item = StreamBodyResult<T>> + Send + 'b
154 where
155 T: for<'de> Deserialize<'de> + Send + 'b;
156}
157
158const INITIAL_CAPACITY: usize = 8 * 1024;
160
161#[async_trait]
162impl JsonStreamResponse for reqwest::Response {
163 fn json_nl_stream<'a, 'b, T>(self, max_obj_len: usize) -> impl futures::Stream<Item = StreamBodyResult<T>> + Send + 'b
164 where
165 T: for<'de> Deserialize<'de> + Send + 'b,
166 {
167 self.json_nl_stream_with_capacity(max_obj_len, INITIAL_CAPACITY)
168 }
169
170 fn json_nl_stream_with_capacity<'a, 'b, T>(
171 self,
172 max_obj_len: usize,
173 buf_capacity: usize,
174 ) -> impl futures::Stream<Item = StreamBodyResult<T>> + Send + 'b
175 where
176 T: for<'de> Deserialize<'de> + Send + 'b
177 {
178 let reader = StreamReader::new(
179 self.bytes_stream()
180 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err)),
181 );
182
183 let codec = tokio_util::codec::LinesCodec::new_with_max_length(max_obj_len);
184 let frames_reader =
185 tokio_util::codec::FramedRead::with_capacity(reader, codec, buf_capacity);
186
187 frames_reader
188 .into_stream()
189 .map(|frame_res| match frame_res {
190 Ok(frame_str) => serde_json::from_str(frame_str.as_str()).map_err(|err| {
191 StreamBodyError::new(StreamBodyKind::CodecError, Some(Box::new(err)), None)
192 }),
193 Err(err) => Err(StreamBodyError::new(
194 StreamBodyKind::CodecError,
195 Some(Box::new(err)),
196 None,
197 )),
198 })
199 }
200
201 fn json_array_stream<'a, 'b, T>(self, max_obj_len: usize) -> impl futures::Stream<Item = StreamBodyResult<T>> + Send + 'b
202 where
203 T: for<'de> Deserialize<'de> + Send + 'b,
204 {
205 self.json_array_stream_with_capacity(max_obj_len, INITIAL_CAPACITY)
206 }
207
208 fn json_array_stream_with_capacity<'a, 'b, T>(
209 self,
210 max_obj_len: usize,
211 buf_capacity: usize,
212 ) -> impl futures::Stream<Item = StreamBodyResult<T>> + Send + 'b
213 where
214 T: for<'de> Deserialize<'de> + Send + 'b,
215 {
216 let reader = StreamReader::new(
217 self.bytes_stream()
218 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err)),
219 );
220
221 let codec = JsonArrayCodec::<T>::new_with_max_length(max_obj_len);
223 let frames_reader =
224 tokio_util::codec::FramedRead::with_capacity(reader, codec, buf_capacity);
225
226 frames_reader.into_stream()
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::test_client::*;
234 use axum::{routing::*, Router};
235 use axum_streams::*;
236 use futures::stream;
237 use serde::Serialize;
238
239 #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
240 struct MyTestStructure {
241 some_test_field: String,
242 test_arr: Vec<MyChildTest>,
243 }
244
245 #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
246 struct MyChildTest {
247 test_field: String,
248 }
249
250 fn generate_test_structures() -> Vec<MyTestStructure> {
251 vec![
252 MyTestStructure {
253 some_test_field: "TestValue".to_string(),
254 test_arr: vec![
255 MyChildTest {
256 test_field: "TestValue1".to_string()
257 },
258 MyChildTest {
259 test_field: "TestValue2".to_string()
260 }
261 ]
262 .iter()
263 .cloned()
264 .collect()
265 };
266 100
267 ]
268 }
269
270 #[tokio::test]
271 async fn deserialize_json_array_stream() {
272 let test_stream_vec = generate_test_structures();
273
274 let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
275
276 let app = Router::new().route("/", get(|| async { StreamBodyAs::json_array(test_stream) }));
277
278 let client = TestClient::new(app).await;
279
280 let res = client
281 .get("/")
282 .send()
283 .await
284 .unwrap()
285 .json_array_stream::<MyTestStructure>(1024);
286 let items: Vec<MyTestStructure> = res.try_collect().await.unwrap();
287
288 assert_eq!(items, test_stream_vec);
289 }
290
291 #[tokio::test]
292 async fn deserialize_json_array_stream_check_max_len() {
293 let test_stream_vec = generate_test_structures();
294
295 let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
296
297 let app = Router::new().route("/", get(|| async { StreamBodyAs::json_array(test_stream) }));
298
299 let client = TestClient::new(app).await;
300
301 let res = client
302 .get("/")
303 .send()
304 .await
305 .unwrap()
306 .json_array_stream::<MyTestStructure>(10);
307 res.try_collect::<Vec<MyTestStructure>>()
308 .await
309 .expect_err("MaxLenReachedError");
310 }
311
312 #[tokio::test]
313 async fn deserialize_json_array_stream_check_len_capacity() {
314 let test_stream_vec = generate_test_structures();
315
316 let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
317
318 let app = Router::new().route("/", get(|| async { StreamBodyAs::json_array(test_stream) }));
319
320 let client = TestClient::new(app).await;
321
322 let res = client
323 .get("/")
324 .send()
325 .await
326 .unwrap()
327 .json_array_stream_with_capacity::<MyTestStructure>(1024, 50);
328
329 let items: Vec<MyTestStructure> = res.try_collect().await.unwrap();
330
331 assert_eq!(items, test_stream_vec);
332 }
333
334 #[tokio::test]
335 async fn deserialize_json_nl_stream() {
336 let test_stream_vec = generate_test_structures();
337
338 let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
339
340 let app = Router::new().route("/", get(|| async { StreamBodyAs::json_nl(test_stream) }));
341
342 let client = TestClient::new(app).await;
343
344 let res = client
345 .get("/")
346 .send()
347 .await
348 .unwrap()
349 .json_nl_stream::<MyTestStructure>(1024);
350 let items: Vec<MyTestStructure> = res.try_collect().await.unwrap();
351
352 assert_eq!(items, test_stream_vec);
353 }
354
355 #[tokio::test]
356 async fn deserialize_json_nl_stream_check_max_len() {
357 let test_stream_vec = generate_test_structures();
358
359 let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
360
361 let app = Router::new().route("/", get(|| async { StreamBodyAs::json_nl(test_stream) }));
362
363 let client = TestClient::new(app).await;
364
365 let res = client
366 .get("/")
367 .send()
368 .await
369 .unwrap()
370 .json_nl_stream::<MyTestStructure>(10);
371 res.try_collect::<Vec<MyTestStructure>>()
372 .await
373 .expect_err("MaxLenReachedError");
374 }
375}