csv_scout/
sniffer.rs

1use hashbrown::HashMap;
2use itertools::Itertools;
3use std::cell::RefCell;
4use std::fs::File;
5use std::io::{Read, Seek, SeekFrom};
6use std::path::Path;
7
8use csv::Reader;
9use csv_core as csvc;
10use regex::{Captures, Regex};
11
12use crate::{
13    chain::{Chain, STATE_STEADYFLEX, STATE_STEADYSTRICT, STATE_UNSTEADY, ViterbiResults},
14    error::{Result, SnifferError},
15    // field_type::DatePreference,
16    metadata::{Dialect, Metadata, Quote},
17    sample::{SampleIter, SampleSize, take_sample_from_start},
18};
19
20type NumberOfOccurrences = u32;
21type NumberOfLines = u32;
22type AdjacentFrequency = u32;
23
24const TOLERANCE: u32 = 1;
25const NUM_ASCII_CHARS: usize = 128;
26const CANDIDATES: &[u8] = b"\t,;|:";
27
28thread_local! (pub static IS_UTF8: RefCell<bool> = const { RefCell::new(true) });
29// thread_local! (pub static DATE_PREFERENCE: RefCell<DatePreference> = const { RefCell::new(DatePreference::MdyFormat) });
30
31/// A CSV sniffer.
32///
33/// The sniffer examines a CSV file, passed in either through a file or a reader.
34#[derive(Debug, Default)]
35pub struct Sniffer {
36    // CSV file dialect guesses
37    delimiter: Option<u8>,
38    // num_preamble_rows: Option<usize>,
39    // has_header_row: Option<bool>,
40    quote: Option<Quote>,
41    flexible: Option<bool>,
42    is_utf8: Option<bool>,
43
44    // Metadata guesses
45    // delimiter_freq: Option<usize>,
46    // fields: Vec<String>,
47    // types: Vec<Type>,
48    // avg_record_len: Option<usize>,
49
50    // sample size to sniff
51    sample_size: Option<SampleSize>,
52    // date format preference
53    // date_preference: Option<DatePreference>,
54}
55impl Sniffer {
56    /// Create a new CSV sniffer.
57    pub fn new() -> Self {
58        Self::default()
59    }
60    /// Specify the delimiter character.
61    pub fn delimiter(&mut self, delimiter: u8) -> &mut Self {
62        self.delimiter = Some(delimiter);
63        self
64    }
65    /// Specify the header type (whether the CSV file has a header row, and where the data starts).
66    // pub fn header(&mut self, header: &Header) -> &mut Self {
67    //     self.num_preamble_rows = Some(header.num_preamble_rows);
68    //     self.has_header_row = Some(header.has_header_row);
69    //     self
70    // }
71    /// Specify the quote character (if any), and whether two quotes in a row as to be interpreted
72    /// as an escaped quote.
73    pub fn quote(&mut self, quote: Quote) -> &mut Self {
74        self.quote = Some(quote);
75        self
76    }
77
78    /// The size of the sample to examine while sniffing. If using `SampleSize::Records`, the
79    /// sniffer will use the `Terminator::CRLF` as record separator.
80    ///
81    /// The sample size defaults to `SampleSize::Bytes(4096)`.
82    pub fn sample_size(&mut self, sample_size: SampleSize) -> &mut Self {
83        self.sample_size = Some(sample_size);
84        self
85    }
86
87    fn get_sample_size(&self) -> SampleSize {
88        self.sample_size.unwrap_or(SampleSize::Bytes(1 << 14))
89    }
90
91    // The date format preference when sniffing.
92    //
93    // The date format preference defaults to `DatePreference::MDY`.
94    // pub fn date_preference(&mut self, date_preference: DatePreference) -> &mut Self {
95    //     DATE_PREFERENCE.with(|preference| {
96    //         *preference.borrow_mut() = date_preference;
97    //     });
98    //     self.date_preference = Some(date_preference);
99    //     self
100    // }
101
102    /// Sniff the CSV file located at the provided path, and return a `Reader` (from the
103    /// [`csv`](https://docs.rs/csv) crate) ready to ready the file.
104    ///
105    /// Fails on file opening or rendering errors, or on an error examining the file.
106    pub fn open_path<P: AsRef<Path>>(&mut self, path: P) -> Result<Reader<File>> {
107        self.open_reader(File::open(path)?)
108    }
109    /// Sniff the CSV file provided by the reader, and return a [`csv`](https://docs.rs/csv)
110    /// `Reader` object.
111    ///
112    /// Fails on file opening or rendering errors, or on an error examining the file.
113    pub fn open_reader<R: Read + Seek>(&mut self, mut reader: R) -> Result<Reader<R>> {
114        let metadata = self.sniff_reader(&mut reader)?;
115        reader.seek(SeekFrom::Start(0))?;
116        metadata.dialect.open_reader(reader)
117    }
118
119    /// Sniff the CSV file located at the provided path, and return a
120    /// [`Metadata`](struct.Metadata.html) object containing information about the CSV file.
121    ///
122    /// Fails on file opening or rendering errors, or on an error examining the file.
123    pub fn sniff_path<P: AsRef<Path>>(&mut self, path: P) -> Result<Metadata> {
124        let file = File::open(path)?;
125        self.sniff_reader(&file)
126    }
127    /// Sniff the CSV file provider by the reader, and return a
128    /// [`Metadata`](struct.Metadata.html) object containing information about the CSV file.
129    ///
130    /// Fails on file opening or readering errors, or on an error examining the file.
131    pub fn sniff_reader<R: Read + Seek>(&mut self, mut reader: R) -> Result<Metadata> {
132        // init IS_UTF8 global var to true
133        IS_UTF8.with(|flag| {
134            *flag.borrow_mut() = true;
135        });
136        // guess quotes & delim
137        self.infer_quotes_delim(&mut reader)?;
138
139        // if we have a delimiter, we just need to search for num_preamble_rows and check for
140        // flexible. Otherwise, we need to guess a delimiter as well.
141        if self.delimiter.is_some() {
142            self.infer_preamble_known_delim(&mut reader)?;
143        } else {
144            self.infer_delim_preamble(&mut reader)?;
145        }
146
147        // self.infer_types(&mut reader)?;
148        self.is_utf8 = Some(IS_UTF8.with(|flag| *flag.borrow()));
149
150        // as this point of the process, we should have all these filled in.
151        // assert!(
152        //     self.delimiter.is_some()
153        //         && self.num_preamble_rows.is_some()
154        //         && self.quote.is_some()
155        //         && self.flexible.is_some()
156        //         && self.is_utf8.is_some()
157        //         && self.delimiter_freq.is_some()
158        //         && self.has_header_row.is_some()
159        //         && self.avg_record_len.is_some()
160        //         && self.delimiter_freq.is_some()
161        // );
162        if !(
163            self.delimiter.is_some()
164            // && self.num_preamble_rows.is_some()
165            && self.quote.is_some()
166            && self.flexible.is_some()
167            && self.is_utf8.is_some()
168            // && self.has_header_row.is_some()
169            // && self.avg_record_len.is_some()
170        ) {
171            return Err(SnifferError::SniffingFailed(format!(
172                "Failed to infer all metadata: {self:?}"
173            )));
174        }
175        // safety: we just checked that all these are Some, so it's safe to unwrap
176        Ok(Metadata {
177            dialect: Dialect {
178                delimiter: self.delimiter.unwrap(),
179                // header: Header {
180                //     num_preamble_rows: self.num_preamble_rows.unwrap(),
181                //     has_header_row: self.has_header_row.unwrap(),
182                // },
183                quote: self.quote.clone().unwrap(),
184                // flexible: self.flexible.unwrap(),
185                // is_utf8: self.is_utf8.unwrap(),
186            },
187            // avg_record_len: self.avg_record_len.unwrap(),
188            // num_fields: self.delimiter_freq.unwrap() + 1,
189            // fields: self.fields.clone(),
190            // types: self.types.clone(),
191        })
192    }
193
194    // Infers quotes and delimiter from quoted (or possibly quoted) files. If quotes detected,
195    // updates self.quote and self.delimiter. If quotes not detected, updates self.quote to
196    // Quote::None. Only valid quote characters: " (double-quote), ' (single-quote), ` (back-tick).
197    fn infer_quotes_delim<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
198        if let (&Some(_), &Some(_)) = (&self.quote, &self.delimiter) {
199            // nothing left to infer!
200            return Ok(());
201        }
202        let quote_guesses = match self.quote {
203            Some(Quote::Some(chr)) => vec![chr],
204            Some(Quote::None) => {
205                // this function only checks quoted (or possibly quoted) files, nothing left to
206                // do if we know there are no quotes
207                return Ok(());
208            }
209            None => vec![b'\'', b'"', b'`'],
210        };
211        let (quote_chr, (quote_cnt, delim_guess)) = quote_guesses.iter().try_fold(
212            (b'"', (0, b'\0')),
213            |acc, &chr| -> Result<(u8, (usize, u8))> {
214                let mut sample_reader = take_sample_from_start(reader, self.get_sample_size())?;
215                if let Some((cnt, delim_chr)) =
216                    quote_count(&mut sample_reader, char::from(chr), self.delimiter)?
217                {
218                    Ok(if cnt > acc.1.0 {
219                        (chr, (cnt, delim_chr))
220                    } else {
221                        acc
222                    })
223                } else {
224                    Ok(acc)
225                }
226            },
227        )?;
228        if quote_cnt == 0 {
229            self.quote = Some(Quote::None);
230        } else {
231            self.quote = Some(Quote::Some(quote_chr));
232            self.delimiter = Some(delim_guess);
233        };
234        Ok(())
235    }
236
237    // Updates delimiter frequency, number of preamble rows, and flexible boolean.
238    fn infer_preamble_known_delim<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
239        // prerequisites for calling this function:
240        if !(self.delimiter.is_some() && self.quote.is_some()) {
241            // instead of assert, return an error
242            // assert!(self.delimiter.is_some() && self.quote.is_some());
243            return Err(SnifferError::SniffingFailed(
244                "infer_preamble_known_delim called without delimiter and quote".into(),
245            ));
246        }
247        // safety: unwraps for delimiter and quote are safe since we just checked above
248        let (quote, delim) = (self.quote.clone().unwrap(), self.delimiter.unwrap());
249
250        let sample_iter = take_sample_from_start(reader, self.get_sample_size())?;
251
252        let mut chain = Chain::default();
253
254        if let Quote::Some(character) = quote {
255            // since we have a quote, we need to run this data through the csv_core::Reader (which
256            // properly escapes quoted fields
257            let mut csv_reader = csvc::ReaderBuilder::new()
258                .delimiter(delim)
259                .quote(character)
260                .build();
261
262            let mut output = vec![];
263            let mut ends = vec![];
264            for line in sample_iter {
265                let line = line?;
266                if line.len() > output.len() {
267                    output.resize(line.len(), 0);
268                }
269                if line.len() > ends.len() {
270                    ends.resize(line.len(), 0);
271                }
272                let (result, _, _, n_ends) =
273                    csv_reader.read_record(line.as_bytes(), &mut output, &mut ends);
274                // check to make sure record was read correctly
275                match result {
276                    csvc::ReadRecordResult::OutputFull | csvc::ReadRecordResult::OutputEndsFull => {
277                        return Err(SnifferError::SniffingFailed(format!(
278                            "failure to read quoted CSV record: {result:?}"
279                        )));
280                    }
281                    _ => {} // non-error results, do nothing
282                }
283                // n_ends is the number of barries between fields, so it's the same as the number
284                // of delimiters
285                chain.add_observation(n_ends);
286            }
287        } else {
288            for line in sample_iter {
289                let line = line?;
290                let freq = bytecount::count(line.as_bytes(), delim);
291                chain.add_observation(freq);
292            }
293        }
294        self.run_chains(vec![chain])
295    }
296
297    // Updates delimiter, delimiter frequency, number of preamble rows, and flexible boolean.
298    fn infer_delim_preamble<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
299        let sample_iter =
300            take_sample_from_start(reader, self.get_sample_size())?.collect::<Result<Vec<_>>>()?;
301
302        let mut chars_frequency: HashMap<u8, HashMap<NumberOfOccurrences, NumberOfLines>> =
303            HashMap::with_capacity(NUM_ASCII_CHARS);
304
305        let mut modes: HashMap<u8, (NumberOfOccurrences, AdjacentFrequency)> =
306            HashMap::with_capacity(NUM_ASCII_CHARS);
307
308        for line in &sample_iter {
309            let mut line_frequency = HashMap::with_capacity(128);
310            for character in line.chars() {
311                let Ok(ascii_char) = u8::try_from(character) else {
312                    continue;
313                };
314                if !CANDIDATES.contains(&ascii_char) {
315                    continue;
316                }
317                *line_frequency.entry(ascii_char).or_default() += 1;
318            }
319            for (ascii_char, freq) in line_frequency {
320                let char_frequency = chars_frequency.entry(ascii_char).or_default();
321                *char_frequency.entry(freq).or_default() += 1;
322            }
323        }
324        for (&ascii_char, line_count_map) in &chars_frequency {
325            let Some((&mode_value, _)) = line_count_map
326                .iter()
327                .max_by_key(|&(_count, num_lines)| num_lines)
328            else {
329                continue; // skip empty maps, just in case
330            };
331
332            let mut adjusted_count = 0;
333            for delta in 0..=TOLERANCE {
334                for count in [mode_value.saturating_sub(delta), mode_value + delta] {
335                    if let Some(&lines) = line_count_map.get(&count) {
336                        adjusted_count += lines;
337                    }
338                }
339            }
340            if TOLERANCE > 0 {
341                if let Some(&lines) = line_count_map.get(&mode_value) {
342                    adjusted_count -= lines;
343                }
344            }
345
346            modes.insert(ascii_char, (mode_value, adjusted_count));
347        }
348        let top_candidates: Vec<u8> = modes
349            .iter()
350            .filter(|(_, (_, score))| *score > 0)
351            .sorted_by_key(|&(_, &(_, score))| std::cmp::Reverse(score)) // needs itertools or just sort
352            .take(6)
353            .map(|(&ch, _)| ch)
354            .collect();
355
356        let mut chains = vec![Chain::default(); NUM_ASCII_CHARS];
357
358        for line in sample_iter {
359            let mut freqs = [0; NUM_ASCII_CHARS];
360            for &chr in line.as_bytes() {
361                if chr < NUM_ASCII_CHARS as u8 {
362                    freqs[chr as usize] += 1;
363                }
364            }
365            for &ch in &top_candidates {
366                chains[ch as usize].add_observation(freqs[ch as usize]);
367            }
368        }
369
370        self.run_chains(chains)
371    }
372
373    // Updates delimiter (if not already known), delimiter frequency, number of preamble rows, and
374    // flexible boolean.
375    fn run_chains(&mut self, mut chains: Vec<Chain>) -> Result<()> {
376        // Find the 'best' delimiter: choose strict (non-flexible) delimiters over flexible ones,
377        // and choose the one that had the highest probability markov chain in the end.
378        //
379        // In the case where delim is already known, 'best_delim' will be incorrect (since it won't
380        // correspond with position in a vector of Chains), but we'll just ignore it when
381        // constructing our return value later. 'best_state' and 'path' are necessary, though, to
382        // compute the preamble rows.
383        let (best_delim, _, best_state, _, _) = chains.iter_mut().enumerate().fold(
384            (b',', 0, STATE_UNSTEADY, vec![], 0.0),
385            |acc, (i, ref mut chain)| {
386                let (_, _, best_state, _, best_state_prob) = acc;
387                let ViterbiResults {
388                    max_delim_freq,
389                    path,
390                } = chain.viterbi();
391                if path.is_empty() {
392                    return acc;
393                }
394                let (final_state, final_viter) = path[path.len() - 1];
395                if final_state < best_state
396                    || (final_state == best_state && final_viter.prob > best_state_prob)
397                {
398                    (i as u8, max_delim_freq, final_state, path, final_viter.prob)
399                } else {
400                    acc
401                }
402            },
403        );
404        self.flexible = Some(match best_state {
405            STATE_STEADYSTRICT => false,
406            STATE_STEADYFLEX => true,
407            _ => {
408                return Err(SnifferError::SniffingFailed(
409                    "unable to find valid delimiter".to_string(),
410                ));
411            }
412        });
413
414        // Find the number of preamble rows (the number of rows during which the state fluctuated
415        // before getting to the final state).
416        // let mut num_preamble_rows = 0;
417        // since path has an extra state as the beginning, skip one
418        // for &(state, _) in path.iter().skip(2) {
419        //     if state == best_state {
420        //         break;
421        //     }
422        //     num_preamble_rows += 1;
423        // }
424        // if num_preamble_rows > 0 {
425        //     num_preamble_rows += 1;
426        // }
427        if self.delimiter.is_none() {
428            self.delimiter = Some(best_delim);
429        }
430        // self.num_preamble_rows = Some(num_preamble_rows);
431        Ok(())
432    }
433
434    // fn infer_types<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
435    //     // prerequisites for calling this function:
436    //     if self.delimiter_freq.is_none() {
437    //         // instead of assert, return error
438    //         // assert!(self.delimiter_freq.is_some());
439    //         return Err(SnifferError::SniffingFailed(
440    //             "delimiter frequency not known".to_string(),
441    //         ));
442    //     }
443    //     // safety: unwrap is safe as we just checked that delimiter_freq is Some
444    //     let field_count = self.delimiter_freq.unwrap() + 1;
445
446    //     let mut csv_reader = self.create_csv_reader(reader)?;
447    //     let mut records_iter = csv_reader.byte_records();
448    //     let mut n_bytes = 0;
449    //     let mut n_records = 0;
450    //     let sample_size = self.get_sample_size();
451
452    //     // Infer types for the top row. We'll save this set of types to check against the types
453    //     // of the remaining rows to see if this is part of the data or a separate header row.
454    //     let header_row_types = match records_iter.next() {
455    //         Some(record) => {
456    //             let byte_record = record?;
457    //             let str_record = StringRecord::from_byte_record_lossy(byte_record);
458    //             n_records += 1;
459    //             n_bytes += count_bytes(&str_record);
460    //             infer_record_types(&str_record)
461    //         }
462    //         None => {
463    //             return Err(SnifferError::SniffingFailed(
464    //                 "CSV empty (after preamble)".into(),
465    //             ));
466    //         }
467    //     };
468    //     let mut row_types = vec![TypeGuesses::all(); field_count];
469
470    //     for record in records_iter {
471    //         let record = record?;
472    //         for (i, field) in record.iter().enumerate() {
473    //             let str_field = String::from_utf8_lossy(field).to_string();
474    //             row_types[i] &= infer_types(&str_field);
475    //         }
476    //         n_records += 1;
477    //         n_bytes += record.as_slice().len();
478    //         // break if we pass sample size limits
479    //         match sample_size {
480    //             SampleSize::Records(recs) => {
481    //                 if n_records > recs {
482    //                     break;
483    //                 }
484    //             }
485    //             SampleSize::Bytes(bytes) => {
486    //                 if n_bytes > bytes {
487    //                     break;
488    //                 }
489    //             }
490    //             SampleSize::All => {}
491    //         }
492    //     }
493    //     if n_records == 1 {
494    //         // there's only one row in the whole data file (the top row already parsed),
495    //         // so we're going to assume it's a data row, not a header row.
496    //         self.has_header_row = Some(false);
497    //         self.types = get_best_types(&header_row_types);
498    //         self.avg_record_len = Some(n_bytes);
499    //         return Ok(());
500    //     }
501
502    //     if header_row_types
503    //         .iter()
504    //         .zip(&row_types)
505    //         .any(|(header, data)| !data.allows(*header))
506    //     {
507    //         self.has_header_row = Some(true);
508    //         // get field names in header
509    //         for field in csv_reader.byte_headers()? {
510    //             self.fields.push(String::from_utf8_lossy(field).to_string());
511    //         }
512    //     } else {
513    //         self.has_header_row = Some(false);
514    //     }
515
516    //     self.types = get_best_types(&row_types);
517    //     self.avg_record_len = Some(n_bytes / n_records);
518    //     Ok(())
519    // }
520
521    // fn create_csv_reader<'a, R: Read + Seek>(
522    //     &self,
523    //     mut reader: &'a mut R,
524    // ) -> Result<Reader<&'a mut R>> {
525    //     reader.seek(SeekFrom::Start(0))?;
526    //     if let Some(num_preamble_rows) = self.num_preamble_rows {
527    //         snip_preamble(&mut reader, num_preamble_rows)?;
528    //     }
529
530    //     let mut builder = csv::ReaderBuilder::new();
531    //     if let Some(delim) = self.delimiter {
532    //         builder.delimiter(delim);
533    //     }
534    //     if let Some(has_header_row) = self.has_header_row {
535    //         builder.has_headers(has_header_row);
536    //     }
537    //     match self.quote {
538    //         Some(Quote::Some(chr)) => {
539    //             builder.quoting(true);
540    //             builder.quote(chr);
541    //         }
542    //         Some(Quote::None) => {
543    //             builder.quoting(false);
544    //         }
545    //         _ => {}
546    //     }
547    //     if let Some(flexible) = self.flexible {
548    //         builder.flexible(flexible);
549    //     }
550
551    //     Ok(builder.from_reader(reader))
552    // }
553}
554
555fn quote_count<R: Read>(
556    sample_iter: &mut SampleIter<R>,
557    character: char,
558    delim: Option<u8>,
559) -> Result<Option<(usize, u8)>> {
560    // Build a regex that matches a quoted CSV cell,
561    // optionally followed by a delimiter.
562    // If delim is None, we try to capture a candidate delimiter.
563    let pattern = delim.map_or_else(
564        || {
565            // When delim is not provided, capture candidate delimiters in a group.
566            format!(
567                r#"(?<delim1>[^\w\n\"ֿ\'])(?: ?)(?:{character}).*?(?:{character})(?<delim2>[^\w\n\"\'])|
568                (?:^|\n)(?:{character}).*?(?:{character})(?<delim3>[^\w\n\"\'])(?: ?)|
569                (?<delim4>[^\w\n\"\'])(?: ?)(?:{character}).*?(?:{character})(?:$|\n)|
570                (?:^|\n)(?:{character}).*?(?:{character})(?:$|\n)"#
571            )
572        },
573        |delim| {
574            // When a delimiter is provided, enforce its presence if it appears.
575            format!(
576                r"{q}(?P<field>(?:[^{q}]|{q}{q})*){q}(?:\s*{d}\s*)?",
577                q = character,
578                d = delim as char
579            )
580        },
581    );
582    // Safety: unwrap is safe here because we control the regex pattern.
583    let re = Regex::new(&pattern).unwrap();
584
585    let mut delim_count_map: HashMap<u8, usize> = HashMap::new();
586    let mut count = 0;
587    for line in sample_iter {
588        let line = line?;
589        // Iterate through all quoted cell matches in the line.
590        for cap in re.captures_iter(&line) {
591            count += 1;
592
593            if let Some(delim) = get_delimiter(&cap) {
594                *delim_count_map.entry(delim).or_insert(0) += 1;
595            }
596        }
597    }
598    if count == 0 {
599        return Ok(None);
600    }
601
602    // If a delimiter was provided, just return it.
603    if let Some(delim) = delim {
604        return Ok(Some((count, delim)));
605    }
606
607    // Otherwise, select the candidate delimiter that was matched most frequently.
608    let (delim_count, delim) =
609        delim_count_map
610            .into_iter()
611            .fold((0, b'\0'), |acc, (delim, d_count)| {
612                if d_count > acc.0 {
613                    (d_count, delim)
614                } else {
615                    acc
616                }
617            });
618
619    if delim_count == 0 {
620        return Err(SnifferError::SniffingFailed(
621            "invalid regex match: no delimiter found".into(),
622        ));
623    }
624    Ok(Some((count, delim)))
625}
626
627fn get_delimiter(captures: &Captures<'_>) -> Option<u8> {
628    let mut counts: HashMap<char, usize> = HashMap::new();
629    // Check groups delim1 through delim4.
630    for i in 1..=4 {
631        let group_name = format!("delim{i}");
632        if let Some(matched) = captures.name(&group_name) {
633            if let Some(ch) = matched.as_str().chars().next() {
634                *counts.entry(ch).or_insert(0) += 1;
635            }
636        }
637    }
638
639    // If no candidates were found, return None.
640    if counts.is_empty() {
641        return None;
642    }
643
644    // Select the candidate with the highest frequency.
645    let (candidate, _) = counts.into_iter().max_by_key(|&(_, count)| count)?;
646    u8::try_from(candidate).ok()
647}