noir_compute/operator/source/
csv.rs

1use std::fmt::Display;
2use std::fs::File;
3use std::io;
4use std::io::{BufRead, BufReader, Read, Seek, SeekFrom};
5use std::marker::PhantomData;
6use std::path::PathBuf;
7
8use csv::{ByteRecord, Reader, ReaderBuilder, Terminator, Trim};
9use serde::Deserialize;
10
11use crate::block::{BlockStructure, OperatorKind, OperatorStructure, Replication};
12use crate::operator::source::Source;
13use crate::operator::{Data, Operator, StreamElement};
14use crate::scheduler::ExecutionMetadata;
15use crate::Stream;
16
17/// Wrapper that limits the bytes that can be read from a type that implements `io::Read`.
18struct LimitedReader<R: Read> {
19    inner: R,
20    /// Bytes remaining to be read.
21    remaining: usize,
22}
23
24impl<R: Read> LimitedReader<R> {
25    fn new(inner: R, remaining: usize) -> Self {
26        Self { inner, remaining }
27    }
28}
29
30impl<R: Read> Read for LimitedReader<R> {
31    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
32        let read_bytes = if self.remaining > 0 {
33            // if there are some bytes to be read, call read on the inner reader
34            self.inner.read(buf)?.min(self.remaining)
35        } else {
36            // all the bytes have been read
37            0
38        };
39        self.remaining -= read_bytes;
40        Ok(read_bytes)
41    }
42}
43
44/// Options for the CSV parser.
45#[derive(Clone)]
46struct CsvOptions {
47    /// Byte used to mark a line as a comment.
48    comment: Option<u8>,
49    /// Field delimiter.
50    delimiter: u8,
51    /// Whether quotes are escaped by using doubled quotes.
52    double_quote: bool,
53    /// Byte used to escape quotes.
54    escape: Option<u8>,
55    /// Whether to allow records with different number of fields.
56    flexible: bool,
57    /// Byte used to quote fields.
58    quote: u8,
59    /// Whether to enable field quoting.
60    quoting: bool,
61    /// Line terminator.
62    terminator: Terminator,
63    /// Whether to trim fields and/or headers.
64    trim: Trim,
65    /// Whether the CSV file has headers.
66    has_headers: bool,
67}
68
69impl Default for CsvOptions {
70    fn default() -> Self {
71        Self {
72            comment: None,
73            delimiter: b',',
74            double_quote: true,
75            escape: None,
76            flexible: false,
77            quote: b'"',
78            quoting: true,
79            terminator: Terminator::CRLF,
80            trim: Trim::None,
81            has_headers: true,
82        }
83    }
84}
85
86/// Source that reads and parses a CSV file.
87///
88/// The file is divided in chunks and is read concurrently by multiple replicas.
89pub struct CsvSource<Out: Data + for<'a> Deserialize<'a>> {
90    /// Path of the file.
91    path: PathBuf,
92    /// Reader used to parse the CSV file.
93    csv_reader: Option<Reader<LimitedReader<BufReader<File>>>>,
94    /// Options to customize the CSV parser.
95    options: CsvOptions,
96    /// Whether the reader has terminated its job.
97    terminated: bool,
98    _out: PhantomData<Out>,
99    buf: ByteRecord,
100}
101
102impl<Out: Data + for<'a> Deserialize<'a>> Display for CsvSource<Out> {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        write!(f, "CsvSource<{}>", std::any::type_name::<Out>())
105    }
106}
107
108impl<Out: Data + for<'a> Deserialize<'a>> CsvSource<Out> {
109    /// Create a new source that reads and parse the lines of a CSV file.
110    ///
111    /// The file is partitioned into as many chunks as replicas, each replica has to have the
112    /// **same** file in the same path. It is guaranteed that each line of the file is emitted by
113    /// exactly one replica.
114    ///
115    /// After creating the source it's possible to customize its behaviour using one of the
116    /// available methods. By default it is assumed that the delimiter is `,` and the CSV has
117    /// headers.
118    ///
119    /// Each line will be deserialized into the type `Out`, so the structure of the CSV must be
120    /// valid for that deserialization. The [`csv`](https://crates.io/crates/csv) crate is used for
121    /// the parsing.
122    ///
123    /// **Note**: the file must be readable and its size must be available. This means that only
124    /// regular files can be read.
125    ///
126    /// ## Example
127    ///
128    /// ```
129    /// # use noir_compute::{StreamContext, RuntimeConfig};
130    /// # use noir_compute::operator::source::CsvSource;
131    /// # use serde::{Deserialize, Serialize};
132    /// # let mut env = StreamContext::new(RuntimeConfig::local(1));
133    /// #[derive(Clone, Deserialize, Serialize)]
134    /// struct Thing {
135    ///     what: String,
136    ///     count: u64,
137    /// }
138    /// let source = CsvSource::<Thing>::new("/datasets/huge.csv");
139    /// let s = env.stream(source);
140    /// ```
141    pub fn new<P: Into<PathBuf>>(path: P) -> Self {
142        Self {
143            path: path.into(),
144            csv_reader: None,
145            options: Default::default(),
146            terminated: false,
147            _out: PhantomData,
148            buf: ByteRecord::new(),
149        }
150    }
151
152    /// The comment character to use when parsing CSV.
153    ///
154    /// If the start of a record begins with the byte given here, then that line is ignored by the
155    /// CSV parser.
156    ///
157    /// This is disabled by default.
158    pub fn comment(mut self, comment: Option<u8>) -> Self {
159        self.options.comment = comment;
160        self
161    }
162
163    /// The field delimiter to use when parsing CSV.
164    ///
165    /// The default is `,`.
166    pub fn delimiter(mut self, delimiter: u8) -> Self {
167        self.options.delimiter = delimiter;
168        self
169    }
170
171    /// Enable double quote escapes.
172    ///
173    /// This is enabled by default, but it may be disabled. When disabled, doubled quotes are not
174    /// interpreted as escapes.
175    pub fn double_quote(mut self, double_quote: bool) -> Self {
176        self.options.double_quote = double_quote;
177        self
178    }
179
180    /// The escape character to use when parsing CSV.
181    ///
182    /// In some variants of CSV, quotes are escaped using a special escape character like `\`
183    /// (instead of escaping quotes by doubling them).
184    ///
185    /// By default, recognizing these idiosyncratic escapes is disabled.
186    pub fn escape(mut self, escape: Option<u8>) -> Self {
187        self.options.escape = escape;
188        self
189    }
190
191    /// Whether the number of fields in records is allowed to change or not.
192    ///
193    /// When disabled (which is the default), parsing CSV data will return an error if a record is
194    /// found with a number of fields different from the number of fields in a previous record.
195    ///
196    /// When enabled, this error checking is turned off.
197    pub fn flexible(mut self, flexible: bool) -> Self {
198        self.options.flexible = flexible;
199        self
200    }
201
202    /// The quote character to use when parsing CSV.
203    ///
204    /// The default is `"`.
205    pub fn quote(mut self, quote: u8) -> Self {
206        self.options.quote = quote;
207        self
208    }
209
210    /// Enable or disable quoting.
211    ///
212    /// This is enabled by default, but it may be disabled. When disabled, quotes are not treated
213    /// specially.
214    pub fn quoting(mut self, quoting: bool) -> Self {
215        self.options.quoting = quoting;
216        self
217    }
218
219    /// The record terminator to use when parsing CSV.
220    ///
221    /// A record terminator can be any single byte. The default is a special value,
222    /// `Terminator::CRLF`, which treats any occurrence of `\r`, `\n` or `\r\n` as a single record
223    /// terminator.
224    pub fn terminator(mut self, terminator: Terminator) -> Self {
225        self.options.terminator = terminator;
226        self
227    }
228
229    /// Whether fields are trimmed of leading and trailing whitespace or not.
230    ///
231    /// By default, no trimming is performed. This method permits one to override that behavior and
232    /// choose one of the following options:
233    ///
234    /// 1. `Trim::Headers` trims only header values.
235    /// 2. `Trim::Fields` trims only non-header or "field" values.
236    /// 3. `Trim::All` trims both header and non-header values.
237    ///
238    /// A value is only interpreted as a header value if this CSV reader is configured to read a
239    /// header record (which is the default).
240    ///
241    /// When reading string records, characters meeting the definition of Unicode whitespace are
242    /// trimmed. When reading byte records, characters meeting the definition of ASCII whitespace
243    /// are trimmed. ASCII whitespace characters correspond to the set `[\t\n\v\f\r ]`.
244    pub fn trim(mut self, trim: Trim) -> Self {
245        self.options.trim = trim;
246        self
247    }
248
249    /// Whether to treat the first row as a special header row.
250    ///
251    /// By default, the first row is treated as a special header row, which means the header is
252    /// never returned by any of the record reading methods or iterators. When this is disabled
253    /// (`yes` set to `false`), the first row is not treated specially.
254    ///
255    /// Note that the `headers` and `byte_headers` methods are unaffected by whether this is set.
256    /// Those methods always return the first record.
257    pub fn has_headers(mut self, has_headers: bool) -> Self {
258        self.options.has_headers = has_headers;
259        self
260    }
261}
262
263impl<Out: Data + for<'a> Deserialize<'a>> Source for CsvSource<Out> {
264    fn replication(&self) -> Replication {
265        Replication::Unlimited
266    }
267}
268
269impl<Out: Data + for<'a> Deserialize<'a>> Operator for CsvSource<Out> {
270    type Out = Out;
271
272    fn setup(&mut self, metadata: &mut ExecutionMetadata) {
273        let global_id = metadata.global_id;
274        let instances = metadata.replicas.len();
275
276        let file = File::options()
277            .read(true)
278            .write(false)
279            .open(&self.path)
280            .unwrap_or_else(|err| {
281                panic!(
282                    "CsvSource: error while opening file {:?}: {:?}",
283                    self.path, err
284                )
285            });
286
287        let file_size = file.metadata().unwrap().len();
288
289        let mut buf_reader = BufReader::new(file);
290
291        let last_byte_terminator = match self.options.terminator {
292            Terminator::CRLF => b'\n',
293            Terminator::Any(terminator) => terminator,
294            _ => unreachable!(),
295        };
296
297        // Handle the header
298        let mut header = Vec::new();
299        let header_size = if self.options.has_headers {
300            buf_reader
301                .read_until(last_byte_terminator, &mut header)
302                .expect("Error while reading CSV header") as u64
303        } else {
304            0
305        };
306
307        // Calculate start and end offset of this replica
308        let body_size = file_size - header_size;
309        let range_size = body_size / instances as u64;
310        let mut start = header_size + range_size * global_id;
311        let mut end = if global_id as usize == instances - 1 {
312            file_size
313        } else {
314            start + range_size
315        };
316
317        // Align start byte
318        if global_id != 0 {
319            // Seek reader to the first byte to be read
320            buf_reader
321                .seek(SeekFrom::Start(start))
322                .expect("Error while seeking BufReader to start");
323            // discard first line
324            let mut buf = Vec::new();
325            start += buf_reader
326                .read_until(last_byte_terminator, &mut buf)
327                .expect("Error while reading first line from file") as u64;
328        }
329
330        // Align end byte
331        if global_id as usize != instances - 1 {
332            // Seek reader to the last byte to be read
333            buf_reader
334                .seek(SeekFrom::Start(end))
335                .expect("Error while seeking BufReader to end");
336            // get to the end of the line
337            let mut buf = Vec::new();
338            end += buf_reader
339                .read_until(last_byte_terminator, &mut buf)
340                .expect("Error while reading last line from file") as u64;
341        }
342
343        // Rewind BufReader to the start
344        buf_reader
345            .seek(SeekFrom::Start(start))
346            .expect("Error while rewinding BufReader");
347
348        // Limit the number of bytes to be read
349        let limited_reader = LimitedReader::new(buf_reader, (end - start) as usize);
350
351        // Create csv::Reader
352        let mut csv_reader = ReaderBuilder::new()
353            .comment(self.options.comment)
354            .delimiter(self.options.delimiter)
355            .double_quote(self.options.double_quote)
356            .escape(self.options.escape)
357            .flexible(self.options.flexible)
358            .quote(self.options.quote)
359            .quoting(self.options.quoting)
360            .terminator(self.options.terminator)
361            .trim(self.options.trim)
362            .has_headers(self.options.has_headers)
363            .from_reader(limited_reader);
364
365        if self.options.has_headers {
366            // set the headers of the CSV file
367            csv_reader.set_byte_headers(
368                Reader::from_reader(header.as_slice())
369                    .byte_headers()
370                    .unwrap()
371                    .to_owned(),
372            );
373        }
374
375        self.csv_reader = Some(csv_reader);
376    }
377
378    fn next(&mut self) -> StreamElement<Out> {
379        if self.terminated {
380            return StreamElement::Terminate;
381        }
382        let csv_reader = self
383            .csv_reader
384            .as_mut()
385            .expect("CsvSource was not initialized");
386
387        match csv_reader.read_byte_record(&mut self.buf) {
388            Ok(true) => {
389                let item = self
390                    .buf
391                    .deserialize::<Out>(None)
392                    .expect("csv does not match type");
393                StreamElement::Item(item)
394            }
395            Ok(false) => {
396                self.terminated = true;
397                StreamElement::FlushAndRestart
398            }
399            Err(e) => panic!("Error while reading CSV file: {:?}", e),
400        }
401    }
402
403    fn structure(&self) -> BlockStructure {
404        let mut operator = OperatorStructure::new::<Out, _>("CSVSource");
405        operator.kind = OperatorKind::Source;
406        BlockStructure::default().add_operator(operator)
407    }
408}
409
410impl<Out: Data + for<'a> Deserialize<'a>> Clone for CsvSource<Out> {
411    fn clone(&self) -> Self {
412        assert!(
413            self.csv_reader.is_none(),
414            "CsvSource must be cloned before calling setup"
415        );
416        Self {
417            path: self.path.clone(),
418            csv_reader: None,
419            options: self.options.clone(),
420            terminated: false,
421            _out: PhantomData,
422            buf: ByteRecord::new(),
423        }
424    }
425}
426
427impl crate::StreamContext {
428    /// Convenience method, creates a `CsvSource` and makes a stream using `StreamContext::stream`
429    pub fn stream_csv<T: Data + for<'a> Deserialize<'a>>(
430        &self,
431        path: impl Into<PathBuf>,
432    ) -> Stream<CsvSource<T>> {
433        let source = CsvSource::new(path);
434        self.stream(source)
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use std::io::Write;
441
442    use itertools::Itertools;
443    use serde::{Deserialize, Serialize};
444    use tempfile::NamedTempFile;
445
446    use crate::config::RuntimeConfig;
447    use crate::environment::StreamContext;
448    use crate::operator::source::CsvSource;
449
450    #[test]
451    fn csv_without_headers() {
452        for num_records in 0..100 {
453            for terminator in &["\n", "\r\n"] {
454                let file = NamedTempFile::new().unwrap();
455                for i in 0..num_records {
456                    write!(file.as_file(), "{},{}{}", i, i + 1, terminator).unwrap();
457                }
458
459                let env = StreamContext::new(RuntimeConfig::local(4));
460                let source = CsvSource::<(i32, i32)>::new(file.path()).has_headers(false);
461                let res = env.stream(source).shuffle().collect_vec();
462                env.execute_blocking();
463
464                let mut res = res.get().unwrap();
465                res.sort_unstable();
466                assert_eq!(res, (0..num_records).map(|x| (x, x + 1)).collect_vec());
467            }
468        }
469    }
470
471    #[test]
472    fn csv_with_headers() {
473        #[derive(Clone, Serialize, Deserialize)]
474        struct T {
475            a: i32,
476            b: i32,
477        }
478
479        for num_records in 0..100 {
480            for terminator in &["\n", "\r\n"] {
481                let file = NamedTempFile::new().unwrap();
482                write!(file.as_file(), "a,b{terminator}").unwrap();
483                for i in 0..num_records {
484                    write!(file.as_file(), "{},{}{}", i, i + 1, terminator).unwrap();
485                }
486
487                let env = StreamContext::new(RuntimeConfig::local(4));
488                let source = CsvSource::<T>::new(file.path());
489                let res = env.stream(source).shuffle().collect_vec();
490                env.execute_blocking();
491
492                let res = res
493                    .get()
494                    .unwrap()
495                    .into_iter()
496                    .map(|x| (x.a, x.b))
497                    .sorted()
498                    .collect_vec();
499                assert_eq!(res, (0..num_records).map(|x| (x, x + 1)).collect_vec());
500            }
501        }
502    }
503}