csv_scout/
sniffer.rs

1use hashbrown::HashMap;
2use itertools::Itertools;
3use std::cell::RefCell;
4use std::fs::File;
5use std::io::{Read, Seek, SeekFrom};
6use std::path::Path;
7
8use csv::Reader;
9use csv_core as csvc;
10use regex::{Captures, Regex};
11
12use crate::{
13    chain::{Chain, STATE_STEADYFLEX, STATE_STEADYSTRICT, STATE_UNSTEADY, ViterbiResults},
14    error::{Result, SnifferError},
15    // field_type::DatePreference,
16    metadata::{Dialect, Metadata, Quote},
17    sample::{SampleIter, SampleSize, take_sample_from_start},
18};
19
20type NumberOfOccurrences = u32;
21type NumberOfLines = u32;
22type AdjacentFrequency = u32;
23
24const TOLERANCE: u32 = 1;
25const NUM_ASCII_CHARS: usize = 128;
26const CANDIDATES: &[u8] = b"\t,;|:";
27
28thread_local! (pub static IS_UTF8: RefCell<bool> = const { RefCell::new(true) });
29// thread_local! (pub static DATE_PREFERENCE: RefCell<DatePreference> = const { RefCell::new(DatePreference::MdyFormat) });
30
31/// A CSV sniffer.
32///
33/// The sniffer examines a CSV file, passed in either through a file or a reader.
34#[derive(Debug, Default)]
35pub struct Sniffer {
36    // CSV file dialect guesses
37    delimiter: Option<u8>,
38    // num_preamble_rows: Option<usize>,
39    // has_header_row: Option<bool>,
40    quote: Option<Quote>,
41    is_utf8: Option<bool>,
42    // Metadata guesses
43    // delimiter_freq: Option<usize>,
44    // fields: Vec<String>,
45    // types: Vec<Type>,
46    // avg_record_len: Option<usize>,
47
48    // sample size to sniff
49    sample_size: Option<SampleSize>,
50    // date format preference
51    // date_preference: Option<DatePreference>,
52}
53impl Sniffer {
54    /// Create a new CSV sniffer.
55    pub fn new() -> Self {
56        Self::default()
57    }
58    /// Specify the delimiter character.
59    pub fn delimiter(&mut self, delimiter: u8) -> &mut Self {
60        self.delimiter = Some(delimiter);
61        self
62    }
63    /// Specify the header type (whether the CSV file has a header row, and where the data starts).
64    // pub fn header(&mut self, header: &Header) -> &mut Self {
65    //     self.num_preamble_rows = Some(header.num_preamble_rows);
66    //     self.has_header_row = Some(header.has_header_row);
67    //     self
68    // }
69    /// Specify the quote character (if any), and whether two quotes in a row as to be interpreted
70    /// as an escaped quote.
71    pub fn quote(&mut self, quote: Quote) -> &mut Self {
72        self.quote = Some(quote);
73        self
74    }
75
76    /// The size of the sample to examine while sniffing. If using `SampleSize::Records`, the
77    /// sniffer will use the `Terminator::CRLF` as record separator.
78    ///
79    /// The sample size defaults to `SampleSize::Bytes(4096)`.
80    pub fn sample_size(&mut self, sample_size: SampleSize) -> &mut Self {
81        self.sample_size = Some(sample_size);
82        self
83    }
84
85    fn get_sample_size(&self) -> SampleSize {
86        self.sample_size.unwrap_or(SampleSize::Bytes(1 << 14))
87    }
88
89    // The date format preference when sniffing.
90    //
91    // The date format preference defaults to `DatePreference::MDY`.
92    // pub fn date_preference(&mut self, date_preference: DatePreference) -> &mut Self {
93    //     DATE_PREFERENCE.with(|preference| {
94    //         *preference.borrow_mut() = date_preference;
95    //     });
96    //     self.date_preference = Some(date_preference);
97    //     self
98    // }
99
100    /// Sniff the CSV file located at the provided path, and return a `Reader` (from the
101    /// [`csv`](https://docs.rs/csv) crate) ready to ready the file.
102    ///
103    /// Fails on file opening or rendering errors, or on an error examining the file.
104    pub fn open_path<P: AsRef<Path>>(&mut self, path: P) -> Result<Reader<File>> {
105        self.open_reader(File::open(path)?)
106    }
107    /// Sniff the CSV file provided by the reader, and return a [`csv`](https://docs.rs/csv)
108    /// `Reader` object.
109    ///
110    /// Fails on file opening or rendering errors, or on an error examining the file.
111    pub fn open_reader<R: Read + Seek>(&mut self, mut reader: R) -> Result<Reader<R>> {
112        let metadata = self.sniff_reader(&mut reader)?;
113        reader.seek(SeekFrom::Start(0))?;
114        metadata.dialect.open_reader(reader)
115    }
116
117    /// Sniff the CSV file located at the provided path, and return a
118    /// [`Metadata`](struct.Metadata.html) object containing information about the CSV file.
119    ///
120    /// Fails on file opening or rendering errors, or on an error examining the file.
121    pub fn sniff_path<P: AsRef<Path>>(&mut self, path: P) -> Result<Metadata> {
122        let file = File::open(path)?;
123        self.sniff_reader(&file)
124    }
125    /// Sniff the CSV file provider by the reader, and return a
126    /// [`Metadata`](struct.Metadata.html) object containing information about the CSV file.
127    ///
128    /// Fails on file opening or readering errors, or on an error examining the file.
129    pub fn sniff_reader<R: Read + Seek>(&mut self, mut reader: R) -> Result<Metadata> {
130        // init IS_UTF8 global var to true
131        IS_UTF8.with(|flag| {
132            *flag.borrow_mut() = true;
133        });
134        // guess quotes & delim
135        self.infer_quotes_delim(&mut reader)?;
136
137        // if we have a delimiter, we just need to search for num_preamble_rows and check for
138        // flexible. Otherwise, we need to guess a delimiter as well.
139        if self.delimiter.is_none() {
140            self.infer_delim_preamble(&mut reader)?;
141        }
142
143        // self.infer_types(&mut reader)?;
144        self.is_utf8 = Some(IS_UTF8.with(|flag| *flag.borrow()));
145
146        // as this point of the process, we should have all these filled in.
147        // assert!(
148        //     self.delimiter.is_some()
149        //         && self.num_preamble_rows.is_some()
150        //         && self.quote.is_some()
151        //         && self.flexible.is_some()
152        //         && self.is_utf8.is_some()
153        //         && self.delimiter_freq.is_some()
154        //         && self.has_header_row.is_some()
155        //         && self.avg_record_len.is_some()
156        //         && self.delimiter_freq.is_some()
157        // );
158        if !(
159            self.delimiter.is_some()
160            // && self.num_preamble_rows.is_some()
161            && self.quote.is_some()
162            && self.is_utf8.is_some()
163            // && self.has_header_row.is_some()
164            // && self.avg_record_len.is_some()
165        ) {
166            return Err(SnifferError::SniffingFailed(format!(
167                "Failed to infer all metadata: {self:?}"
168            )));
169        }
170        // safety: we just checked that all these are Some, so it's safe to unwrap
171        Ok(Metadata {
172            dialect: Dialect {
173                delimiter: self.delimiter.unwrap(),
174                // header: Header {
175                //     num_preamble_rows: self.num_preamble_rows.unwrap(),
176                //     has_header_row: self.has_header_row.unwrap(),
177                // },
178                quote: self.quote.clone().unwrap(),
179                // flexible: self.flexible.unwrap(),
180                // is_utf8: self.is_utf8.unwrap(),
181            },
182            // avg_record_len: self.avg_record_len.unwrap(),
183            // num_fields: self.delimiter_freq.unwrap() + 1,
184            // fields: self.fields.clone(),
185            // types: self.types.clone(),
186        })
187    }
188
189    // Infers quotes and delimiter from quoted (or possibly quoted) files. If quotes detected,
190    // updates self.quote and self.delimiter. If quotes not detected, updates self.quote to
191    // Quote::None. Only valid quote characters: " (double-quote), ' (single-quote), ` (back-tick).
192    fn infer_quotes_delim<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
193        if let (&Some(_), &Some(_)) = (&self.quote, &self.delimiter) {
194            // nothing left to infer!
195            return Ok(());
196        }
197        let quote_guesses = match self.quote {
198            Some(Quote::Some(chr)) => vec![chr],
199            Some(Quote::None) => {
200                // this function only checks quoted (or possibly quoted) files, nothing left to
201                // do if we know there are no quotes
202                return Ok(());
203            }
204            None => vec![b'\'', b'"', b'`'],
205        };
206        let (quote_chr, (quote_cnt, delim_guess)) = quote_guesses.iter().try_fold(
207            (b'"', (0, b'\0')),
208            |acc, &chr| -> Result<(u8, (usize, u8))> {
209                let mut sample_reader = take_sample_from_start(reader, self.get_sample_size())?;
210                if let Some((cnt, delim_chr)) =
211                    quote_count(&mut sample_reader, char::from(chr), self.delimiter)?
212                {
213                    Ok(if cnt > acc.1.0 {
214                        (chr, (cnt, delim_chr))
215                    } else {
216                        acc
217                    })
218                } else {
219                    Ok(acc)
220                }
221            },
222        )?;
223        if quote_cnt == 0 {
224            self.quote = Some(Quote::None);
225        } else {
226            self.quote = Some(Quote::Some(quote_chr));
227            self.delimiter = Some(delim_guess);
228        };
229        Ok(())
230    }
231
232    // Updates delimiter frequency, number of preamble rows, and flexible boolean.
233    fn infer_preamble_known_delim<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
234        // prerequisites for calling this function:
235        if !(self.delimiter.is_some() && self.quote.is_some()) {
236            // instead of assert, return an error
237            // assert!(self.delimiter.is_some() && self.quote.is_some());
238            return Err(SnifferError::SniffingFailed(
239                "infer_preamble_known_delim called without delimiter and quote".into(),
240            ));
241        }
242        // safety: unwraps for delimiter and quote are safe since we just checked above
243        let (quote, delim) = (self.quote.clone().unwrap(), self.delimiter.unwrap());
244
245        let sample_iter = take_sample_from_start(reader, self.get_sample_size())?;
246
247        let mut chain = Chain::default();
248
249        if let Quote::Some(character) = quote {
250            // since we have a quote, we need to run this data through the csv_core::Reader (which
251            // properly escapes quoted fields
252            let mut csv_reader = csvc::ReaderBuilder::new()
253                .delimiter(delim)
254                .quote(character)
255                .build();
256
257            let mut output = vec![];
258            let mut ends = vec![];
259            for line in sample_iter {
260                let line = line?;
261                if line.len() > output.len() {
262                    output.resize(line.len(), 0);
263                }
264                if line.len() > ends.len() {
265                    ends.resize(line.len(), 0);
266                }
267                let (result, _, _, n_ends) =
268                    csv_reader.read_record(line.as_bytes(), &mut output, &mut ends);
269                // check to make sure record was read correctly
270                match result {
271                    csvc::ReadRecordResult::OutputFull | csvc::ReadRecordResult::OutputEndsFull => {
272                        return Err(SnifferError::SniffingFailed(format!(
273                            "failure to read quoted CSV record: {result:?}"
274                        )));
275                    }
276                    _ => {} // non-error results, do nothing
277                }
278                // n_ends is the number of barries between fields, so it's the same as the number
279                // of delimiters
280                chain.add_observation(n_ends);
281            }
282        } else {
283            for line in sample_iter {
284                let line = line?;
285                let freq = bytecount::count(line.as_bytes(), delim);
286                chain.add_observation(freq);
287            }
288        }
289        self.run_chains(vec![chain])
290    }
291
292    // Updates delimiter, delimiter frequency, number of preamble rows, and flexible boolean.
293    fn infer_delim_preamble<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
294        let sample_iter =
295            take_sample_from_start(reader, self.get_sample_size())?.collect::<Result<Vec<_>>>()?;
296
297        let mut chars_frequency: HashMap<u8, HashMap<NumberOfOccurrences, NumberOfLines>> =
298            HashMap::with_capacity(NUM_ASCII_CHARS);
299
300        let mut modes: HashMap<u8, (NumberOfOccurrences, AdjacentFrequency)> =
301            HashMap::with_capacity(NUM_ASCII_CHARS);
302
303        for line in &sample_iter {
304            let mut line_frequency = HashMap::with_capacity(128);
305            for character in line.chars() {
306                let Ok(ascii_char) = u8::try_from(character) else {
307                    continue;
308                };
309                if !CANDIDATES.contains(&ascii_char) {
310                    continue;
311                }
312                *line_frequency.entry(ascii_char).or_default() += 1;
313            }
314            for (ascii_char, freq) in line_frequency {
315                let char_frequency = chars_frequency.entry(ascii_char).or_default();
316                *char_frequency.entry(freq).or_default() += 1;
317            }
318        }
319        for (&ascii_char, line_count_map) in &chars_frequency {
320            let Some((&mode_value, _)) = line_count_map
321                .iter()
322                .max_by_key(|&(_count, num_lines)| num_lines)
323            else {
324                continue; // skip empty maps, just in case
325            };
326
327            let mut adjusted_count = 0;
328            for delta in 0..=TOLERANCE {
329                for count in [mode_value.saturating_sub(delta), mode_value + delta] {
330                    if let Some(&lines) = line_count_map.get(&count) {
331                        adjusted_count += lines;
332                    }
333                }
334            }
335            if TOLERANCE > 0 {
336                if let Some(&lines) = line_count_map.get(&mode_value) {
337                    adjusted_count -= lines;
338                }
339            }
340
341            modes.insert(ascii_char, (mode_value, adjusted_count));
342        }
343        let top_candidates: Vec<u8> = modes
344            .iter()
345            .filter(|(_, (_, score))| *score > 0)
346            .sorted_by_key(|&(_, &(_, score))| std::cmp::Reverse(score)) // needs itertools or just sort
347            .take(6)
348            .map(|(&ch, _)| ch)
349            .collect();
350
351        let mut chains = vec![Chain::default(); NUM_ASCII_CHARS];
352
353        for line in sample_iter {
354            let mut freqs = [0; NUM_ASCII_CHARS];
355            for &chr in line.as_bytes() {
356                if chr < NUM_ASCII_CHARS as u8 {
357                    freqs[chr as usize] += 1;
358                }
359            }
360            for &ch in &top_candidates {
361                chains[ch as usize].add_observation(freqs[ch as usize]);
362            }
363        }
364
365        self.run_chains(chains)
366    }
367
368    // Updates delimiter (if not already known), delimiter frequency, number of preamble rows, and
369    // flexible boolean.
370    fn run_chains(&mut self, mut chains: Vec<Chain>) -> Result<()> {
371        // Find the 'best' delimiter: choose strict (non-flexible) delimiters over flexible ones,
372        // and choose the one that had the highest probability markov chain in the end.
373        //
374        // In the case where delim is already known, 'best_delim' will be incorrect (since it won't
375        // correspond with position in a vector of Chains), but we'll just ignore it when
376        // constructing our return value later. 'best_state' and 'path' are necessary, though, to
377        // compute the preamble rows.
378        let (best_delim, _, best_state, _, _) = chains.iter_mut().enumerate().fold(
379            (b',', 0, STATE_UNSTEADY, vec![], 0.0),
380            |acc, (i, ref mut chain)| {
381                let (_, _, best_state, _, best_state_prob) = acc;
382                let ViterbiResults {
383                    max_delim_freq,
384                    path,
385                } = chain.viterbi();
386                if path.is_empty() {
387                    return acc;
388                }
389                let (final_state, final_viter) = path[path.len() - 1];
390                if final_state < best_state
391                    || (final_state == best_state && final_viter.prob > best_state_prob)
392                {
393                    (i as u8, max_delim_freq, final_state, path, final_viter.prob)
394                } else {
395                    acc
396                }
397            },
398        );
399        // self.flexible = Some(match best_state {
400        //     STATE_STEADYSTRICT => false,
401        //     STATE_STEADYFLEX => true,
402        //     _ => {
403        //         return Err(SnifferError::SniffingFailed(
404        //             "unable to find valid delimiter".to_string(),
405        //         ));
406        //     }
407        // });
408
409        // Find the number of preamble rows (the number of rows during which the state fluctuated
410        // before getting to the final state).
411        // let mut num_preamble_rows = 0;
412        // since path has an extra state as the beginning, skip one
413        // for &(state, _) in path.iter().skip(2) {
414        //     if state == best_state {
415        //         break;
416        //     }
417        //     num_preamble_rows += 1;
418        // }
419        // if num_preamble_rows > 0 {
420        //     num_preamble_rows += 1;
421        // }
422        if self.delimiter.is_none() {
423            self.delimiter = Some(best_delim);
424        }
425        // self.num_preamble_rows = Some(num_preamble_rows);
426        Ok(())
427    }
428
429    // fn infer_types<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
430    //     // prerequisites for calling this function:
431    //     if self.delimiter_freq.is_none() {
432    //         // instead of assert, return error
433    //         // assert!(self.delimiter_freq.is_some());
434    //         return Err(SnifferError::SniffingFailed(
435    //             "delimiter frequency not known".to_string(),
436    //         ));
437    //     }
438    //     // safety: unwrap is safe as we just checked that delimiter_freq is Some
439    //     let field_count = self.delimiter_freq.unwrap() + 1;
440
441    //     let mut csv_reader = self.create_csv_reader(reader)?;
442    //     let mut records_iter = csv_reader.byte_records();
443    //     let mut n_bytes = 0;
444    //     let mut n_records = 0;
445    //     let sample_size = self.get_sample_size();
446
447    //     // Infer types for the top row. We'll save this set of types to check against the types
448    //     // of the remaining rows to see if this is part of the data or a separate header row.
449    //     let header_row_types = match records_iter.next() {
450    //         Some(record) => {
451    //             let byte_record = record?;
452    //             let str_record = StringRecord::from_byte_record_lossy(byte_record);
453    //             n_records += 1;
454    //             n_bytes += count_bytes(&str_record);
455    //             infer_record_types(&str_record)
456    //         }
457    //         None => {
458    //             return Err(SnifferError::SniffingFailed(
459    //                 "CSV empty (after preamble)".into(),
460    //             ));
461    //         }
462    //     };
463    //     let mut row_types = vec![TypeGuesses::all(); field_count];
464
465    //     for record in records_iter {
466    //         let record = record?;
467    //         for (i, field) in record.iter().enumerate() {
468    //             let str_field = String::from_utf8_lossy(field).to_string();
469    //             row_types[i] &= infer_types(&str_field);
470    //         }
471    //         n_records += 1;
472    //         n_bytes += record.as_slice().len();
473    //         // break if we pass sample size limits
474    //         match sample_size {
475    //             SampleSize::Records(recs) => {
476    //                 if n_records > recs {
477    //                     break;
478    //                 }
479    //             }
480    //             SampleSize::Bytes(bytes) => {
481    //                 if n_bytes > bytes {
482    //                     break;
483    //                 }
484    //             }
485    //             SampleSize::All => {}
486    //         }
487    //     }
488    //     if n_records == 1 {
489    //         // there's only one row in the whole data file (the top row already parsed),
490    //         // so we're going to assume it's a data row, not a header row.
491    //         self.has_header_row = Some(false);
492    //         self.types = get_best_types(&header_row_types);
493    //         self.avg_record_len = Some(n_bytes);
494    //         return Ok(());
495    //     }
496
497    //     if header_row_types
498    //         .iter()
499    //         .zip(&row_types)
500    //         .any(|(header, data)| !data.allows(*header))
501    //     {
502    //         self.has_header_row = Some(true);
503    //         // get field names in header
504    //         for field in csv_reader.byte_headers()? {
505    //             self.fields.push(String::from_utf8_lossy(field).to_string());
506    //         }
507    //     } else {
508    //         self.has_header_row = Some(false);
509    //     }
510
511    //     self.types = get_best_types(&row_types);
512    //     self.avg_record_len = Some(n_bytes / n_records);
513    //     Ok(())
514    // }
515
516    // fn create_csv_reader<'a, R: Read + Seek>(
517    //     &self,
518    //     mut reader: &'a mut R,
519    // ) -> Result<Reader<&'a mut R>> {
520    //     reader.seek(SeekFrom::Start(0))?;
521    //     if let Some(num_preamble_rows) = self.num_preamble_rows {
522    //         snip_preamble(&mut reader, num_preamble_rows)?;
523    //     }
524
525    //     let mut builder = csv::ReaderBuilder::new();
526    //     if let Some(delim) = self.delimiter {
527    //         builder.delimiter(delim);
528    //     }
529    //     if let Some(has_header_row) = self.has_header_row {
530    //         builder.has_headers(has_header_row);
531    //     }
532    //     match self.quote {
533    //         Some(Quote::Some(chr)) => {
534    //             builder.quoting(true);
535    //             builder.quote(chr);
536    //         }
537    //         Some(Quote::None) => {
538    //             builder.quoting(false);
539    //         }
540    //         _ => {}
541    //     }
542    //     if let Some(flexible) = self.flexible {
543    //         builder.flexible(flexible);
544    //     }
545
546    //     Ok(builder.from_reader(reader))
547    // }
548}
549
550fn quote_count<R: Read>(
551    sample_iter: &mut SampleIter<R>,
552    character: char,
553    delim: Option<u8>,
554) -> Result<Option<(usize, u8)>> {
555    // Build a regex that matches a quoted CSV cell,
556    // optionally followed by a delimiter.
557    // If delim is None, we try to capture a candidate delimiter.
558    let pattern = delim.map_or_else(
559        || {
560            // When delim is not provided, capture candidate delimiters in a group.
561            format!(
562                r#"(?<delim1>[^\w\n\"ֿ\'])(?: ?)(?:{character}).*?(?:{character})(?<delim2>[^\w\n\"\'])|
563                (?:^|\n)(?:{character}).*?(?:{character})(?<delim3>[^\w\n\"\'])(?: ?)|
564                (?<delim4>[^\w\n\"\'])(?: ?)(?:{character}).*?(?:{character})(?:$|\n)|
565                (?:^|\n)(?:{character}).*?(?:{character})(?:$|\n)"#
566            )
567        },
568        |delim| {
569            // When a delimiter is provided, enforce its presence if it appears.
570            format!(
571                r"{q}(?P<field>(?:[^{q}]|{q}{q})*){q}(?:\s*{d}\s*)?",
572                q = character,
573                d = delim as char
574            )
575        },
576    );
577    // Safety: unwrap is safe here because we control the regex pattern.
578    let re = Regex::new(&pattern).unwrap();
579
580    let mut delim_count_map: HashMap<u8, usize> = HashMap::new();
581    let mut count = 0;
582    for line in sample_iter {
583        let line = line?;
584        // Iterate through all quoted cell matches in the line.
585        for cap in re.captures_iter(&line) {
586            count += 1;
587
588            if let Some(delim) = get_delimiter(&cap) {
589                *delim_count_map.entry(delim).or_insert(0) += 1;
590            }
591        }
592    }
593    if count == 0 {
594        return Ok(None);
595    }
596
597    // If a delimiter was provided, just return it.
598    if let Some(delim) = delim {
599        return Ok(Some((count, delim)));
600    }
601
602    // Otherwise, select the candidate delimiter that was matched most frequently.
603    let (delim_count, delim) =
604        delim_count_map
605            .into_iter()
606            .fold((0, b'\0'), |acc, (delim, d_count)| {
607                if d_count > acc.0 {
608                    (d_count, delim)
609                } else {
610                    acc
611                }
612            });
613
614    if delim_count == 0 {
615        return Err(SnifferError::SniffingFailed(
616            "invalid regex match: no delimiter found".into(),
617        ));
618    }
619    Ok(Some((count, delim)))
620}
621
622fn get_delimiter(captures: &Captures<'_>) -> Option<u8> {
623    let mut counts: HashMap<char, usize> = HashMap::new();
624    // Check groups delim1 through delim4.
625    for i in 1..=4 {
626        let group_name = format!("delim{i}");
627        if let Some(matched) = captures.name(&group_name) {
628            if let Some(ch) = matched.as_str().chars().next() {
629                *counts.entry(ch).or_insert(0) += 1;
630            }
631        }
632    }
633
634    // If no candidates were found, return None.
635    if counts.is_empty() {
636        return None;
637    }
638
639    // Select the candidate with the highest frequency.
640    let (candidate, _) = counts.into_iter().max_by_key(|&(_, count)| count)?;
641    u8::try_from(candidate).ok()
642}