noir_compute/operator/source/
csv.rs1use std::fmt::Display;
2use std::fs::File;
3use std::io;
4use std::io::{BufRead, BufReader, Read, Seek, SeekFrom};
5use std::marker::PhantomData;
6use std::path::PathBuf;
7
8use csv::{ByteRecord, Reader, ReaderBuilder, Terminator, Trim};
9use serde::Deserialize;
10
11use crate::block::{BlockStructure, OperatorKind, OperatorStructure, Replication};
12use crate::operator::source::Source;
13use crate::operator::{Data, Operator, StreamElement};
14use crate::scheduler::ExecutionMetadata;
15use crate::Stream;
16
17struct LimitedReader<R: Read> {
19 inner: R,
20 remaining: usize,
22}
23
24impl<R: Read> LimitedReader<R> {
25 fn new(inner: R, remaining: usize) -> Self {
26 Self { inner, remaining }
27 }
28}
29
30impl<R: Read> Read for LimitedReader<R> {
31 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
32 let read_bytes = if self.remaining > 0 {
33 self.inner.read(buf)?.min(self.remaining)
35 } else {
36 0
38 };
39 self.remaining -= read_bytes;
40 Ok(read_bytes)
41 }
42}
43
44#[derive(Clone)]
46struct CsvOptions {
47 comment: Option<u8>,
49 delimiter: u8,
51 double_quote: bool,
53 escape: Option<u8>,
55 flexible: bool,
57 quote: u8,
59 quoting: bool,
61 terminator: Terminator,
63 trim: Trim,
65 has_headers: bool,
67}
68
69impl Default for CsvOptions {
70 fn default() -> Self {
71 Self {
72 comment: None,
73 delimiter: b',',
74 double_quote: true,
75 escape: None,
76 flexible: false,
77 quote: b'"',
78 quoting: true,
79 terminator: Terminator::CRLF,
80 trim: Trim::None,
81 has_headers: true,
82 }
83 }
84}
85
86pub struct CsvSource<Out: Data + for<'a> Deserialize<'a>> {
90 path: PathBuf,
92 csv_reader: Option<Reader<LimitedReader<BufReader<File>>>>,
94 options: CsvOptions,
96 terminated: bool,
98 _out: PhantomData<Out>,
99 buf: ByteRecord,
100}
101
102impl<Out: Data + for<'a> Deserialize<'a>> Display for CsvSource<Out> {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 write!(f, "CsvSource<{}>", std::any::type_name::<Out>())
105 }
106}
107
108impl<Out: Data + for<'a> Deserialize<'a>> CsvSource<Out> {
109 pub fn new<P: Into<PathBuf>>(path: P) -> Self {
142 Self {
143 path: path.into(),
144 csv_reader: None,
145 options: Default::default(),
146 terminated: false,
147 _out: PhantomData,
148 buf: ByteRecord::new(),
149 }
150 }
151
152 pub fn comment(mut self, comment: Option<u8>) -> Self {
159 self.options.comment = comment;
160 self
161 }
162
163 pub fn delimiter(mut self, delimiter: u8) -> Self {
167 self.options.delimiter = delimiter;
168 self
169 }
170
171 pub fn double_quote(mut self, double_quote: bool) -> Self {
176 self.options.double_quote = double_quote;
177 self
178 }
179
180 pub fn escape(mut self, escape: Option<u8>) -> Self {
187 self.options.escape = escape;
188 self
189 }
190
191 pub fn flexible(mut self, flexible: bool) -> Self {
198 self.options.flexible = flexible;
199 self
200 }
201
202 pub fn quote(mut self, quote: u8) -> Self {
206 self.options.quote = quote;
207 self
208 }
209
210 pub fn quoting(mut self, quoting: bool) -> Self {
215 self.options.quoting = quoting;
216 self
217 }
218
219 pub fn terminator(mut self, terminator: Terminator) -> Self {
225 self.options.terminator = terminator;
226 self
227 }
228
229 pub fn trim(mut self, trim: Trim) -> Self {
245 self.options.trim = trim;
246 self
247 }
248
249 pub fn has_headers(mut self, has_headers: bool) -> Self {
258 self.options.has_headers = has_headers;
259 self
260 }
261}
262
263impl<Out: Data + for<'a> Deserialize<'a>> Source for CsvSource<Out> {
264 fn replication(&self) -> Replication {
265 Replication::Unlimited
266 }
267}
268
269impl<Out: Data + for<'a> Deserialize<'a>> Operator for CsvSource<Out> {
270 type Out = Out;
271
272 fn setup(&mut self, metadata: &mut ExecutionMetadata) {
273 let global_id = metadata.global_id;
274 let instances = metadata.replicas.len();
275
276 let file = File::options()
277 .read(true)
278 .write(false)
279 .open(&self.path)
280 .unwrap_or_else(|err| {
281 panic!(
282 "CsvSource: error while opening file {:?}: {:?}",
283 self.path, err
284 )
285 });
286
287 let file_size = file.metadata().unwrap().len();
288
289 let mut buf_reader = BufReader::new(file);
290
291 let last_byte_terminator = match self.options.terminator {
292 Terminator::CRLF => b'\n',
293 Terminator::Any(terminator) => terminator,
294 _ => unreachable!(),
295 };
296
297 let mut header = Vec::new();
299 let header_size = if self.options.has_headers {
300 buf_reader
301 .read_until(last_byte_terminator, &mut header)
302 .expect("Error while reading CSV header") as u64
303 } else {
304 0
305 };
306
307 let body_size = file_size - header_size;
309 let range_size = body_size / instances as u64;
310 let mut start = header_size + range_size * global_id;
311 let mut end = if global_id as usize == instances - 1 {
312 file_size
313 } else {
314 start + range_size
315 };
316
317 if global_id != 0 {
319 buf_reader
321 .seek(SeekFrom::Start(start))
322 .expect("Error while seeking BufReader to start");
323 let mut buf = Vec::new();
325 start += buf_reader
326 .read_until(last_byte_terminator, &mut buf)
327 .expect("Error while reading first line from file") as u64;
328 }
329
330 if global_id as usize != instances - 1 {
332 buf_reader
334 .seek(SeekFrom::Start(end))
335 .expect("Error while seeking BufReader to end");
336 let mut buf = Vec::new();
338 end += buf_reader
339 .read_until(last_byte_terminator, &mut buf)
340 .expect("Error while reading last line from file") as u64;
341 }
342
343 buf_reader
345 .seek(SeekFrom::Start(start))
346 .expect("Error while rewinding BufReader");
347
348 let limited_reader = LimitedReader::new(buf_reader, (end - start) as usize);
350
351 let mut csv_reader = ReaderBuilder::new()
353 .comment(self.options.comment)
354 .delimiter(self.options.delimiter)
355 .double_quote(self.options.double_quote)
356 .escape(self.options.escape)
357 .flexible(self.options.flexible)
358 .quote(self.options.quote)
359 .quoting(self.options.quoting)
360 .terminator(self.options.terminator)
361 .trim(self.options.trim)
362 .has_headers(self.options.has_headers)
363 .from_reader(limited_reader);
364
365 if self.options.has_headers {
366 csv_reader.set_byte_headers(
368 Reader::from_reader(header.as_slice())
369 .byte_headers()
370 .unwrap()
371 .to_owned(),
372 );
373 }
374
375 self.csv_reader = Some(csv_reader);
376 }
377
378 fn next(&mut self) -> StreamElement<Out> {
379 if self.terminated {
380 return StreamElement::Terminate;
381 }
382 let csv_reader = self
383 .csv_reader
384 .as_mut()
385 .expect("CsvSource was not initialized");
386
387 match csv_reader.read_byte_record(&mut self.buf) {
388 Ok(true) => {
389 let item = self
390 .buf
391 .deserialize::<Out>(None)
392 .expect("csv does not match type");
393 StreamElement::Item(item)
394 }
395 Ok(false) => {
396 self.terminated = true;
397 StreamElement::FlushAndRestart
398 }
399 Err(e) => panic!("Error while reading CSV file: {:?}", e),
400 }
401 }
402
403 fn structure(&self) -> BlockStructure {
404 let mut operator = OperatorStructure::new::<Out, _>("CSVSource");
405 operator.kind = OperatorKind::Source;
406 BlockStructure::default().add_operator(operator)
407 }
408}
409
410impl<Out: Data + for<'a> Deserialize<'a>> Clone for CsvSource<Out> {
411 fn clone(&self) -> Self {
412 assert!(
413 self.csv_reader.is_none(),
414 "CsvSource must be cloned before calling setup"
415 );
416 Self {
417 path: self.path.clone(),
418 csv_reader: None,
419 options: self.options.clone(),
420 terminated: false,
421 _out: PhantomData,
422 buf: ByteRecord::new(),
423 }
424 }
425}
426
427impl crate::StreamContext {
428 pub fn stream_csv<T: Data + for<'a> Deserialize<'a>>(
430 &self,
431 path: impl Into<PathBuf>,
432 ) -> Stream<CsvSource<T>> {
433 let source = CsvSource::new(path);
434 self.stream(source)
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use std::io::Write;
441
442 use itertools::Itertools;
443 use serde::{Deserialize, Serialize};
444 use tempfile::NamedTempFile;
445
446 use crate::config::RuntimeConfig;
447 use crate::environment::StreamContext;
448 use crate::operator::source::CsvSource;
449
450 #[test]
451 fn csv_without_headers() {
452 for num_records in 0..100 {
453 for terminator in &["\n", "\r\n"] {
454 let file = NamedTempFile::new().unwrap();
455 for i in 0..num_records {
456 write!(file.as_file(), "{},{}{}", i, i + 1, terminator).unwrap();
457 }
458
459 let env = StreamContext::new(RuntimeConfig::local(4));
460 let source = CsvSource::<(i32, i32)>::new(file.path()).has_headers(false);
461 let res = env.stream(source).shuffle().collect_vec();
462 env.execute_blocking();
463
464 let mut res = res.get().unwrap();
465 res.sort_unstable();
466 assert_eq!(res, (0..num_records).map(|x| (x, x + 1)).collect_vec());
467 }
468 }
469 }
470
471 #[test]
472 fn csv_with_headers() {
473 #[derive(Clone, Serialize, Deserialize)]
474 struct T {
475 a: i32,
476 b: i32,
477 }
478
479 for num_records in 0..100 {
480 for terminator in &["\n", "\r\n"] {
481 let file = NamedTempFile::new().unwrap();
482 write!(file.as_file(), "a,b{terminator}").unwrap();
483 for i in 0..num_records {
484 write!(file.as_file(), "{},{}{}", i, i + 1, terminator).unwrap();
485 }
486
487 let env = StreamContext::new(RuntimeConfig::local(4));
488 let source = CsvSource::<T>::new(file.path());
489 let res = env.stream(source).shuffle().collect_vec();
490 env.execute_blocking();
491
492 let res = res
493 .get()
494 .unwrap()
495 .into_iter()
496 .map(|x| (x.a, x.b))
497 .sorted()
498 .collect_vec();
499 assert_eq!(res, (0..num_records).map(|x| (x, x + 1)).collect_vec());
500 }
501 }
502 }
503}