csv_scout/
sample.rs

1use std::io::{BufRead, BufReader, Read, Seek, SeekFrom};
2
3use crate::error::Result;
4use crate::sniffer::IS_UTF8;
5
6/// Argument used when calling `sample_size` on `Sniffer`.
7#[derive(Debug, Clone, Copy)]
8pub enum SampleSize {
9    /// Use a number of records as the size of the sample to sniff.
10    Records(usize),
11    /// Use a number of bytes as the size of the sample to sniff.
12    Bytes(usize),
13    /// Sniff the entire sample.
14    All,
15}
16
17pub fn take_sample_from_start<R>(reader: &mut R, sample_size: SampleSize) -> Result<SampleIter<R>>
18where
19    R: Read + Seek,
20{
21    reader.seek(SeekFrom::Start(0))?;
22    Ok(SampleIter::new(reader, sample_size))
23}
24
25pub struct SampleIter<'a, R: 'a + Read> {
26    reader: BufReader<&'a mut R>,
27    sample_size: SampleSize,
28    n_bytes: usize,
29    n_records: usize,
30    is_done: bool,
31}
32
33impl<'a, R: Read> SampleIter<'a, R> {
34    fn new(reader: &'a mut R, sample_size: SampleSize) -> Self {
35        let buf_reader = BufReader::new(reader);
36        SampleIter {
37            reader: buf_reader,
38            sample_size,
39            n_bytes: 0,
40            n_records: 0,
41            is_done: false,
42        }
43    }
44}
45
46impl<R: Read> Iterator for SampleIter<'_, R> {
47    type Item = Result<String>;
48
49    fn next(&mut self) -> Option<Result<String>> {
50        if self.is_done {
51            return None;
52        }
53
54        let mut buf = Vec::new();
55        let n_bytes_read = match self.reader.read_until(b'\n', &mut buf) {
56            Ok(n_bytes_read) => n_bytes_read,
57            Err(e) => {
58                return Some(Err(e.into()));
59            }
60        };
61        if n_bytes_read == 0 {
62            self.is_done = true;
63            return None;
64        }
65
66        let mut output = simdutf8::basic::from_utf8(&buf).map_or_else(
67            |_| {
68                // Its not all utf-8, set IS_UTF8 global to false
69                IS_UTF8.with(|flag| {
70                    *flag.borrow_mut() = false;
71                });
72                String::from_utf8_lossy(&buf).to_string()
73            },
74            std::string::ToString::to_string,
75        );
76
77        let last_byte = (output.as_ref() as &[u8])[output.len() - 1];
78        if last_byte != b'\n' && last_byte != b'\r' {
79            // non CR/LF-ended line
80            // line was cut off before ending, so we ignore it!
81            self.is_done = true;
82            return None;
83        }
84
85        output = output.trim_matches(|c| c == '\n' || c == '\r').into();
86        self.n_bytes += n_bytes_read;
87        self.n_records += 1;
88        match self.sample_size {
89            SampleSize::Records(max_records) => {
90                if self.n_records > max_records {
91                    self.is_done = true;
92                    return None;
93                }
94            }
95            SampleSize::Bytes(max_bytes) => {
96                if self.n_bytes > max_bytes {
97                    self.is_done = true;
98                    return None;
99                }
100            }
101            SampleSize::All => {}
102        }
103        Some(Ok(output))
104    }
105}