csv_scout/
sniffer.rs

1use hashbrown::HashMap;
2use std::cell::RefCell;
3use std::fs::File;
4use std::io::{Read, Seek, SeekFrom};
5use std::path::Path;
6
7use csv::Reader;
8use csv_core as csvc;
9use regex::Regex;
10
11use crate::{
12    chain::{Chain, STATE_STEADYFLEX, STATE_STEADYSTRICT, STATE_UNSTEADY, ViterbiResults},
13    error::{Result, SnifferError},
14    // field_type::DatePreference,
15    metadata::{Dialect, Metadata, Quote},
16    sample::{SampleIter, SampleSize, take_sample_from_start},
17};
18
19const NUM_ASCII_CHARS: usize = 128;
20
21thread_local! (pub static IS_UTF8: RefCell<bool> = const { RefCell::new(true) });
22// thread_local! (pub static DATE_PREFERENCE: RefCell<DatePreference> = const { RefCell::new(DatePreference::MdyFormat) });
23
24/// A CSV sniffer.
25///
26/// The sniffer examines a CSV file, passed in either through a file or a reader.
27#[derive(Debug, Default)]
28pub struct Sniffer {
29    // CSV file dialect guesses
30    delimiter: Option<u8>,
31    // num_preamble_rows: Option<usize>,
32    // has_header_row: Option<bool>,
33    quote: Option<Quote>,
34    flexible: Option<bool>,
35    is_utf8: Option<bool>,
36
37    // Metadata guesses
38    delimiter_freq: Option<usize>,
39    // fields: Vec<String>,
40    // types: Vec<Type>,
41    // avg_record_len: Option<usize>,
42
43    // sample size to sniff
44    sample_size: Option<SampleSize>,
45    // date format preference
46    // date_preference: Option<DatePreference>,
47}
48impl Sniffer {
49    /// Create a new CSV sniffer.
50    pub fn new() -> Self {
51        Self::default()
52    }
53    /// Specify the delimiter character.
54    pub fn delimiter(&mut self, delimiter: u8) -> &mut Self {
55        self.delimiter = Some(delimiter);
56        self
57    }
58    /// Specify the header type (whether the CSV file has a header row, and where the data starts).
59    // pub fn header(&mut self, header: &Header) -> &mut Self {
60    //     self.num_preamble_rows = Some(header.num_preamble_rows);
61    //     self.has_header_row = Some(header.has_header_row);
62    //     self
63    // }
64    /// Specify the quote character (if any), and whether two quotes in a row as to be interpreted
65    /// as an escaped quote.
66    pub fn quote(&mut self, quote: Quote) -> &mut Self {
67        self.quote = Some(quote);
68        self
69    }
70
71    /// The size of the sample to examine while sniffing. If using `SampleSize::Records`, the
72    /// sniffer will use the `Terminator::CRLF` as record separator.
73    ///
74    /// The sample size defaults to `SampleSize::Bytes(4096)`.
75    pub fn sample_size(&mut self, sample_size: SampleSize) -> &mut Self {
76        self.sample_size = Some(sample_size);
77        self
78    }
79
80    fn get_sample_size(&self) -> SampleSize {
81        self.sample_size.unwrap_or(SampleSize::Bytes(1 << 14))
82    }
83
84    // The date format preference when sniffing.
85    //
86    // The date format preference defaults to `DatePreference::MDY`.
87    // pub fn date_preference(&mut self, date_preference: DatePreference) -> &mut Self {
88    //     DATE_PREFERENCE.with(|preference| {
89    //         *preference.borrow_mut() = date_preference;
90    //     });
91    //     self.date_preference = Some(date_preference);
92    //     self
93    // }
94
95    /// Sniff the CSV file located at the provided path, and return a `Reader` (from the
96    /// [`csv`](https://docs.rs/csv) crate) ready to ready the file.
97    ///
98    /// Fails on file opening or rendering errors, or on an error examining the file.
99    pub fn open_path<P: AsRef<Path>>(&mut self, path: P) -> Result<Reader<File>> {
100        self.open_reader(File::open(path)?)
101    }
102    /// Sniff the CSV file provided by the reader, and return a [`csv`](https://docs.rs/csv)
103    /// `Reader` object.
104    ///
105    /// Fails on file opening or rendering errors, or on an error examining the file.
106    pub fn open_reader<R: Read + Seek>(&mut self, mut reader: R) -> Result<Reader<R>> {
107        let metadata = self.sniff_reader(&mut reader)?;
108        reader.seek(SeekFrom::Start(0))?;
109        metadata.dialect.open_reader(reader)
110    }
111
112    /// Sniff the CSV file located at the provided path, and return a
113    /// [`Metadata`](struct.Metadata.html) object containing information about the CSV file.
114    ///
115    /// Fails on file opening or rendering errors, or on an error examining the file.
116    pub fn sniff_path<P: AsRef<Path>>(&mut self, path: P) -> Result<Metadata> {
117        let file = File::open(path)?;
118        self.sniff_reader(&file)
119    }
120    /// Sniff the CSV file provider by the reader, and return a
121    /// [`Metadata`](struct.Metadata.html) object containing information about the CSV file.
122    ///
123    /// Fails on file opening or readering errors, or on an error examining the file.
124    pub fn sniff_reader<R: Read + Seek>(&mut self, mut reader: R) -> Result<Metadata> {
125        // init IS_UTF8 global var to true
126        IS_UTF8.with(|flag| {
127            *flag.borrow_mut() = true;
128        });
129        // guess quotes & delim
130        self.infer_quotes_delim(&mut reader)?;
131
132        // if we have a delimiter, we just need to search for num_preamble_rows and check for
133        // flexible. Otherwise, we need to guess a delimiter as well.
134        if self.delimiter.is_some() {
135            self.infer_preamble_known_delim(&mut reader)?;
136        } else {
137            self.infer_delim_preamble(&mut reader)?;
138        }
139
140        // self.infer_types(&mut reader)?;
141        self.is_utf8 = Some(IS_UTF8.with(|flag| *flag.borrow()));
142
143        // as this point of the process, we should have all these filled in.
144        // assert!(
145        //     self.delimiter.is_some()
146        //         && self.num_preamble_rows.is_some()
147        //         && self.quote.is_some()
148        //         && self.flexible.is_some()
149        //         && self.is_utf8.is_some()
150        //         && self.delimiter_freq.is_some()
151        //         && self.has_header_row.is_some()
152        //         && self.avg_record_len.is_some()
153        //         && self.delimiter_freq.is_some()
154        // );
155        if !(self.delimiter.is_some()
156            // && self.num_preamble_rows.is_some()
157            && self.quote.is_some()
158            && self.flexible.is_some()
159            && self.is_utf8.is_some()
160            && self.delimiter_freq.is_some()
161            // && self.has_header_row.is_some()
162            // && self.avg_record_len.is_some()
163            && self.delimiter_freq.is_some())
164        {
165            return Err(SnifferError::SniffingFailed(format!(
166                "Failed to infer all metadata: {self:?}"
167            )));
168        }
169        // safety: we just checked that all these are Some, so it's safe to unwrap
170        Ok(Metadata {
171            dialect: Dialect {
172                delimiter: self.delimiter.unwrap(),
173                // header: Header {
174                //     num_preamble_rows: self.num_preamble_rows.unwrap(),
175                //     has_header_row: self.has_header_row.unwrap(),
176                // },
177                quote: self.quote.clone().unwrap(),
178                // flexible: self.flexible.unwrap(),
179                // is_utf8: self.is_utf8.unwrap(),
180            },
181            // avg_record_len: self.avg_record_len.unwrap(),
182            // num_fields: self.delimiter_freq.unwrap() + 1,
183            // fields: self.fields.clone(),
184            // types: self.types.clone(),
185        })
186    }
187
188    // Infers quotes and delimiter from quoted (or possibly quoted) files. If quotes detected,
189    // updates self.quote and self.delimiter. If quotes not detected, updates self.quote to
190    // Quote::None. Only valid quote characters: " (double-quote), ' (single-quote), ` (back-tick).
191    fn infer_quotes_delim<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
192        if let (&Some(_), &Some(_)) = (&self.quote, &self.delimiter) {
193            // nothing left to infer!
194            return Ok(());
195        }
196        let quote_guesses = match self.quote {
197            Some(Quote::Some(chr)) => vec![chr],
198            Some(Quote::None) => {
199                // this function only checks quoted (or possibly quoted) files, nothing left to
200                // do if we know there are no quotes
201                return Ok(());
202            }
203            None => vec![b'\'', b'"', b'`'],
204        };
205        let (quote_chr, (quote_cnt, delim_guess)) = quote_guesses.iter().try_fold(
206            (b'"', (0, b'\0')),
207            |acc, &chr| -> Result<(u8, (usize, u8))> {
208                let mut sample_reader = take_sample_from_start(reader, self.get_sample_size())?;
209                if let Some((cnt, delim_chr)) =
210                    quote_count(&mut sample_reader, char::from(chr), self.delimiter)?
211                {
212                    Ok(if cnt > acc.1.0 {
213                        (chr, (cnt, delim_chr))
214                    } else {
215                        acc
216                    })
217                } else {
218                    Ok(acc)
219                }
220            },
221        )?;
222        if quote_cnt == 0 {
223            self.quote = Some(Quote::None);
224        } else {
225            self.quote = Some(Quote::Some(quote_chr));
226            self.delimiter = Some(delim_guess);
227        };
228        Ok(())
229    }
230
231    // Updates delimiter frequency, number of preamble rows, and flexible boolean.
232    fn infer_preamble_known_delim<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
233        // prerequisites for calling this function:
234        if !(self.delimiter.is_some() && self.quote.is_some()) {
235            // instead of assert, return an error
236            // assert!(self.delimiter.is_some() && self.quote.is_some());
237            return Err(SnifferError::SniffingFailed(
238                "infer_preamble_known_delim called without delimiter and quote".into(),
239            ));
240        }
241        // safety: unwraps for delimiter and quote are safe since we just checked above
242        let (quote, delim) = (self.quote.clone().unwrap(), self.delimiter.unwrap());
243
244        let sample_iter = take_sample_from_start(reader, self.get_sample_size())?;
245
246        let mut chain = Chain::default();
247
248        if let Quote::Some(character) = quote {
249            // since we have a quote, we need to run this data through the csv_core::Reader (which
250            // properly escapes quoted fields
251            let mut csv_reader = csvc::ReaderBuilder::new()
252                .delimiter(delim)
253                .quote(character)
254                .build();
255
256            let mut output = vec![];
257            let mut ends = vec![];
258            for line in sample_iter {
259                let line = line?;
260                if line.len() > output.len() {
261                    output.resize(line.len(), 0);
262                }
263                if line.len() > ends.len() {
264                    ends.resize(line.len(), 0);
265                }
266                let (result, _, _, n_ends) =
267                    csv_reader.read_record(line.as_bytes(), &mut output, &mut ends);
268                // check to make sure record was read correctly
269                match result {
270                    csvc::ReadRecordResult::OutputFull | csvc::ReadRecordResult::OutputEndsFull => {
271                        return Err(SnifferError::SniffingFailed(format!(
272                            "failure to read quoted CSV record: {result:?}"
273                        )));
274                    }
275                    _ => {} // non-error results, do nothing
276                }
277                // n_ends is the number of barries between fields, so it's the same as the number
278                // of delimiters
279                chain.add_observation(n_ends);
280            }
281        } else {
282            for line in sample_iter {
283                let line = line?;
284                let freq = bytecount::count(line.as_bytes(), delim);
285                chain.add_observation(freq);
286            }
287        }
288        self.run_chains(vec![chain])
289    }
290
291    // Updates delimiter, delimiter frequency, number of preamble rows, and flexible boolean.
292    fn infer_delim_preamble<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
293        let sample_iter = take_sample_from_start(reader, self.get_sample_size())?;
294
295        let mut chains = vec![Chain::default(); NUM_ASCII_CHARS];
296        for line in sample_iter {
297            let line = line?;
298            let mut freqs = [0; NUM_ASCII_CHARS];
299            for &chr in line.as_bytes() {
300                if chr < NUM_ASCII_CHARS as u8 {
301                    freqs[chr as usize] += 1;
302                }
303            }
304            for (chr, &freq) in freqs.iter().enumerate() {
305                chains[chr].add_observation(freq);
306            }
307        }
308
309        self.run_chains(chains)
310    }
311
312    // Updates delimiter (if not already known), delimiter frequency, number of preamble rows, and
313    // flexible boolean.
314    fn run_chains(&mut self, mut chains: Vec<Chain>) -> Result<()> {
315        // Find the 'best' delimiter: choose strict (non-flexible) delimiters over flexible ones,
316        // and choose the one that had the highest probability markov chain in the end.
317        //
318        // In the case where delim is already known, 'best_delim' will be incorrect (since it won't
319        // correspond with position in a vector of Chains), but we'll just ignore it when
320        // constructing our return value later. 'best_state' and 'path' are necessary, though, to
321        // compute the preamble rows.
322        let (best_delim, delim_freq, best_state, _, _) = chains.iter_mut().enumerate().fold(
323            (b',', 0, STATE_UNSTEADY, vec![], 0.0),
324            |acc, (i, ref mut chain)| {
325                let (_, _, best_state, _, best_state_prob) = acc;
326                let ViterbiResults {
327                    max_delim_freq,
328                    path,
329                } = chain.viterbi();
330                if path.is_empty() {
331                    return acc;
332                }
333                let (final_state, final_viter) = path[path.len() - 1];
334                if final_state < best_state
335                    || (final_state == best_state && final_viter.prob > best_state_prob)
336                {
337                    (i as u8, max_delim_freq, final_state, path, final_viter.prob)
338                } else {
339                    acc
340                }
341            },
342        );
343        self.flexible = Some(match best_state {
344            STATE_STEADYSTRICT => false,
345            STATE_STEADYFLEX => true,
346            _ => {
347                return Err(SnifferError::SniffingFailed(
348                    "unable to find valid delimiter".to_string(),
349                ));
350            }
351        });
352
353        // Find the number of preamble rows (the number of rows during which the state fluctuated
354        // before getting to the final state).
355        // let mut num_preamble_rows = 0;
356        // since path has an extra state as the beginning, skip one
357        // for &(state, _) in path.iter().skip(2) {
358        //     if state == best_state {
359        //         break;
360        //     }
361        //     num_preamble_rows += 1;
362        // }
363        // if num_preamble_rows > 0 {
364        //     num_preamble_rows += 1;
365        // }
366        if self.delimiter.is_none() {
367            self.delimiter = Some(best_delim);
368        }
369        self.delimiter_freq = Some(delim_freq);
370        // self.num_preamble_rows = Some(num_preamble_rows);
371        Ok(())
372    }
373
374    // fn infer_types<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
375    //     // prerequisites for calling this function:
376    //     if self.delimiter_freq.is_none() {
377    //         // instead of assert, return error
378    //         // assert!(self.delimiter_freq.is_some());
379    //         return Err(SnifferError::SniffingFailed(
380    //             "delimiter frequency not known".to_string(),
381    //         ));
382    //     }
383    //     // safety: unwrap is safe as we just checked that delimiter_freq is Some
384    //     let field_count = self.delimiter_freq.unwrap() + 1;
385
386    //     let mut csv_reader = self.create_csv_reader(reader)?;
387    //     let mut records_iter = csv_reader.byte_records();
388    //     let mut n_bytes = 0;
389    //     let mut n_records = 0;
390    //     let sample_size = self.get_sample_size();
391
392    //     // Infer types for the top row. We'll save this set of types to check against the types
393    //     // of the remaining rows to see if this is part of the data or a separate header row.
394    //     let header_row_types = match records_iter.next() {
395    //         Some(record) => {
396    //             let byte_record = record?;
397    //             let str_record = StringRecord::from_byte_record_lossy(byte_record);
398    //             n_records += 1;
399    //             n_bytes += count_bytes(&str_record);
400    //             infer_record_types(&str_record)
401    //         }
402    //         None => {
403    //             return Err(SnifferError::SniffingFailed(
404    //                 "CSV empty (after preamble)".into(),
405    //             ));
406    //         }
407    //     };
408    //     let mut row_types = vec![TypeGuesses::all(); field_count];
409
410    //     for record in records_iter {
411    //         let record = record?;
412    //         for (i, field) in record.iter().enumerate() {
413    //             let str_field = String::from_utf8_lossy(field).to_string();
414    //             row_types[i] &= infer_types(&str_field);
415    //         }
416    //         n_records += 1;
417    //         n_bytes += record.as_slice().len();
418    //         // break if we pass sample size limits
419    //         match sample_size {
420    //             SampleSize::Records(recs) => {
421    //                 if n_records > recs {
422    //                     break;
423    //                 }
424    //             }
425    //             SampleSize::Bytes(bytes) => {
426    //                 if n_bytes > bytes {
427    //                     break;
428    //                 }
429    //             }
430    //             SampleSize::All => {}
431    //         }
432    //     }
433    //     if n_records == 1 {
434    //         // there's only one row in the whole data file (the top row already parsed),
435    //         // so we're going to assume it's a data row, not a header row.
436    //         self.has_header_row = Some(false);
437    //         self.types = get_best_types(&header_row_types);
438    //         self.avg_record_len = Some(n_bytes);
439    //         return Ok(());
440    //     }
441
442    //     if header_row_types
443    //         .iter()
444    //         .zip(&row_types)
445    //         .any(|(header, data)| !data.allows(*header))
446    //     {
447    //         self.has_header_row = Some(true);
448    //         // get field names in header
449    //         for field in csv_reader.byte_headers()? {
450    //             self.fields.push(String::from_utf8_lossy(field).to_string());
451    //         }
452    //     } else {
453    //         self.has_header_row = Some(false);
454    //     }
455
456    //     self.types = get_best_types(&row_types);
457    //     self.avg_record_len = Some(n_bytes / n_records);
458    //     Ok(())
459    // }
460
461    // fn create_csv_reader<'a, R: Read + Seek>(
462    //     &self,
463    //     mut reader: &'a mut R,
464    // ) -> Result<Reader<&'a mut R>> {
465    //     reader.seek(SeekFrom::Start(0))?;
466    //     if let Some(num_preamble_rows) = self.num_preamble_rows {
467    //         snip_preamble(&mut reader, num_preamble_rows)?;
468    //     }
469
470    //     let mut builder = csv::ReaderBuilder::new();
471    //     if let Some(delim) = self.delimiter {
472    //         builder.delimiter(delim);
473    //     }
474    //     if let Some(has_header_row) = self.has_header_row {
475    //         builder.has_headers(has_header_row);
476    //     }
477    //     match self.quote {
478    //         Some(Quote::Some(chr)) => {
479    //             builder.quoting(true);
480    //             builder.quote(chr);
481    //         }
482    //         Some(Quote::None) => {
483    //             builder.quoting(false);
484    //         }
485    //         _ => {}
486    //     }
487    //     if let Some(flexible) = self.flexible {
488    //         builder.flexible(flexible);
489    //     }
490
491    //     Ok(builder.from_reader(reader))
492    // }
493}
494
495fn quote_count<R: Read>(
496    sample_iter: &mut SampleIter<R>,
497    character: char,
498    delim: Option<u8>,
499) -> Result<Option<(usize, u8)>> {
500    let pattern = delim.map_or_else(
501        || format!(r#"{character}\s*?(?P<delim>[^\w\n'"`])\s*{character}"#),
502        |delim| format!(r"{character}\s*?{delim}\s*{character}"),
503    );
504    // safety: unwrap is safe as we know the pattern is valid
505    let re = Regex::new(&pattern).unwrap();
506
507    let mut delim_count_map: HashMap<String, usize> = HashMap::new();
508    let mut count = 0;
509    for line in sample_iter {
510        let line = line?;
511        for cap in re.captures_iter(&line) {
512            count += 1;
513            // if we already know delimiter, we don't need to count
514            if delim.is_some() {
515            } else {
516                *delim_count_map.entry(cap["delim"].to_string()).or_insert(0) += 1;
517            }
518        }
519    }
520    if count == 0 {
521        return Ok(None);
522    }
523
524    // if we already know delimiter, no need to go through map
525    if let Some(delim) = delim {
526        return Ok(Some((count, delim)));
527    }
528
529    // find the highest-count delimiter in the map
530    let (delim_count, delim) =
531        delim_count_map
532            .iter()
533            .fold((0, b'\0'), |acc, (delim, &delim_count)| {
534                // assert!(delim.len() == 1);
535                if delim.len() != 1 {
536                    // instead of assert, we set delim count to 0 and delim to null byte
537                    // this will be picked up the delim_count == 0 check below
538                    (0, b'\0')
539                } else if delim_count > acc.0 {
540                    (delim_count, (delim.as_ref() as &[u8])[0])
541                } else {
542                    acc
543                }
544            });
545
546    // delim_count should be nonzero; delim should always match at least something
547    // instead of the assert, we return an error
548    if delim_count == 0 {
549        // assert_ne!(delim_count, 0, "invalid regex match: no delimiter found");
550        return Err(SnifferError::SniffingFailed(
551            "invalid regex match: no delimiter found".into(),
552        ));
553    }
554    Ok(Some((count, delim)))
555}