Skip to main content

axum_streams/
stream_body_as.rs

1use crate::stream_format::StreamingFormat;
2use axum::body::{Body, HttpBody};
3use axum::response::{IntoResponse, Response};
4use bytes::BytesMut;
5use futures::stream::BoxStream;
6use futures::StreamExt;
7use futures::{Stream, TryStreamExt};
8use http::{HeaderMap, HeaderValue};
9use http_body::Frame;
10use std::fmt::Formatter;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14pub struct StreamBodyAs<'a> {
15    stream: BoxStream<'a, Result<Frame<axum::body::Bytes>, axum::Error>>,
16    headers: Option<HeaderMap>,
17}
18
19impl<'a> std::fmt::Debug for StreamBodyAs<'a> {
20    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
21        write!(f, "StreamBodyWithFormat")
22    }
23}
24
25impl<'a> StreamBodyAs<'a> {
26    /// Create a new `StreamBodyWith` providing a stream of your objects in the specified format.
27    pub fn new<S, T, FMT, E>(stream_format: FMT, stream: S) -> Self
28    where
29        FMT: StreamingFormat<T>,
30        S: Stream<Item = Result<T, E>> + 'a + Send,
31        E: Into<axum::Error>,
32    {
33        Self::with_options(stream_format, stream, StreamBodyAsOptions::new())
34    }
35
36    pub fn with_options<S, T, FMT, E>(
37        stream_format: FMT,
38        stream: S,
39        options: StreamBodyAsOptions,
40    ) -> Self
41    where
42        FMT: StreamingFormat<T>,
43        S: Stream<Item = Result<T, E>> + 'a + Send,
44        E: Into<axum::Error>,
45    {
46        Self {
47            stream: Self::create_stream_frames(&stream_format, stream, &options),
48            headers: stream_format.http_response_headers(&options),
49        }
50    }
51
52    pub fn headers(mut self, headers: HeaderMap) -> Self {
53        self.headers = Some(headers);
54        self
55    }
56
57    pub fn header<K, V>(mut self, key: K, value: V) -> Self
58    where
59        K: http::header::IntoHeaderName,
60        V: Into<HeaderValue>,
61    {
62        let current_headers = self.headers.get_or_insert(HeaderMap::new());
63        current_headers.append(key, value.into());
64        self
65    }
66
67    fn create_stream_frames<S, T, FMT, E>(
68        stream_format: &FMT,
69        stream: S,
70        options: &StreamBodyAsOptions,
71    ) -> BoxStream<'a, Result<Frame<axum::body::Bytes>, axum::Error>>
72    where
73        FMT: StreamingFormat<T>,
74        S: Stream<Item = Result<T, E>> + 'a + Send,
75        E: Into<axum::Error>,
76    {
77        let boxed_stream = Box::pin(stream.map_err(|e| e.into()));
78        match (options.buffering_ready_items, options.buffering_bytes) {
79            (Some(buffering_ready_items), _) => stream_format
80                .to_bytes_stream(boxed_stream, options)
81                .ready_chunks(buffering_ready_items)
82                .map(|chunks| {
83                    let mut buf = BytesMut::new();
84                    for chunk in chunks {
85                        buf.extend_from_slice(&chunk?);
86                    }
87                    Ok(Frame::data(buf.freeze()))
88                })
89                .boxed(),
90            (_, Some(buffering_bytes)) => {
91                let bytes_stream = stream_format.to_bytes_stream(boxed_stream, options).chain(
92                    futures::stream::once(futures::future::ready(Ok(bytes::Bytes::new()))),
93                );
94
95                bytes_stream
96                    .scan(
97                        BytesMut::with_capacity(buffering_bytes),
98                        move |current_buffer, maybe_bytes| {
99                            futures::future::ready(match maybe_bytes {
100                                Ok(bytes) if bytes.is_empty() => {
101                                    Some(vec![Ok(Frame::data(current_buffer.split().freeze()))])
102                                }
103                                Ok(bytes) => {
104                                    let mut frames = Vec::new();
105                                    current_buffer.extend_from_slice(&bytes);
106                                    while current_buffer.len() >= buffering_bytes {
107                                        let buffer =
108                                            current_buffer.split_to(buffering_bytes).freeze();
109                                        frames.push(Ok(Frame::data(buffer)));
110                                    }
111                                    Some(frames)
112                                }
113                                Err(_) => None,
114                            })
115                        },
116                    )
117                    .flat_map(|res| futures::stream::iter(res).boxed())
118                    .boxed()
119            }
120            (None, None) => stream_format
121                .to_bytes_stream(boxed_stream, options)
122                .map(|res| res.map(Frame::data))
123                .boxed(),
124        }
125    }
126}
127
128impl IntoResponse for StreamBodyAs<'static> {
129    fn into_response(mut self) -> Response {
130        let maybe_headers = self.headers.take();
131        let mut response: Response<Body> = Response::new(Body::new(self));
132        if let Some(headers) = maybe_headers {
133            *response.headers_mut() = headers;
134        }
135        response
136    }
137}
138
139impl<'a> HttpBody for StreamBodyAs<'a> {
140    type Data = axum::body::Bytes;
141    type Error = axum::Error;
142
143    fn poll_frame(
144        mut self: Pin<&mut Self>,
145        cx: &mut Context<'_>,
146    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
147        Pin::new(&mut self.stream).poll_next(cx)
148    }
149}
150
151pub type HttpHeaderValue = http::header::HeaderValue;
152
153pub struct StreamBodyAsOptions {
154    pub buffering_ready_items: Option<usize>,
155    pub buffering_bytes: Option<usize>,
156    pub content_type: Option<HttpHeaderValue>,
157}
158
159impl StreamBodyAsOptions {
160    pub fn new() -> Self {
161        Self {
162            buffering_ready_items: None,
163            buffering_bytes: None,
164            content_type: None,
165        }
166    }
167
168    pub fn buffering_ready_items(mut self, ready_items: usize) -> Self {
169        self.buffering_ready_items = Some(ready_items);
170        self
171    }
172
173    pub fn buffering_bytes(mut self, ready_bytes: usize) -> Self {
174        self.buffering_bytes = Some(ready_bytes);
175        self
176    }
177
178    pub fn content_type(mut self, content_type: HttpHeaderValue) -> Self {
179        self.content_type = Some(content_type);
180        self
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use crate::TextStreamFormat;
188    use bytes::Bytes;
189    use futures::TryStreamExt;
190
191    #[test]
192    fn test_stream_body_as_options() {
193        let options = StreamBodyAsOptions::new();
194        assert_eq!(options.buffering_ready_items, None);
195
196        let options = StreamBodyAsOptions::new().buffering_ready_items(10);
197        assert_eq!(options.buffering_ready_items, Some(10));
198    }
199
200    #[tokio::test]
201    async fn test_stream_body_as() {
202        let stream = futures::stream::iter(vec!["First".to_string(), "Second".to_string()]).boxed();
203        let stream_body_as =
204            StreamBodyAs::new(TextStreamFormat::new(), stream.map(Ok::<_, axum::Error>));
205        let response = stream_body_as.into_response();
206        assert_eq!(
207            response.headers().get(http::header::CONTENT_TYPE).unwrap(),
208            "text/plain; charset=utf-8"
209        );
210        let read = response.into_body().into_data_stream();
211        let data: Vec<Bytes> = read.try_collect().await.unwrap();
212        assert_eq!(data.len(), 2);
213        assert_eq!(data[0], Bytes::from("First"));
214        assert_eq!(data[1], Bytes::from("Second"));
215    }
216
217    #[tokio::test]
218    async fn test_stream_body_as_buffering_items() {
219        let stream = futures::stream::iter(vec![
220            "First".to_string(),
221            "Second".to_string(),
222            "Third".to_string(),
223        ])
224        .boxed();
225        let stream_body_as = StreamBodyAs::with_options(
226            TextStreamFormat::new(),
227            stream.map(Ok::<_, axum::Error>),
228            StreamBodyAsOptions::new().buffering_ready_items(2),
229        );
230        let response = stream_body_as.into_response();
231        assert_eq!(
232            response.headers().get(http::header::CONTENT_TYPE).unwrap(),
233            "text/plain; charset=utf-8"
234        );
235        let read = response.into_body().into_data_stream();
236        let data: Vec<Bytes> = read.try_collect().await.unwrap();
237        assert_eq!(data.len(), 2);
238        assert_eq!(data[0], Bytes::from("FirstSecond"));
239        assert_eq!(data[1], Bytes::from("Third"));
240    }
241
242    #[tokio::test]
243    async fn test_stream_body_as_buffering_bytes() {
244        let stream = futures::stream::iter(vec![
245            "First".to_string(),
246            "Second".to_string(),
247            "Third".to_string(),
248        ])
249        .boxed();
250        let stream_body_as = StreamBodyAs::with_options(
251            TextStreamFormat::new(),
252            stream.map(Ok::<_, axum::Error>),
253            StreamBodyAsOptions::new().buffering_bytes(3),
254        );
255        let response = stream_body_as.into_response();
256        assert_eq!(
257            response.headers().get(http::header::CONTENT_TYPE).unwrap(),
258            "text/plain; charset=utf-8"
259        );
260        let read = response.into_body().into_data_stream();
261        let data: Vec<Bytes> = read.try_collect().await.unwrap();
262        assert_eq!(data.len(), 6);
263        assert_eq!(data[0], Bytes::from("Fir"));
264        assert_eq!(data[1], Bytes::from("stS"));
265        assert_eq!(data[2], Bytes::from("eco"));
266        assert_eq!(data[3], Bytes::from("ndT"));
267        assert_eq!(data[4], Bytes::from("hir"));
268        assert_eq!(data[5], Bytes::from("d"));
269    }
270}