dbcrossbarlib/
concat.rs

1//! Support for concatenating multiple CSV streams.
2
3use tokio::sync::mpsc::Sender;
4
5use crate::common::*;
6use crate::tokio_glue::{bytes_channel, try_forward_to_sender};
7
8/// Given a stream of CSV streams, merge them into a single CSV stream, removing
9/// the headers from every CSV stream except the first.
10///
11/// This is a bit complicated because it needs to be asynchronous, and it tries
12/// to impose near-zero overhead on the underlying data copies.
13pub(crate) fn concatenate_csv_streams(
14    ctx: Context,
15    mut csv_streams: BoxStream<CsvStream>,
16) -> Result<CsvStream> {
17    // Create an asynchronous background worker to do the actual work.
18    let (mut sender, receiver) = bytes_channel(1);
19    let worker_ctx = ctx.clone();
20    let worker = async move {
21        let mut first = true;
22        while let Some(result) = csv_streams.next().await {
23            match result {
24                Err(err) => {
25                    error!("error reading stream of streams: {}", err,);
26                    return send_err(sender, err).await;
27                }
28                Ok(csv_stream) => {
29                    debug!("concatenating {}", csv_stream.name);
30                    let mut data = csv_stream.data;
31
32                    // If we're not the first CSV stream, remove the CSV header.
33                    if first {
34                        first = false;
35                    } else {
36                        data = strip_csv_header(worker_ctx.clone(), data)?;
37                    }
38
39                    // Forward the rest of the stream.
40                    try_forward_to_sender(data, &mut sender).await?;
41                }
42            }
43        }
44        trace!("end of CSV streams");
45        Ok(())
46    }
47    .instrument(debug_span!("concatenante_csv_streams"));
48
49    // Build our combined `CsvStream`.
50    let new_csv_stream = CsvStream {
51        name: "combined".to_owned(),
52        data: receiver.boxed(),
53    };
54
55    // Run the worker in the background, and return our combined stream.
56    ctx.spawn_worker(worker.boxed());
57    Ok(new_csv_stream)
58}
59
60#[tokio::test]
61async fn concatenate_csv_streams_strips_all_but_first_header() {
62    use tokio_stream::wrappers::ReceiverStream;
63
64    let input_1 = b"a,b\n1,2\n";
65    let input_2 = b"a,b\n3,4\n";
66    let expected = b"a,b\n1,2\n3,4\n";
67
68    let (ctx, worker_fut) = Context::create();
69
70    let cmd_fut = async move {
71        debug!("testing concatenate_csv_streams");
72
73        // Build our `BoxStream<CsvStream>`.
74        let (sender, receiver) = mpsc::channel::<Result<CsvStream>>(2);
75        sender
76            .send(Ok(CsvStream::from_bytes(&input_1[..]).await))
77            .await
78            .map_send_err()
79            .unwrap();
80        sender
81            .send(Ok(CsvStream::from_bytes(&input_2[..]).await))
82            .await
83            .map_send_err()
84            .unwrap();
85        let csv_streams = ReceiverStream::new(receiver).boxed();
86
87        // Close our sender so that our receiver knows we're done.
88        drop(sender);
89
90        // Test concatenation.
91        let combined = concatenate_csv_streams(ctx.clone(), csv_streams)
92            .unwrap()
93            .into_bytes()
94            .await
95            .unwrap();
96        assert_eq!(combined, &expected[..]);
97
98        Ok(())
99    };
100
101    try_join!(cmd_fut, worker_fut).unwrap();
102}
103
104/// Remove the CSV header from a CSV stream, passing everything else through
105/// untouched.
106fn strip_csv_header(
107    ctx: Context,
108    mut stream: BoxStream<BytesMut>,
109) -> Result<BoxStream<BytesMut>> {
110    // Create an asynchronous background worker to do the actual work.
111    let (mut sender, receiver) = bytes_channel(1);
112    let worker = async move {
113        // Accumulate bytes in this buffer until we see a full CSV header.
114        let mut buffer: Option<BytesMut> = None;
115
116        // Look for a full CSV header.
117        while let Some(result) = stream.next().await {
118            match result {
119                Err(err) => {
120                    error!("error reading stream: {}", err);
121                    return send_err(sender, err).await;
122                }
123                Ok(bytes) => {
124                    trace!("received {} bytes", bytes.len());
125                    let mut new_buffer = if let Some(mut buffer) = buffer.take() {
126                        buffer.extend_from_slice(&bytes);
127                        buffer
128                    } else {
129                        bytes
130                    };
131                    match csv_header_length(&new_buffer) {
132                        Ok(Some(header_len)) => {
133                            trace!("stripping {} bytes of headers", header_len);
134                            let _headers = new_buffer.split_to(header_len);
135                            sender
136                                .send(Ok(new_buffer))
137                                .await
138                                .context("broken pipe prevented sending data")?;
139                            try_forward_to_sender(stream, &mut sender).await?;
140                            return Ok(());
141                        }
142                        Ok(None) => {
143                            // Save our buffer and keep looking for the end of
144                            // the headers.
145                            trace!(
146                                "didn't find full headers in {} bytes, looking...",
147                                new_buffer.len(),
148                            );
149                            buffer = Some(new_buffer);
150                        }
151                        Err(err) => {
152                            return send_err(sender, err).await;
153                        }
154                    }
155                }
156            }
157        }
158        trace!("end of stream");
159        let err = format_err!("end of CSV file while reading headers");
160        send_err(sender, err).await
161    }
162    .instrument(debug_span!("strip_csv_header"));
163
164    // Run the worker in the background, and return our receiver.
165    ctx.spawn_worker(worker.boxed());
166    Ok(receiver.boxed())
167}
168
169// Send `err` using `sender`.
170async fn send_err(sender: Sender<Result<BytesMut>>, err: Error) -> Result<()> {
171    sender
172        .send(Err(err))
173        .await
174        .context("broken pipe prevented sending error")?;
175    Ok(())
176}
177
178/// Given a slice of bytes, determine if it contains a complete set of CSV
179/// headers, and if so, return their length.
180fn csv_header_length(data: &[u8]) -> Result<Option<usize>> {
181    // We could try to use the `csv` crate for this, but the `csv` crate will
182    // go to great lengths to recover from malformed CSV files, so it's not
183    // very useful for detecting whether we have a complete header line.
184    if let Some(pos) = data.iter().position(|b| *b == b'\n') {
185        if data[..pos].iter().any(|b| *b == b'"') {
186            Err(format_err!(
187                "cannot yet concatenate CSV streams with quoted headers"
188            ))
189        } else {
190            Ok(Some(pos + 1))
191        }
192    } else {
193        Ok(None)
194    }
195}
196
197#[test]
198fn csv_header_length_handles_corner_cases() {
199    assert_eq!(csv_header_length(b"").unwrap(), None);
200    assert_eq!(csv_header_length(b"a,b,c").unwrap(), None);
201    assert_eq!(csv_header_length(b"a,b,c\n").unwrap(), Some(6));
202    assert_eq!(csv_header_length(b"a,b,c\nd,e,f\n").unwrap(), Some(6));
203    assert_eq!(csv_header_length(b"a,b,c\r\n").unwrap(), Some(7));
204
205    // If we wanted to be more clever, we could handle quoted headers with
206    // embedded newlines, and other such complications.
207    assert!(csv_header_length(b"a,\"\n\",c\n").is_err());
208    //assert_eq!(csv_header_length(b"a,\"\na").unwrap(), None);
209    //assert_eq!(csv_header_length(b"a,\"\n\",c\n").unwrap(), Some(8));
210}