reqwest_streams/
csv_stream.rs

1use crate::error::StreamBodyKind;
2use crate::{StreamBodyError, StreamBodyResult};
3use async_trait::*;
4use futures::{StreamExt, TryStreamExt};
5use serde::Deserialize;
6use tokio_util::io::StreamReader;
7
8/// Extension trait for [`reqwest::Response`] that provides streaming support for the CSV format.
9#[async_trait]
10pub trait CsvStreamResponse {
11    /// Streams the response as CSV, where each line is a CSV row.
12    ///
13    /// The stream will [`Deserialize`] entries as type `T` with a maximum size of `max_obj_len`
14    /// bytes. If `max_obj_len` is [`usize::MAX`], lines will be read until a newline (`\n`)
15    /// character is reached.
16    ///
17    /// If `with_csv_header` is `true`, the stream will skip the first row (the CSV header).
18    ///
19    /// The `delimiter` is the byte value of the delimiter character.
20    ///
21    /// # Example
22    ///
23    /// ```rust,no_run
24    /// use futures::stream::BoxStream as _;
25    /// use reqwest_streams::CsvStreamResponse as _;
26    /// use serde::{Deserialize, Serialize};
27    ///
28    /// #[derive(Debug, Clone, Deserialize)]
29    /// struct MyTestStructure {
30    ///     some_test_field: String
31    /// }
32    ///
33    /// #[tokio::main]
34    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
35    ///     const MAX_OBJ_LEN: usize = 64 * 1024;
36    ///
37    ///     let _stream = reqwest::get("http://localhost:8080/csv")
38    ///         .await?
39    ///         .csv_stream::<MyTestStructure>(MAX_OBJ_LEN, true, b',');
40    ///
41    ///     Ok(())
42    /// }
43    /// ```
44    fn csv_stream<'a, 'b, T>(
45        self,
46        max_obj_len: usize,
47        with_csv_header: bool,
48        delimiter: u8,
49    ) -> impl futures::Stream<Item = StreamBodyResult<T>>  + Send + 'b
50    where
51        T: for<'de> Deserialize<'de>;
52}
53
54#[async_trait]
55impl CsvStreamResponse for reqwest::Response {
56    fn csv_stream<'a, 'b, T>(
57        self,
58        max_obj_len: usize,
59        with_csv_header: bool,
60        delimiter: u8,
61    ) -> impl futures::Stream<Item = StreamBodyResult<T>>  + Send + 'b
62    where
63        T: for<'de> Deserialize<'de>,
64    {
65        let reader = StreamReader::new(
66            self.bytes_stream()
67                .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err)),
68        );
69
70        let codec = tokio_util::codec::LinesCodec::new_with_max_length(max_obj_len);
71        let frames_reader = tokio_util::codec::FramedRead::new(reader, codec);
72
73        #[allow(clippy::bool_to_int_with_if)] // false positive: it is not bool to int
74        let skip_header_if_expected = if with_csv_header { 1 } else { 0 };
75
76        frames_reader
77            .into_stream()
78            .skip(skip_header_if_expected)
79            .map(move |frame_res| match frame_res {
80                Ok(frame_str) => {
81                    let mut csv_reader = csv::ReaderBuilder::new()
82                        .delimiter(delimiter)
83                        .has_headers(false)
84                        .from_reader(frame_str.as_bytes());
85
86                    let mut iter = csv_reader.deserialize::<T>();
87
88                    if let Some(csv_res) = iter.next() {
89                        match csv_res {
90                            Ok(result) => Ok(result),
91                            Err(err) => Err(StreamBodyError::new(
92                                StreamBodyKind::CodecError,
93                                Some(Box::new(err)),
94                                None,
95                            )),
96                        }
97                    } else {
98                        Err(StreamBodyError::new(StreamBodyKind::CodecError, None, None))
99                    }
100                }
101                Err(err) => Err(StreamBodyError::new(
102                    StreamBodyKind::CodecError,
103                    Some(Box::new(err)),
104                    None,
105                )),
106            })
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113    use crate::test_client::*;
114    use axum::{routing::*, Router};
115    use axum_streams::*;
116    use futures::stream;
117    use serde::Serialize;
118
119    #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
120    struct MyTestStructure {
121        some_test_field1: String,
122        some_test_field2: String,
123    }
124
125    fn generate_test_structures() -> Vec<MyTestStructure> {
126        vec![
127            MyTestStructure {
128                some_test_field1: "TestValue1".to_string(),
129                some_test_field2: "TestValue2".to_string()
130            };
131            100
132        ]
133    }
134
135    #[tokio::test]
136    async fn deserialize_csv_stream() {
137        let test_stream_vec = generate_test_structures();
138
139        let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
140
141        let app = Router::new().route("/", get(|| async { StreamBodyAs::csv(test_stream) }));
142
143        let client = TestClient::new(app).await;
144
145        let res = client
146            .get("/")
147            .send()
148            .await
149            .unwrap()
150            .csv_stream::<MyTestStructure>(1024, false, b',');
151        let items: Vec<MyTestStructure> = res.try_collect().await.unwrap();
152
153        assert_eq!(items, test_stream_vec);
154    }
155
156    #[tokio::test]
157    async fn deserialize_csv_stream_with_header() {
158        let test_stream_vec = generate_test_structures();
159
160        let test_stream = Box::pin(stream::iter(
161            test_stream_vec
162                .clone()
163                .into_iter()
164                .map(Ok::<_, axum::Error>),
165        ));
166
167        let app = Router::new().route(
168            "/",
169            get(|| async { StreamBodyAs::new(CsvStreamFormat::new(true, b','), test_stream) }),
170        );
171
172        let client = TestClient::new(app).await;
173
174        let res = client
175            .get("/")
176            .send()
177            .await
178            .unwrap()
179            .csv_stream::<MyTestStructure>(1024, true, b',');
180        let items: Vec<MyTestStructure> = res.try_collect().await.unwrap();
181
182        assert_eq!(items, test_stream_vec);
183    }
184
185    #[tokio::test]
186    async fn deserialize_csv_check_max_len() {
187        let test_stream_vec = generate_test_structures();
188
189        let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
190
191        let app = Router::new().route("/", get(|| async { StreamBodyAs::json_array(test_stream) }));
192
193        let client = TestClient::new(app).await;
194
195        let res = client
196            .get("/")
197            .send()
198            .await
199            .unwrap()
200            .csv_stream::<MyTestStructure>(5, false, b',');
201        res.try_collect::<Vec<MyTestStructure>>()
202            .await
203            .expect_err("MaxLenReachedError");
204    }
205}