reqwest_streams/
csv_stream.rs1use crate::error::StreamBodyKind;
2use crate::{StreamBodyError, StreamBodyResult};
3use async_trait::*;
4use futures::{StreamExt, TryStreamExt};
5use serde::Deserialize;
6use tokio_util::io::StreamReader;
7
8#[async_trait]
10pub trait CsvStreamResponse {
11 fn csv_stream<'a, 'b, T>(
45 self,
46 max_obj_len: usize,
47 with_csv_header: bool,
48 delimiter: u8,
49 ) -> impl futures::Stream<Item = StreamBodyResult<T>> + Send + 'b
50 where
51 T: for<'de> Deserialize<'de>;
52}
53
54#[async_trait]
55impl CsvStreamResponse for reqwest::Response {
56 fn csv_stream<'a, 'b, T>(
57 self,
58 max_obj_len: usize,
59 with_csv_header: bool,
60 delimiter: u8,
61 ) -> impl futures::Stream<Item = StreamBodyResult<T>> + Send + 'b
62 where
63 T: for<'de> Deserialize<'de>,
64 {
65 let reader = StreamReader::new(
66 self.bytes_stream()
67 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err)),
68 );
69
70 let codec = tokio_util::codec::LinesCodec::new_with_max_length(max_obj_len);
71 let frames_reader = tokio_util::codec::FramedRead::new(reader, codec);
72
73 #[allow(clippy::bool_to_int_with_if)] let skip_header_if_expected = if with_csv_header { 1 } else { 0 };
75
76 frames_reader
77 .into_stream()
78 .skip(skip_header_if_expected)
79 .map(move |frame_res| match frame_res {
80 Ok(frame_str) => {
81 let mut csv_reader = csv::ReaderBuilder::new()
82 .delimiter(delimiter)
83 .has_headers(false)
84 .from_reader(frame_str.as_bytes());
85
86 let mut iter = csv_reader.deserialize::<T>();
87
88 if let Some(csv_res) = iter.next() {
89 match csv_res {
90 Ok(result) => Ok(result),
91 Err(err) => Err(StreamBodyError::new(
92 StreamBodyKind::CodecError,
93 Some(Box::new(err)),
94 None,
95 )),
96 }
97 } else {
98 Err(StreamBodyError::new(StreamBodyKind::CodecError, None, None))
99 }
100 }
101 Err(err) => Err(StreamBodyError::new(
102 StreamBodyKind::CodecError,
103 Some(Box::new(err)),
104 None,
105 )),
106 })
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use crate::test_client::*;
114 use axum::{routing::*, Router};
115 use axum_streams::*;
116 use futures::stream;
117 use serde::Serialize;
118
119 #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
120 struct MyTestStructure {
121 some_test_field1: String,
122 some_test_field2: String,
123 }
124
125 fn generate_test_structures() -> Vec<MyTestStructure> {
126 vec![
127 MyTestStructure {
128 some_test_field1: "TestValue1".to_string(),
129 some_test_field2: "TestValue2".to_string()
130 };
131 100
132 ]
133 }
134
135 #[tokio::test]
136 async fn deserialize_csv_stream() {
137 let test_stream_vec = generate_test_structures();
138
139 let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
140
141 let app = Router::new().route("/", get(|| async { StreamBodyAs::csv(test_stream) }));
142
143 let client = TestClient::new(app).await;
144
145 let res = client
146 .get("/")
147 .send()
148 .await
149 .unwrap()
150 .csv_stream::<MyTestStructure>(1024, false, b',');
151 let items: Vec<MyTestStructure> = res.try_collect().await.unwrap();
152
153 assert_eq!(items, test_stream_vec);
154 }
155
156 #[tokio::test]
157 async fn deserialize_csv_stream_with_header() {
158 let test_stream_vec = generate_test_structures();
159
160 let test_stream = Box::pin(stream::iter(
161 test_stream_vec
162 .clone()
163 .into_iter()
164 .map(Ok::<_, axum::Error>),
165 ));
166
167 let app = Router::new().route(
168 "/",
169 get(|| async { StreamBodyAs::new(CsvStreamFormat::new(true, b','), test_stream) }),
170 );
171
172 let client = TestClient::new(app).await;
173
174 let res = client
175 .get("/")
176 .send()
177 .await
178 .unwrap()
179 .csv_stream::<MyTestStructure>(1024, true, b',');
180 let items: Vec<MyTestStructure> = res.try_collect().await.unwrap();
181
182 assert_eq!(items, test_stream_vec);
183 }
184
185 #[tokio::test]
186 async fn deserialize_csv_check_max_len() {
187 let test_stream_vec = generate_test_structures();
188
189 let test_stream = Box::pin(stream::iter(test_stream_vec.clone()));
190
191 let app = Router::new().route("/", get(|| async { StreamBodyAs::json_array(test_stream) }));
192
193 let client = TestClient::new(app).await;
194
195 let res = client
196 .get("/")
197 .send()
198 .await
199 .unwrap()
200 .csv_stream::<MyTestStructure>(5, false, b',');
201 res.try_collect::<Vec<MyTestStructure>>()
202 .await
203 .expect_err("MaxLenReachedError");
204 }
205}