Skip to main content

axum_streams/
csv_format.rs

1use crate::stream_body_as::StreamBodyAsOptions;
2use crate::stream_format::StreamingFormat;
3use crate::StreamBodyAs;
4use futures::stream::BoxStream;
5use futures::Stream;
6use futures::StreamExt;
7use http::HeaderMap;
8use serde::Serialize;
9
10pub struct CsvStreamFormat {
11    has_headers: bool,
12    delimiter: u8,
13    flexible: bool,
14    quote_style: csv::QuoteStyle,
15    quote: u8,
16    double_quote: bool,
17    escape: u8,
18    terminator: csv::Terminator,
19}
20
21impl Default for CsvStreamFormat {
22    fn default() -> Self {
23        Self {
24            has_headers: true,
25            delimiter: b',',
26            flexible: false,
27            quote_style: csv::QuoteStyle::Necessary,
28            quote: b'"',
29            double_quote: true,
30            escape: b'\\',
31            terminator: csv::Terminator::Any(b'\n'),
32        }
33    }
34}
35
36impl CsvStreamFormat {
37    pub fn new(has_headers: bool, delimiter: u8) -> Self {
38        Self {
39            has_headers,
40            delimiter,
41            ..Default::default()
42        }
43    }
44
45    /// Sets whether to use flexible serialize.
46    pub fn with_flexible(mut self, flexible: bool) -> Self {
47        self.flexible = flexible;
48        self
49    }
50
51    /// Sets the quote style to use.
52    pub fn with_quote_style(mut self, quote_style: csv::QuoteStyle) -> Self {
53        self.quote_style = quote_style;
54        self
55    }
56
57    /// Sets the quote character to use.
58    pub fn with_quote(mut self, quote: u8) -> Self {
59        self.quote = quote;
60        self
61    }
62
63    /// Sets whether to double quote.
64    pub fn with_double_quote(mut self, double_quote: bool) -> Self {
65        self.double_quote = double_quote;
66        self
67    }
68
69    /// Sets the escape character to use.
70    pub fn with_escape(mut self, escape: u8) -> Self {
71        self.escape = escape;
72        self
73    }
74
75    /// Sets the line terminator to use.
76    pub fn with_terminator(mut self, terminator: csv::Terminator) -> Self {
77        self.terminator = terminator;
78        self
79    }
80
81    /// Set the field delimiter to use.
82    pub fn with_delimiter(mut self, delimiter: u8) -> Self {
83        self.delimiter = delimiter;
84        self
85    }
86
87    /// Set whether to write headers.
88    pub fn with_has_headers(mut self, has_headers: bool) -> Self {
89        self.has_headers = has_headers;
90        self
91    }
92}
93
94impl<T> StreamingFormat<T> for CsvStreamFormat
95where
96    T: Serialize + Send + Sync + 'static,
97{
98    fn to_bytes_stream<'a, 'b>(
99        &'a self,
100        stream: BoxStream<'b, Result<T, axum::Error>>,
101        _: &'a StreamBodyAsOptions,
102    ) -> BoxStream<'b, Result<axum::body::Bytes, axum::Error>> {
103        let stream_with_header = self.has_headers;
104        let stream_delimiter = self.delimiter;
105        let stream_flexible = self.flexible;
106        let stream_quote_style = self.quote_style;
107        let stream_quote = self.quote;
108        let stream_double_quote = self.double_quote;
109        let stream_escape = self.escape;
110        let terminator = self.terminator;
111
112        Box::pin({
113            stream
114                .enumerate()
115                .map(move |(index, obj_res)| match obj_res {
116                    Err(e) => Err(e),
117                    Ok(obj) => {
118                        let mut writer = csv::WriterBuilder::new()
119                            .has_headers(index == 0 && stream_with_header)
120                            .delimiter(stream_delimiter)
121                            .flexible(stream_flexible)
122                            .quote_style(stream_quote_style)
123                            .quote(stream_quote)
124                            .double_quote(stream_double_quote)
125                            .escape(stream_escape)
126                            .terminator(terminator)
127                            .from_writer(vec![]);
128
129                        writer.serialize(obj).map_err(axum::Error::new)?;
130                        writer.flush().map_err(axum::Error::new)?;
131                        writer
132                            .into_inner()
133                            .map_err(axum::Error::new)
134                            .map(axum::body::Bytes::from)
135                    }
136                })
137        })
138    }
139
140    fn http_response_headers(&self, options: &StreamBodyAsOptions) -> Option<HeaderMap> {
141        let mut header_map = HeaderMap::new();
142        header_map.insert(
143            http::header::CONTENT_TYPE,
144            options
145                .content_type
146                .clone()
147                .unwrap_or_else(|| http::header::HeaderValue::from_static("text/csv")),
148        );
149        Some(header_map)
150    }
151}
152
153impl<'a> StreamBodyAs<'a> {
154    pub fn csv<S, T>(stream: S) -> Self
155    where
156        T: Serialize + Send + Sync + 'static,
157        S: Stream<Item = T> + 'a + Send,
158    {
159        Self::new(
160            CsvStreamFormat::new(false, b','),
161            stream.map(Ok::<T, axum::Error>),
162        )
163    }
164
165    pub fn csv_with_errors<S, T, E>(stream: S) -> Self
166    where
167        T: Serialize + Send + Sync + 'static,
168        S: Stream<Item = Result<T, E>> + 'a + Send,
169        E: Into<axum::Error> + 'static,
170    {
171        Self::new(CsvStreamFormat::new(false, b','), stream)
172    }
173}
174
175impl StreamBodyAsOptions {
176    pub fn csv<'a, S, T>(self, stream: S) -> StreamBodyAs<'a>
177    where
178        T: Serialize + Send + Sync + 'static,
179        S: Stream<Item = T> + 'a + Send,
180    {
181        StreamBodyAs::with_options(
182            CsvStreamFormat::new(false, b','),
183            stream.map(Ok::<T, axum::Error>),
184            self,
185        )
186    }
187
188    pub fn csv_with_errors<'a, S, T, E>(self, stream: S) -> StreamBodyAs<'a>
189    where
190        T: Serialize + Send + Sync + 'static,
191        S: Stream<Item = Result<T, E>> + 'a + Send,
192        E: Into<axum::Error> + 'static,
193    {
194        StreamBodyAs::with_options(CsvStreamFormat::new(false, b','), stream, self)
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use crate::test_client::*;
202    use crate::StreamBodyAs;
203    use axum::{routing::*, Router};
204    use futures::stream;
205    use std::ops::Add;
206
207    #[tokio::test]
208    async fn serialize_csv_stream_format() {
209        #[derive(Debug, Clone, Serialize)]
210        struct TestOutputStructure {
211            foo1: String,
212            foo2: String,
213        }
214
215        let test_stream_vec = vec![
216            TestOutputStructure {
217                foo1: "bar1".to_string(),
218                foo2: "bar2".to_string()
219            };
220            7
221        ];
222
223        let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
224
225        let app = Router::new().route(
226            "/",
227            get(|| async {
228                StreamBodyAs::new(
229                    CsvStreamFormat::new(false, b'.').with_delimiter(b','),
230                    test_stream.map(Ok::<_, axum::Error>),
231                )
232            }),
233        );
234
235        let client = TestClient::new(app).await;
236
237        let expected_csv = test_stream_vec
238            .iter()
239            .map(|item| format!("{},{}", item.foo1, item.foo2))
240            .collect::<Vec<String>>()
241            .join("\n")
242            .add("\n");
243
244        let res = client.get("/").send().await.unwrap();
245        assert_eq!(
246            res.headers()
247                .get("content-type")
248                .and_then(|h| h.to_str().ok()),
249            Some("text/csv")
250        );
251        let body = res.text().await.unwrap();
252
253        assert_eq!(body, expected_csv);
254    }
255}