axum_streams/
json_formats.rs

1use crate::stream_body_as::StreamBodyAsOptions;
2use crate::stream_format::StreamingFormat;
3use crate::{StreamBodyAs, StreamFormatEnvelope};
4use bytes::{BufMut, BytesMut};
5use futures::stream::BoxStream;
6use futures::Stream;
7use futures::StreamExt;
8use http::HeaderMap;
9use serde::Serialize;
10use std::io::Write;
11
12pub struct JsonArrayStreamFormat<E = ()>
13where
14    E: Serialize,
15{
16    envelope: Option<StreamFormatEnvelope<E>>,
17}
18
19impl JsonArrayStreamFormat {
20    pub fn new() -> JsonArrayStreamFormat<()> {
21        JsonArrayStreamFormat { envelope: None }
22    }
23
24    pub fn with_envelope<E>(envelope: E, array_field: &str) -> JsonArrayStreamFormat<E>
25    where
26        E: Serialize,
27    {
28        JsonArrayStreamFormat {
29            envelope: Some(StreamFormatEnvelope {
30                object: envelope,
31                array_field: array_field.to_string(),
32            }),
33        }
34    }
35}
36
37impl<T, E> StreamingFormat<T> for JsonArrayStreamFormat<E>
38where
39    T: Serialize + Send + Sync + 'static,
40    E: Serialize + Send + Sync + 'static,
41{
42    fn to_bytes_stream<'a, 'b>(
43        &'a self,
44        stream: BoxStream<'b, Result<T, axum::Error>>,
45        _: &'a StreamBodyAsOptions,
46    ) -> BoxStream<'b, Result<axum::body::Bytes, axum::Error>> {
47        let stream_bytes: BoxStream<Result<axum::body::Bytes, axum::Error>> = Box::pin({
48            stream.enumerate().map(|(index, obj_res)| match obj_res {
49                Err(e) => Err(e),
50                Ok(obj) => {
51                    let mut buf = BytesMut::new().writer();
52
53                    let sep_write_res = if index != 0 {
54                        buf.write_all(JSON_SEP_BYTES).map_err(axum::Error::new)
55                    } else {
56                        Ok(())
57                    };
58
59                    match sep_write_res {
60                        Ok(_) => {
61                            match serde_json::to_writer(&mut buf, &obj).map_err(axum::Error::new) {
62                                Ok(_) => Ok(buf.into_inner().freeze()),
63                                Err(e) => Err(e),
64                            }
65                        }
66                        Err(e) => Err(e),
67                    }
68                }
69            })
70        });
71
72        let prepend_stream: BoxStream<Result<axum::body::Bytes, axum::Error>> =
73            Box::pin(futures::stream::once(futures::future::ready({
74                if let Some(envelope) = &self.envelope {
75                    match serde_json::to_vec(&envelope.object) {
76                        Ok(envelope_bytes) if envelope_bytes.len() > 1 => {
77                            let mut buf = BytesMut::new().writer();
78                            let envelope_slice = envelope_bytes.as_slice();
79                            match buf
80                                .write_all(&envelope_slice[0..envelope_slice.len() - 1])
81                                .and_then(|_| {
82                                    if envelope_bytes.len() > 2 {
83                                        buf.write_all(JSON_SEP_BYTES)
84                                    } else {
85                                        Ok(())
86                                    }
87                                })
88                                .and_then(|_| {
89                                    buf.write_all(
90                                        format!("\"{}\":", envelope.array_field).as_bytes(),
91                                    )
92                                })
93                                .and_then(|_| buf.write_all(JSON_ARRAY_BEGIN_BYTES))
94                            {
95                                Ok(_) => Ok::<_, axum::Error>(buf.into_inner().freeze()),
96                                Err(e) => Err(axum::Error::new(e)),
97                            }
98                        }
99                        Ok(envelope_bytes) => Err(axum::Error::new(std::io::Error::new(
100                            std::io::ErrorKind::Other,
101                            format!("Too short envelope: {:?}", envelope_bytes),
102                        ))),
103                        Err(e) => Err(axum::Error::new(e)),
104                    }
105                } else {
106                    Ok::<_, axum::Error>(axum::body::Bytes::from(JSON_ARRAY_BEGIN_BYTES))
107                }
108            })));
109
110        let append_stream: BoxStream<Result<axum::body::Bytes, axum::Error>> =
111            Box::pin(futures::stream::once(futures::future::ready({
112                if self.envelope.is_some() {
113                    Ok::<_, axum::Error>(axum::body::Bytes::from(JSON_ARRAY_ENVELOP_END_BYTES))
114                } else {
115                    Ok::<_, axum::Error>(axum::body::Bytes::from(JSON_ARRAY_END_BYTES))
116                }
117            })));
118
119        Box::pin(prepend_stream.chain(stream_bytes.chain(append_stream)))
120    }
121
122    fn http_response_headers(&self, options: &StreamBodyAsOptions) -> Option<HeaderMap> {
123        let mut header_map = HeaderMap::new();
124        header_map.insert(
125            http::header::CONTENT_TYPE,
126            options
127                .content_type
128                .clone()
129                .unwrap_or_else(|| http::header::HeaderValue::from_static("application/json")),
130        );
131        Some(header_map)
132    }
133}
134
135pub struct JsonNewLineStreamFormat;
136
137impl JsonNewLineStreamFormat {
138    pub fn new() -> Self {
139        Self {}
140    }
141}
142
143impl<T> StreamingFormat<T> for JsonNewLineStreamFormat
144where
145    T: Serialize + Send + Sync + 'static,
146{
147    fn to_bytes_stream<'a, 'b>(
148        &'a self,
149        stream: BoxStream<'b, Result<T, axum::Error>>,
150        _: &'a StreamBodyAsOptions,
151    ) -> BoxStream<'b, Result<axum::body::Bytes, axum::Error>> {
152        Box::pin({
153            stream.map(|obj_res| match obj_res {
154                Err(e) => Err(e),
155                Ok(obj) => {
156                    let mut buf = BytesMut::new().writer();
157                    match serde_json::to_writer(&mut buf, &obj).map_err(axum::Error::new) {
158                        Ok(_) => match buf.write_all(JSON_NL_SEP_BYTES).map_err(axum::Error::new) {
159                            Ok(_) => Ok(buf.into_inner().freeze()),
160                            Err(e) => Err(e),
161                        },
162                        Err(e) => Err(e),
163                    }
164                }
165            })
166        })
167    }
168
169    fn http_response_headers(&self, _: &StreamBodyAsOptions) -> Option<HeaderMap> {
170        let mut header_map = HeaderMap::new();
171        header_map.insert(
172            http::header::CONTENT_TYPE,
173            http::header::HeaderValue::from_static("application/jsonstream"),
174        );
175        Some(header_map)
176    }
177}
178
179const JSON_ARRAY_BEGIN_BYTES: &[u8] = "[".as_bytes();
180const JSON_ARRAY_END_BYTES: &[u8] = "]".as_bytes();
181const JSON_ARRAY_ENVELOP_END_BYTES: &[u8] = "]}".as_bytes();
182const JSON_SEP_BYTES: &[u8] = ",".as_bytes();
183
184const JSON_NL_SEP_BYTES: &[u8] = "\n".as_bytes();
185
186impl<'a> crate::StreamBodyAs<'a> {
187    pub fn json_array<S, T>(stream: S) -> Self
188    where
189        T: Serialize + Send + Sync + 'static,
190        S: Stream<Item = T> + 'a + Send,
191    {
192        Self::new(
193            JsonArrayStreamFormat::new(),
194            stream.map(Ok::<T, axum::Error>),
195        )
196    }
197
198    pub fn json_array_with_errors<S, T, E>(stream: S) -> Self
199    where
200        T: Serialize + Send + Sync + 'static,
201        S: Stream<Item = Result<T, E>> + 'a + Send,
202        E: Into<axum::Error>,
203    {
204        Self::new(JsonArrayStreamFormat::new(), stream)
205    }
206
207    pub fn json_array_with_envelope<S, T, EN>(stream: S, envelope: EN, array_field: &str) -> Self
208    where
209        T: Serialize + Send + Sync + 'static,
210        S: Stream<Item = T> + 'a + Send,
211        EN: Serialize + Send + Sync + 'static,
212    {
213        Self::new(
214            JsonArrayStreamFormat::with_envelope(envelope, array_field),
215            stream.map(Ok::<T, axum::Error>),
216        )
217    }
218
219    pub fn json_array_with_envelope_errors<S, T, E, EN>(
220        stream: S,
221        envelope: EN,
222        array_field: &str,
223    ) -> Self
224    where
225        T: Serialize + Send + Sync + 'static,
226        S: Stream<Item = Result<T, E>> + 'a + Send,
227        E: Into<axum::Error>,
228        EN: Serialize + Send + Sync + 'static,
229    {
230        Self::new(
231            JsonArrayStreamFormat::with_envelope(envelope, array_field),
232            stream,
233        )
234    }
235
236    pub fn json_nl<S, T>(stream: S) -> Self
237    where
238        T: Serialize + Send + Sync + 'static,
239        S: Stream<Item = T> + 'a + Send,
240    {
241        Self::new(
242            JsonNewLineStreamFormat::new(),
243            stream.map(Ok::<T, axum::Error>),
244        )
245    }
246
247    pub fn json_nl_with_errors<S, T, E>(stream: S) -> Self
248    where
249        T: Serialize + Send + Sync + 'static,
250        S: Stream<Item = Result<T, E>> + 'a + Send,
251        E: Into<axum::Error>,
252    {
253        Self::new(JsonNewLineStreamFormat::new(), stream)
254    }
255}
256
257impl StreamBodyAsOptions {
258    pub fn json_array<'a, S, T>(self, stream: S) -> StreamBodyAs<'a>
259    where
260        T: Serialize + Send + Sync + 'static,
261        S: Stream<Item = T> + 'a + Send,
262    {
263        StreamBodyAs::with_options(
264            JsonArrayStreamFormat::new(),
265            stream.map(Ok::<T, axum::Error>),
266            self,
267        )
268    }
269
270    pub fn json_array_with_errors<'a, S, T, E>(self, stream: S) -> StreamBodyAs<'a>
271    where
272        T: Serialize + Send + Sync + 'static,
273        S: Stream<Item = Result<T, E>> + 'a + Send,
274        E: Into<axum::Error>,
275    {
276        StreamBodyAs::with_options(JsonArrayStreamFormat::new(), stream, self)
277    }
278
279    pub fn json_array_with_envelope<'a, S, T, EN>(
280        self,
281        stream: S,
282        envelope: EN,
283        array_field: &str,
284    ) -> StreamBodyAs<'a>
285    where
286        T: Serialize + Send + Sync + 'static,
287        S: Stream<Item = T> + 'a + Send,
288        EN: Serialize + Send + Sync + 'static,
289    {
290        StreamBodyAs::with_options(
291            JsonArrayStreamFormat::with_envelope(envelope, array_field),
292            stream.map(Ok::<T, axum::Error>),
293            self,
294        )
295    }
296
297    pub fn json_array_with_envelope_errors<'a, S, T, E, EN>(
298        self,
299        stream: S,
300        envelope: EN,
301        array_field: &str,
302    ) -> StreamBodyAs<'a>
303    where
304        T: Serialize + Send + Sync + 'static,
305        S: Stream<Item = Result<T, E>> + 'a + Send,
306        E: Into<axum::Error>,
307        EN: Serialize + Send + Sync + 'static,
308    {
309        StreamBodyAs::with_options(
310            JsonArrayStreamFormat::with_envelope(envelope, array_field),
311            stream,
312            self,
313        )
314    }
315
316    pub fn json_nl<'a, S, T>(self, stream: S) -> StreamBodyAs<'a>
317    where
318        T: Serialize + Send + Sync + 'static,
319        S: Stream<Item = T> + 'a + Send,
320    {
321        StreamBodyAs::with_options(
322            JsonNewLineStreamFormat::new(),
323            stream.map(Ok::<T, axum::Error>),
324            self,
325        )
326    }
327
328    pub fn json_nl_with_errors<'a, S, T, E>(self, stream: S) -> StreamBodyAs<'a>
329    where
330        T: Serialize + Send + Sync + 'static,
331        S: Stream<Item = Result<T, E>> + 'a + Send,
332        E: Into<axum::Error>,
333    {
334        StreamBodyAs::with_options(JsonNewLineStreamFormat::new(), stream, self)
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use crate::test_client::*;
342    use crate::StreamBodyAs;
343    use axum::{routing::*, Router};
344    use futures::stream;
345
346    #[tokio::test]
347    async fn serialize_json_array_stream_format() {
348        #[derive(Debug, Clone, Serialize)]
349        struct TestOutputStructure {
350            foo: String,
351        }
352
353        let test_stream_vec = vec![
354            TestOutputStructure {
355                foo: "bar".to_string()
356            };
357            7
358        ];
359
360        let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
361
362        let app = Router::new().route(
363            "/",
364            get(|| async {
365                StreamBodyAs::new(
366                    JsonArrayStreamFormat::new(),
367                    test_stream.map(Ok::<_, axum::Error>),
368                )
369            }),
370        );
371
372        let client = TestClient::new(app).await;
373
374        let expected_json = serde_json::to_string(&test_stream_vec).unwrap();
375        let res = client.get("/").send().await.unwrap();
376        assert_eq!(
377            res.headers()
378                .get("content-type")
379                .and_then(|h| h.to_str().ok()),
380            Some("application/json")
381        );
382
383        let body = res.text().await.unwrap();
384
385        assert_eq!(body, expected_json);
386    }
387
388    #[tokio::test]
389    async fn serialize_json_nl_stream_format() {
390        #[derive(Debug, Clone, Serialize)]
391        struct TestOutputStructure {
392            foo: String,
393        }
394
395        let test_stream_vec = vec![
396            TestOutputStructure {
397                foo: "bar".to_string()
398            };
399            7
400        ];
401
402        let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
403
404        let app = Router::new().route(
405            "/",
406            get(|| async {
407                StreamBodyAs::new(
408                    JsonNewLineStreamFormat::new(),
409                    test_stream.map(Ok::<_, axum::Error>),
410                )
411            }),
412        );
413
414        let client = TestClient::new(app).await;
415
416        let expected_json = test_stream_vec
417            .iter()
418            .map(|item| serde_json::to_string(item).unwrap())
419            .collect::<Vec<String>>()
420            .join("\n")
421            + "\n";
422
423        let res = client.get("/").send().await.unwrap();
424        assert_eq!(
425            res.headers()
426                .get("content-type")
427                .and_then(|h| h.to_str().ok()),
428            Some("application/jsonstream")
429        );
430
431        let body = res.text().await.unwrap();
432
433        assert_eq!(body, expected_json);
434    }
435
436    #[tokio::test]
437    async fn serialize_json_array_stream_with_envelope_format() {
438        #[derive(Debug, Clone, Serialize)]
439        struct TestItemStructure {
440            foo: String,
441        }
442
443        #[derive(Debug, Clone, Serialize)]
444        struct TestEnvelopeStructure {
445            envelope_field: String,
446            #[serde(skip_serializing_if = "Vec::is_empty")]
447            my_array: Vec<TestItemStructure>,
448        }
449
450        let test_stream_vec = vec![
451            TestItemStructure {
452                foo: "bar".to_string()
453            };
454            7
455        ];
456
457        let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
458
459        let test_envelope = TestEnvelopeStructure {
460            envelope_field: "test_envelope".to_string(),
461            my_array: Vec::new(),
462        };
463
464        let app = Router::new().route(
465            "/",
466            get(|| async {
467                StreamBodyAs::new(
468                    JsonArrayStreamFormat::with_envelope(test_envelope, "my_array"),
469                    test_stream.map(Ok::<_, axum::Error>),
470                )
471            }),
472        );
473
474        let client = TestClient::new(app).await;
475
476        let expected_envelope = TestEnvelopeStructure {
477            envelope_field: "test_envelope".to_string(),
478            my_array: test_stream_vec.clone(),
479        };
480
481        let expected_json = serde_json::to_string(&expected_envelope).unwrap();
482        let res = client.get("/").send().await.unwrap();
483        assert_eq!(
484            res.headers()
485                .get("content-type")
486                .and_then(|h| h.to_str().ok()),
487            Some("application/json")
488        );
489
490        let body = res.text().await.unwrap();
491
492        assert_eq!(body, expected_json);
493    }
494
495    #[tokio::test]
496    async fn serialize_json_array_stream_with_empty_envelope_format() {
497        #[derive(Debug, Clone, Serialize)]
498        struct TestItemStructure {
499            foo: String,
500        }
501
502        #[derive(Debug, Clone, Serialize)]
503        struct TestEnvelopeStructure {
504            #[serde(skip_serializing_if = "Vec::is_empty")]
505            my_array: Vec<TestItemStructure>,
506        }
507
508        let test_stream_vec = vec![
509            TestItemStructure {
510                foo: "bar".to_string()
511            };
512            7
513        ];
514
515        let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
516
517        let test_envelope = TestEnvelopeStructure {
518            my_array: Vec::new(),
519        };
520
521        let app = Router::new().route(
522            "/",
523            get(|| async {
524                StreamBodyAs::new(
525                    JsonArrayStreamFormat::with_envelope(test_envelope, "my_array"),
526                    test_stream.map(Ok::<_, axum::Error>),
527                )
528            }),
529        );
530
531        let client = TestClient::new(app).await;
532
533        let expected_envelope = TestEnvelopeStructure {
534            my_array: test_stream_vec.clone(),
535        };
536
537        let expected_json = serde_json::to_string(&expected_envelope).unwrap();
538        let res = client.get("/").send().await.unwrap();
539        assert_eq!(
540            res.headers()
541                .get("content-type")
542                .and_then(|h| h.to_str().ok()),
543            Some("application/json")
544        );
545
546        let body = res.text().await.unwrap();
547
548        assert_eq!(body, expected_json);
549    }
550}