reqwest_streams/
json_stream.rs

1use crate::error::StreamBodyKind;
2use crate::json_array_codec::JsonArrayCodec;
3use crate::{StreamBodyError, StreamBodyResult};
4use async_trait::*;
5use futures::{StreamExt, TryStreamExt};
6use serde::Deserialize;
7use tokio_util::io::StreamReader;
8
9/// Extension trait for [`reqwest::Response`] that provides streaming support for the JSON array
10/// and JSON Lines (NL/NewLines) formats.
11#[async_trait]
12pub trait JsonStreamResponse {
13    /// Streams the response as a JSON array.
14    ///
15    /// The stream will [`Deserialize`] entries as type `T` with a maximum size of `max_obj_len`
16    /// bytes. If `max_obj_len` is [`usize::MAX`], lines will be read until a newline (`\n`)
17    /// character is reached.
18    ///
19    /// # Example
20    ///
21    /// ```rust,no_run
22    /// use futures::stream::BoxStream as _;
23    /// use reqwest_streams::JsonStreamResponse as _;
24    /// use serde::{Deserialize, Serialize};
25    ///
26    /// #[derive(Debug, Clone, Deserialize)]
27    /// struct MyTestStructure {
28    ///     some_test_field: String
29    /// }
30    ///
31    /// #[tokio::main]
32    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
33    ///     const MAX_OBJ_LEN: usize = 64 * 1024;
34    ///
35    ///     let _stream = reqwest::get("http://localhost:8080/json-array")
36    ///         .await?
37    ///         .json_array_stream::<MyTestStructure>(MAX_OBJ_LEN);
38    ///
39    ///     Ok(())
40    /// }
41    /// ```
42    fn json_array_stream<'a, 'b, T>(self, max_obj_len: usize) -> impl futures::Stream<Item = StreamBodyResult<T>>  + Send + 'b
43    where
44        T: for<'de> Deserialize<'de> + Send + 'b;
45
46    /// Streams the response as a JSON array.
47    ///
48    /// The stream will [`Deserialize`] entries as type `T` with a maximum size of `max_obj_len`
49    /// bytes. If `max_obj_len` is [`usize::MAX`], lines will be read until a newline (`\n`)
50    /// character is reached.
51    ///
52    /// `buf_capacity` is the initial capacity of the stream's decoding buffer.
53    ///
54    /// # Example
55    ///
56    /// ```rust,no_run
57    /// use futures::stream::BoxStream as _;
58    /// use reqwest_streams::JsonStreamResponse as _;
59    /// use serde::{Deserialize, Serialize};
60    ///
61    /// #[derive(Debug, Clone, Deserialize)]
62    /// struct MyTestStructure {
63    ///     some_test_field: String
64    /// }
65    ///
66    /// #[tokio::main]
67    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
68    ///     const MAX_OBJ_LEN: usize = 64 * 1024;
69    ///     const INITIAL_BUF_CAPACITY: usize = 16 * 1024;
70    ///
71    ///     let _stream = reqwest::get("http://localhost:8080/json-array")
72    ///         .await?
73    ///         .json_array_stream_with_capacity::<MyTestStructure>(MAX_OBJ_LEN, INITIAL_BUF_CAPACITY);
74    ///
75    ///     Ok(())
76    /// }
77    /// ```
78    fn json_array_stream_with_capacity<'a, 'b, T>(
79        self,
80        max_obj_len: usize,
81        buf_capacity: usize,
82    ) -> impl futures::Stream<Item = StreamBodyResult<T>>  + Send + 'b
83    where
84        T: for<'de> Deserialize<'de> + Send + 'b;
85
86    /// Streams the response as JSON lines (NL/NewLines), where each line contains a JSON object.
87    ///
88    /// The stream will [`Deserialize`] entries as type `T` with a maximum size of `max_obj_len`
89    /// bytes. If `max_obj_len` is [`usize::MAX`], lines will be read until a newline (`\n`)
90    /// character is reached.
91    ///
92    /// # Example
93    ///
94    /// ```rust,no_run
95    /// use futures::stream::BoxStream as _;
96    /// use reqwest_streams::JsonStreamResponse as _;
97    /// use serde::{Deserialize, Serialize};
98    ///
99    /// #[derive(Debug, Clone, Deserialize)]
100    /// struct MyTestStructure {
101    ///     some_test_field: String
102    /// }
103    ///
104    /// #[tokio::main]
105    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
106    ///     const MAX_OBJ_LEN: usize = 64 * 1024;
107    ///
108    ///     let _stream = reqwest::get("http://localhost:8080/json-nl")
109    ///         .await?
110    ///         .json_nl_stream::<MyTestStructure>(MAX_OBJ_LEN);
111    ///
112    ///     Ok(())
113    /// }
114    /// ```
115    fn json_nl_stream<'a, 'b, T>(self, max_obj_len: usize) -> impl futures::Stream<Item = StreamBodyResult<T>>  + Send + 'b
116    where
117        T: for<'de> Deserialize<'de> + Send + 'b;
118
119    /// Streams the response as JSON lines (NL/NewLines), where each line contains a JSON object.
120    ///
121    /// The stream will [`Deserialize`] entries as type `T` with a maximum size of `max_obj_len`
122    /// bytes. If `max_obj_len` is [`usize::MAX`], lines will be read until a `\n` character
123    /// is reached.
124    ///
125    /// # Example
126    ///
127    /// ```rust,no_run
128    /// use futures::stream::BoxStream as _;
129    /// use reqwest_streams::JsonStreamResponse as _;
130    /// use serde::{Deserialize, Serialize};
131    ///
132    /// #[derive(Debug, Clone, Deserialize)]
133    /// struct MyTestStructure {
134    ///     some_test_field: String
135    /// }
136    ///
137    /// #[tokio::main]
138    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
139    ///     const MAX_OBJ_LEN: usize = 64 * 1024;
140    ///     const INITIAL_BUF_CAPACITY: usize = 16 * 1024;
141    ///
142    ///     let _stream = reqwest::get("http://localhost:8080/json-nl")
143    ///         .await?
144    ///         .json_nl_stream_with_capacity::<MyTestStructure>(MAX_OBJ_LEN, INITIAL_BUF_CAPACITY);
145    ///
146    ///     Ok(())
147    /// }
148    /// ```
149    fn json_nl_stream_with_capacity<'a, 'b, T>(
150        self,
151        max_obj_len: usize,
152        buf_capacity: usize,
153    ) -> impl futures::Stream<Item = StreamBodyResult<T>> + Send + 'b
154    where
155        T: for<'de> Deserialize<'de> + Send + 'b;
156}
157
158// This is the default capacity of the buffer used by StreamReader
159const INITIAL_CAPACITY: usize = 8 * 1024;
160
161#[async_trait]
162impl JsonStreamResponse for reqwest::Response {
163    fn json_nl_stream<'a, 'b, T>(self, max_obj_len: usize) -> impl futures::Stream<Item = StreamBodyResult<T>>  + Send + 'b
164    where
165        T: for<'de> Deserialize<'de> + Send + 'b,
166    {
167        self.json_nl_stream_with_capacity(max_obj_len, INITIAL_CAPACITY)
168    }
169
170    fn json_nl_stream_with_capacity<'a, 'b, T>(
171        self,
172        max_obj_len: usize,
173        buf_capacity: usize,
174    ) -> impl futures::Stream<Item = StreamBodyResult<T>>  + Send + 'b
175    where
176        T: for<'de> Deserialize<'de> + Send + 'b
177    {
178        let reader = StreamReader::new(
179            self.bytes_stream()
180                .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err)),
181        );
182
183        let codec = tokio_util::codec::LinesCodec::new_with_max_length(max_obj_len);
184        let frames_reader =
185            tokio_util::codec::FramedRead::with_capacity(reader, codec, buf_capacity);
186
187        frames_reader
188            .into_stream()
189            .map(|frame_res| match frame_res {
190                Ok(frame_str) => serde_json::from_str(frame_str.as_str()).map_err(|err| {
191                    StreamBodyError::new(StreamBodyKind::CodecError, Some(Box::new(err)), None)
192                }),
193                Err(err) => Err(StreamBodyError::new(
194                    StreamBodyKind::CodecError,
195                    Some(Box::new(err)),
196                    None,
197                )),
198            })
199    }
200
201    fn json_array_stream<'a, 'b, T>(self, max_obj_len: usize) -> impl futures::Stream<Item = StreamBodyResult<T>>  + Send + 'b
202    where
203        T: for<'de> Deserialize<'de> + Send + 'b,
204    {
205        self.json_array_stream_with_capacity(max_obj_len, INITIAL_CAPACITY)
206    }
207
208    fn json_array_stream_with_capacity<'a, 'b, T>(
209        self,
210        max_obj_len: usize,
211        buf_capacity: usize,
212    ) -> impl futures::Stream<Item = StreamBodyResult<T>>  + Send + 'b
213    where
214        T: for<'de> Deserialize<'de> + Send + 'b,
215    {
216        let reader = StreamReader::new(
217            self.bytes_stream()
218                .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err)),
219        );
220
221        //serde_json::from_reader(read);
222        let codec = JsonArrayCodec::<T>::new_with_max_length(max_obj_len);
223        let frames_reader =
224            tokio_util::codec::FramedRead::with_capacity(reader, codec, buf_capacity);
225
226        frames_reader.into_stream()
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use crate::test_client::*;
234    use axum::{routing::*, Router};
235    use axum_streams::*;
236    use futures::stream;
237    use serde::Serialize;
238
239    #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
240    struct MyTestStructure {
241        some_test_field: String,
242        test_arr: Vec<MyChildTest>,
243    }
244
245    #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
246    struct MyChildTest {
247        test_field: String,
248    }
249
250    fn generate_test_structures() -> Vec<MyTestStructure> {
251        vec![
252            MyTestStructure {
253                some_test_field: "TestValue".to_string(),
254                test_arr: vec![
255                    MyChildTest {
256                        test_field: "TestValue1".to_string()
257                    },
258                    MyChildTest {
259                        test_field: "TestValue2".to_string()
260                    }
261                ]
262                .iter()
263                .cloned()
264                .collect()
265            };
266            100
267        ]
268    }
269
270    #[tokio::test]
271    async fn deserialize_json_array_stream() {
272        let test_stream_vec = generate_test_structures();
273
274        let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
275
276        let app = Router::new().route("/", get(|| async { StreamBodyAs::json_array(test_stream) }));
277
278        let client = TestClient::new(app).await;
279
280        let res = client
281            .get("/")
282            .send()
283            .await
284            .unwrap()
285            .json_array_stream::<MyTestStructure>(1024);
286        let items: Vec<MyTestStructure> = res.try_collect().await.unwrap();
287
288        assert_eq!(items, test_stream_vec);
289    }
290
291    #[tokio::test]
292    async fn deserialize_json_array_stream_check_max_len() {
293        let test_stream_vec = generate_test_structures();
294
295        let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
296
297        let app = Router::new().route("/", get(|| async { StreamBodyAs::json_array(test_stream) }));
298
299        let client = TestClient::new(app).await;
300
301        let res = client
302            .get("/")
303            .send()
304            .await
305            .unwrap()
306            .json_array_stream::<MyTestStructure>(10);
307        res.try_collect::<Vec<MyTestStructure>>()
308            .await
309            .expect_err("MaxLenReachedError");
310    }
311
312    #[tokio::test]
313    async fn deserialize_json_array_stream_check_len_capacity() {
314        let test_stream_vec = generate_test_structures();
315
316        let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
317
318        let app = Router::new().route("/", get(|| async { StreamBodyAs::json_array(test_stream) }));
319
320        let client = TestClient::new(app).await;
321
322        let res = client
323            .get("/")
324            .send()
325            .await
326            .unwrap()
327            .json_array_stream_with_capacity::<MyTestStructure>(1024, 50);
328
329        let items: Vec<MyTestStructure> = res.try_collect().await.unwrap();
330
331        assert_eq!(items, test_stream_vec);
332    }
333
334    #[tokio::test]
335    async fn deserialize_json_nl_stream() {
336        let test_stream_vec = generate_test_structures();
337
338        let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
339
340        let app = Router::new().route("/", get(|| async { StreamBodyAs::json_nl(test_stream) }));
341
342        let client = TestClient::new(app).await;
343
344        let res = client
345            .get("/")
346            .send()
347            .await
348            .unwrap()
349            .json_nl_stream::<MyTestStructure>(1024);
350        let items: Vec<MyTestStructure> = res.try_collect().await.unwrap();
351
352        assert_eq!(items, test_stream_vec);
353    }
354
355    #[tokio::test]
356    async fn deserialize_json_nl_stream_check_max_len() {
357        let test_stream_vec = generate_test_structures();
358
359        let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
360
361        let app = Router::new().route("/", get(|| async { StreamBodyAs::json_nl(test_stream) }));
362
363        let client = TestClient::new(app).await;
364
365        let res = client
366            .get("/")
367            .send()
368            .await
369            .unwrap()
370            .json_nl_stream::<MyTestStructure>(10);
371        res.try_collect::<Vec<MyTestStructure>>()
372            .await
373            .expect_err("MaxLenReachedError");
374    }
375}