1use tokio::sync::mpsc::Sender;
4
5use crate::common::*;
6use crate::tokio_glue::{bytes_channel, try_forward_to_sender};
7
8pub(crate) fn concatenate_csv_streams(
14 ctx: Context,
15 mut csv_streams: BoxStream<CsvStream>,
16) -> Result<CsvStream> {
17 let (mut sender, receiver) = bytes_channel(1);
19 let worker_ctx = ctx.clone();
20 let worker = async move {
21 let mut first = true;
22 while let Some(result) = csv_streams.next().await {
23 match result {
24 Err(err) => {
25 error!("error reading stream of streams: {}", err,);
26 return send_err(sender, err).await;
27 }
28 Ok(csv_stream) => {
29 debug!("concatenating {}", csv_stream.name);
30 let mut data = csv_stream.data;
31
32 if first {
34 first = false;
35 } else {
36 data = strip_csv_header(worker_ctx.clone(), data)?;
37 }
38
39 try_forward_to_sender(data, &mut sender).await?;
41 }
42 }
43 }
44 trace!("end of CSV streams");
45 Ok(())
46 }
47 .instrument(debug_span!("concatenante_csv_streams"));
48
49 let new_csv_stream = CsvStream {
51 name: "combined".to_owned(),
52 data: receiver.boxed(),
53 };
54
55 ctx.spawn_worker(worker.boxed());
57 Ok(new_csv_stream)
58}
59
60#[tokio::test]
61async fn concatenate_csv_streams_strips_all_but_first_header() {
62 use tokio_stream::wrappers::ReceiverStream;
63
64 let input_1 = b"a,b\n1,2\n";
65 let input_2 = b"a,b\n3,4\n";
66 let expected = b"a,b\n1,2\n3,4\n";
67
68 let (ctx, worker_fut) = Context::create();
69
70 let cmd_fut = async move {
71 debug!("testing concatenate_csv_streams");
72
73 let (sender, receiver) = mpsc::channel::<Result<CsvStream>>(2);
75 sender
76 .send(Ok(CsvStream::from_bytes(&input_1[..]).await))
77 .await
78 .map_send_err()
79 .unwrap();
80 sender
81 .send(Ok(CsvStream::from_bytes(&input_2[..]).await))
82 .await
83 .map_send_err()
84 .unwrap();
85 let csv_streams = ReceiverStream::new(receiver).boxed();
86
87 drop(sender);
89
90 let combined = concatenate_csv_streams(ctx.clone(), csv_streams)
92 .unwrap()
93 .into_bytes()
94 .await
95 .unwrap();
96 assert_eq!(combined, &expected[..]);
97
98 Ok(())
99 };
100
101 try_join!(cmd_fut, worker_fut).unwrap();
102}
103
104fn strip_csv_header(
107 ctx: Context,
108 mut stream: BoxStream<BytesMut>,
109) -> Result<BoxStream<BytesMut>> {
110 let (mut sender, receiver) = bytes_channel(1);
112 let worker = async move {
113 let mut buffer: Option<BytesMut> = None;
115
116 while let Some(result) = stream.next().await {
118 match result {
119 Err(err) => {
120 error!("error reading stream: {}", err);
121 return send_err(sender, err).await;
122 }
123 Ok(bytes) => {
124 trace!("received {} bytes", bytes.len());
125 let mut new_buffer = if let Some(mut buffer) = buffer.take() {
126 buffer.extend_from_slice(&bytes);
127 buffer
128 } else {
129 bytes
130 };
131 match csv_header_length(&new_buffer) {
132 Ok(Some(header_len)) => {
133 trace!("stripping {} bytes of headers", header_len);
134 let _headers = new_buffer.split_to(header_len);
135 sender
136 .send(Ok(new_buffer))
137 .await
138 .context("broken pipe prevented sending data")?;
139 try_forward_to_sender(stream, &mut sender).await?;
140 return Ok(());
141 }
142 Ok(None) => {
143 trace!(
146 "didn't find full headers in {} bytes, looking...",
147 new_buffer.len(),
148 );
149 buffer = Some(new_buffer);
150 }
151 Err(err) => {
152 return send_err(sender, err).await;
153 }
154 }
155 }
156 }
157 }
158 trace!("end of stream");
159 let err = format_err!("end of CSV file while reading headers");
160 send_err(sender, err).await
161 }
162 .instrument(debug_span!("strip_csv_header"));
163
164 ctx.spawn_worker(worker.boxed());
166 Ok(receiver.boxed())
167}
168
169async fn send_err(sender: Sender<Result<BytesMut>>, err: Error) -> Result<()> {
171 sender
172 .send(Err(err))
173 .await
174 .context("broken pipe prevented sending error")?;
175 Ok(())
176}
177
178fn csv_header_length(data: &[u8]) -> Result<Option<usize>> {
181 if let Some(pos) = data.iter().position(|b| *b == b'\n') {
185 if data[..pos].iter().any(|b| *b == b'"') {
186 Err(format_err!(
187 "cannot yet concatenate CSV streams with quoted headers"
188 ))
189 } else {
190 Ok(Some(pos + 1))
191 }
192 } else {
193 Ok(None)
194 }
195}
196
197#[test]
198fn csv_header_length_handles_corner_cases() {
199 assert_eq!(csv_header_length(b"").unwrap(), None);
200 assert_eq!(csv_header_length(b"a,b,c").unwrap(), None);
201 assert_eq!(csv_header_length(b"a,b,c\n").unwrap(), Some(6));
202 assert_eq!(csv_header_length(b"a,b,c\nd,e,f\n").unwrap(), Some(6));
203 assert_eq!(csv_header_length(b"a,b,c\r\n").unwrap(), Some(7));
204
205 assert!(csv_header_length(b"a,\"\n\",c\n").is_err());
208 }