1use crate::stream_format::StreamingFormat;
2use axum::body::{Body, HttpBody};
3use axum::response::{IntoResponse, Response};
4use bytes::BytesMut;
5use futures::stream::BoxStream;
6use futures::StreamExt;
7use futures::{Stream, TryStreamExt};
8use http::{HeaderMap, HeaderValue};
9use http_body::Frame;
10use std::fmt::Formatter;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14pub struct StreamBodyAs<'a> {
15 stream: BoxStream<'a, Result<Frame<axum::body::Bytes>, axum::Error>>,
16 headers: Option<HeaderMap>,
17}
18
19impl<'a> std::fmt::Debug for StreamBodyAs<'a> {
20 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
21 write!(f, "StreamBodyWithFormat")
22 }
23}
24
25impl<'a> StreamBodyAs<'a> {
26 pub fn new<S, T, FMT, E>(stream_format: FMT, stream: S) -> Self
28 where
29 FMT: StreamingFormat<T>,
30 S: Stream<Item = Result<T, E>> + 'a + Send,
31 E: Into<axum::Error>,
32 {
33 Self::with_options(stream_format, stream, StreamBodyAsOptions::new())
34 }
35
36 pub fn with_options<S, T, FMT, E>(
37 stream_format: FMT,
38 stream: S,
39 options: StreamBodyAsOptions,
40 ) -> Self
41 where
42 FMT: StreamingFormat<T>,
43 S: Stream<Item = Result<T, E>> + 'a + Send,
44 E: Into<axum::Error>,
45 {
46 Self {
47 stream: Self::create_stream_frames(&stream_format, stream, &options),
48 headers: stream_format.http_response_headers(&options),
49 }
50 }
51
52 pub fn headers(mut self, headers: HeaderMap) -> Self {
53 self.headers = Some(headers);
54 self
55 }
56
57 pub fn header<K, V>(mut self, key: K, value: V) -> Self
58 where
59 K: http::header::IntoHeaderName,
60 V: Into<HeaderValue>,
61 {
62 let current_headers = self.headers.get_or_insert(HeaderMap::new());
63 current_headers.append(key, value.into());
64 self
65 }
66
67 fn create_stream_frames<S, T, FMT, E>(
68 stream_format: &FMT,
69 stream: S,
70 options: &StreamBodyAsOptions,
71 ) -> BoxStream<'a, Result<Frame<axum::body::Bytes>, axum::Error>>
72 where
73 FMT: StreamingFormat<T>,
74 S: Stream<Item = Result<T, E>> + 'a + Send,
75 E: Into<axum::Error>,
76 {
77 let boxed_stream = Box::pin(stream.map_err(|e| e.into()));
78 match (options.buffering_ready_items, options.buffering_bytes) {
79 (Some(buffering_ready_items), _) => stream_format
80 .to_bytes_stream(boxed_stream, options)
81 .ready_chunks(buffering_ready_items)
82 .map(|chunks| {
83 let mut buf = BytesMut::new();
84 for chunk in chunks {
85 buf.extend_from_slice(&chunk?);
86 }
87 Ok(Frame::data(buf.freeze()))
88 })
89 .boxed(),
90 (_, Some(buffering_bytes)) => {
91 let bytes_stream = stream_format.to_bytes_stream(boxed_stream, options).chain(
92 futures::stream::once(futures::future::ready(Ok(bytes::Bytes::new()))),
93 );
94
95 bytes_stream
96 .scan(
97 BytesMut::with_capacity(buffering_bytes),
98 move |current_buffer, maybe_bytes| {
99 futures::future::ready(match maybe_bytes {
100 Ok(bytes) if bytes.is_empty() => {
101 Some(vec![Ok(Frame::data(current_buffer.split().freeze()))])
102 }
103 Ok(bytes) => {
104 let mut frames = Vec::new();
105 current_buffer.extend_from_slice(&bytes);
106 while current_buffer.len() >= buffering_bytes {
107 let buffer =
108 current_buffer.split_to(buffering_bytes).freeze();
109 frames.push(Ok(Frame::data(buffer)));
110 }
111 Some(frames)
112 }
113 Err(_) => None,
114 })
115 },
116 )
117 .flat_map(|res| futures::stream::iter(res).boxed())
118 .boxed()
119 }
120 (None, None) => stream_format
121 .to_bytes_stream(boxed_stream, options)
122 .map(|res| res.map(Frame::data))
123 .boxed(),
124 }
125 }
126}
127
128impl IntoResponse for StreamBodyAs<'static> {
129 fn into_response(mut self) -> Response {
130 let maybe_headers = self.headers.take();
131 let mut response: Response<Body> = Response::new(Body::new(self));
132 if let Some(headers) = maybe_headers {
133 *response.headers_mut() = headers;
134 }
135 response
136 }
137}
138
139impl<'a> HttpBody for StreamBodyAs<'a> {
140 type Data = axum::body::Bytes;
141 type Error = axum::Error;
142
143 fn poll_frame(
144 mut self: Pin<&mut Self>,
145 cx: &mut Context<'_>,
146 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
147 Pin::new(&mut self.stream).poll_next(cx)
148 }
149}
150
151pub type HttpHeaderValue = http::header::HeaderValue;
152
153pub struct StreamBodyAsOptions {
154 pub buffering_ready_items: Option<usize>,
155 pub buffering_bytes: Option<usize>,
156 pub content_type: Option<HttpHeaderValue>,
157}
158
159impl StreamBodyAsOptions {
160 pub fn new() -> Self {
161 Self {
162 buffering_ready_items: None,
163 buffering_bytes: None,
164 content_type: None,
165 }
166 }
167
168 pub fn buffering_ready_items(mut self, ready_items: usize) -> Self {
169 self.buffering_ready_items = Some(ready_items);
170 self
171 }
172
173 pub fn buffering_bytes(mut self, ready_bytes: usize) -> Self {
174 self.buffering_bytes = Some(ready_bytes);
175 self
176 }
177
178 pub fn content_type(mut self, content_type: HttpHeaderValue) -> Self {
179 self.content_type = Some(content_type);
180 self
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use crate::TextStreamFormat;
188 use bytes::Bytes;
189 use futures::TryStreamExt;
190
191 #[test]
192 fn test_stream_body_as_options() {
193 let options = StreamBodyAsOptions::new();
194 assert_eq!(options.buffering_ready_items, None);
195
196 let options = StreamBodyAsOptions::new().buffering_ready_items(10);
197 assert_eq!(options.buffering_ready_items, Some(10));
198 }
199
200 #[tokio::test]
201 async fn test_stream_body_as() {
202 let stream = futures::stream::iter(vec!["First".to_string(), "Second".to_string()]).boxed();
203 let stream_body_as =
204 StreamBodyAs::new(TextStreamFormat::new(), stream.map(Ok::<_, axum::Error>));
205 let response = stream_body_as.into_response();
206 assert_eq!(
207 response.headers().get(http::header::CONTENT_TYPE).unwrap(),
208 "text/plain; charset=utf-8"
209 );
210 let read = response.into_body().into_data_stream();
211 let data: Vec<Bytes> = read.try_collect().await.unwrap();
212 assert_eq!(data.len(), 2);
213 assert_eq!(data[0], Bytes::from("First"));
214 assert_eq!(data[1], Bytes::from("Second"));
215 }
216
217 #[tokio::test]
218 async fn test_stream_body_as_buffering_items() {
219 let stream = futures::stream::iter(vec![
220 "First".to_string(),
221 "Second".to_string(),
222 "Third".to_string(),
223 ])
224 .boxed();
225 let stream_body_as = StreamBodyAs::with_options(
226 TextStreamFormat::new(),
227 stream.map(Ok::<_, axum::Error>),
228 StreamBodyAsOptions::new().buffering_ready_items(2),
229 );
230 let response = stream_body_as.into_response();
231 assert_eq!(
232 response.headers().get(http::header::CONTENT_TYPE).unwrap(),
233 "text/plain; charset=utf-8"
234 );
235 let read = response.into_body().into_data_stream();
236 let data: Vec<Bytes> = read.try_collect().await.unwrap();
237 assert_eq!(data.len(), 2);
238 assert_eq!(data[0], Bytes::from("FirstSecond"));
239 assert_eq!(data[1], Bytes::from("Third"));
240 }
241
242 #[tokio::test]
243 async fn test_stream_body_as_buffering_bytes() {
244 let stream = futures::stream::iter(vec![
245 "First".to_string(),
246 "Second".to_string(),
247 "Third".to_string(),
248 ])
249 .boxed();
250 let stream_body_as = StreamBodyAs::with_options(
251 TextStreamFormat::new(),
252 stream.map(Ok::<_, axum::Error>),
253 StreamBodyAsOptions::new().buffering_bytes(3),
254 );
255 let response = stream_body_as.into_response();
256 assert_eq!(
257 response.headers().get(http::header::CONTENT_TYPE).unwrap(),
258 "text/plain; charset=utf-8"
259 );
260 let read = response.into_body().into_data_stream();
261 let data: Vec<Bytes> = read.try_collect().await.unwrap();
262 assert_eq!(data.len(), 6);
263 assert_eq!(data[0], Bytes::from("Fir"));
264 assert_eq!(data[1], Bytes::from("stS"));
265 assert_eq!(data[2], Bytes::from("eco"));
266 assert_eq!(data[3], Bytes::from("ndT"));
267 assert_eq!(data[4], Bytes::from("hir"));
268 assert_eq!(data[5], Bytes::from("d"));
269 }
270}