csv_scout/
sniffer.rs

1use hashbrown::HashMap;
2use itertools::Itertools;
3use std::cell::RefCell;
4use std::fs::File;
5use std::io::{Read, Seek, SeekFrom};
6use std::path::Path;
7
8use csv::Reader;
9use regex::{Captures, Regex};
10
11use crate::{
12    chain::{Chain, STATE_UNSTEADY, ViterbiResults},
13    error::{Result, SnifferError},
14    // field_type::DatePreference,
15    metadata::{Dialect, Metadata, Quote},
16    sample::{SampleIter, SampleSize, take_sample_from_start},
17};
18
19type NumberOfOccurrences = u32;
20type NumberOfLines = u32;
21type AdjacentFrequency = u32;
22
23const TOLERANCE: u32 = 1;
24const NUM_ASCII_CHARS: usize = 128;
25const CANDIDATES: &[u8] = b"\t,;|:";
26
27thread_local! (pub static IS_UTF8: RefCell<bool> = const { RefCell::new(true) });
28// thread_local! (pub static DATE_PREFERENCE: RefCell<DatePreference> = const { RefCell::new(DatePreference::MdyFormat) });
29
30/// A CSV sniffer.
31///
32/// The sniffer examines a CSV file, passed in either through a file or a reader.
33#[derive(Debug, Default)]
34pub struct Sniffer {
35    // CSV file dialect guesses
36    delimiter: Option<u8>,
37    // num_preamble_rows: Option<usize>,
38    // has_header_row: Option<bool>,
39    quote: Option<Quote>,
40    is_utf8: Option<bool>,
41    // Metadata guesses
42    // delimiter_freq: Option<usize>,
43    // fields: Vec<String>,
44    // types: Vec<Type>,
45    // avg_record_len: Option<usize>,
46
47    // sample size to sniff
48    sample_size: Option<SampleSize>,
49    // date format preference
50    // date_preference: Option<DatePreference>,
51}
52impl Sniffer {
53    /// Create a new CSV sniffer.
54    pub fn new() -> Self {
55        Self::default()
56    }
57    /// Specify the delimiter character.
58    pub fn delimiter(&mut self, delimiter: u8) -> &mut Self {
59        self.delimiter = Some(delimiter);
60        self
61    }
62    /// Specify the header type (whether the CSV file has a header row, and where the data starts).
63    // pub fn header(&mut self, header: &Header) -> &mut Self {
64    //     self.num_preamble_rows = Some(header.num_preamble_rows);
65    //     self.has_header_row = Some(header.has_header_row);
66    //     self
67    // }
68    /// Specify the quote character (if any), and whether two quotes in a row as to be interpreted
69    /// as an escaped quote.
70    pub fn quote(&mut self, quote: Quote) -> &mut Self {
71        self.quote = Some(quote);
72        self
73    }
74
75    /// The size of the sample to examine while sniffing. If using `SampleSize::Records`, the
76    /// sniffer will use the `Terminator::CRLF` as record separator.
77    ///
78    /// The sample size defaults to `SampleSize::Bytes(4096)`.
79    pub fn sample_size(&mut self, sample_size: SampleSize) -> &mut Self {
80        self.sample_size = Some(sample_size);
81        self
82    }
83
84    fn get_sample_size(&self) -> SampleSize {
85        self.sample_size.unwrap_or(SampleSize::Bytes(1 << 14))
86    }
87
88    // The date format preference when sniffing.
89    //
90    // The date format preference defaults to `DatePreference::MDY`.
91    // pub fn date_preference(&mut self, date_preference: DatePreference) -> &mut Self {
92    //     DATE_PREFERENCE.with(|preference| {
93    //         *preference.borrow_mut() = date_preference;
94    //     });
95    //     self.date_preference = Some(date_preference);
96    //     self
97    // }
98
99    /// Sniff the CSV file located at the provided path, and return a `Reader` (from the
100    /// [`csv`](https://docs.rs/csv) crate) ready to ready the file.
101    ///
102    /// Fails on file opening or rendering errors, or on an error examining the file.
103    pub fn open_path<P: AsRef<Path>>(&mut self, path: P) -> Result<Reader<File>> {
104        self.open_reader(File::open(path)?)
105    }
106    /// Sniff the CSV file provided by the reader, and return a [`csv`](https://docs.rs/csv)
107    /// `Reader` object.
108    ///
109    /// Fails on file opening or rendering errors, or on an error examining the file.
110    pub fn open_reader<R: Read + Seek>(&mut self, mut reader: R) -> Result<Reader<R>> {
111        let metadata = self.sniff_reader(&mut reader)?;
112        reader.seek(SeekFrom::Start(0))?;
113        metadata.dialect.open_reader(reader)
114    }
115
116    /// Sniff the CSV file located at the provided path, and return a
117    /// [`Metadata`](struct.Metadata.html) object containing information about the CSV file.
118    ///
119    /// Fails on file opening or rendering errors, or on an error examining the file.
120    pub fn sniff_path<P: AsRef<Path>>(&mut self, path: P) -> Result<Metadata> {
121        let file = File::open(path)?;
122        self.sniff_reader(&file)
123    }
124    /// Sniff the CSV file provider by the reader, and return a
125    /// [`Metadata`](struct.Metadata.html) object containing information about the CSV file.
126    ///
127    /// Fails on file opening or readering errors, or on an error examining the file.
128    pub fn sniff_reader<R: Read + Seek>(&mut self, mut reader: R) -> Result<Metadata> {
129        // init IS_UTF8 global var to true
130        IS_UTF8.with(|flag| {
131            *flag.borrow_mut() = true;
132        });
133        // guess quotes & delim
134        self.infer_quotes_delim(&mut reader)?;
135
136        // if we have a delimiter, we just need to search for num_preamble_rows and check for
137        // flexible. Otherwise, we need to guess a delimiter as well.
138        if self.delimiter.is_none() {
139            self.infer_delim_preamble(&mut reader)?;
140        }
141
142        // self.infer_types(&mut reader)?;
143        self.is_utf8 = Some(IS_UTF8.with(|flag| *flag.borrow()));
144
145        // as this point of the process, we should have all these filled in.
146        // assert!(
147        //     self.delimiter.is_some()
148        //         && self.num_preamble_rows.is_some()
149        //         && self.quote.is_some()
150        //         && self.flexible.is_some()
151        //         && self.is_utf8.is_some()
152        //         && self.delimiter_freq.is_some()
153        //         && self.has_header_row.is_some()
154        //         && self.avg_record_len.is_some()
155        //         && self.delimiter_freq.is_some()
156        // );
157        if !(
158            self.delimiter.is_some()
159            // && self.num_preamble_rows.is_some()
160            && self.quote.is_some()
161            && self.is_utf8.is_some()
162            // && self.has_header_row.is_some()
163            // && self.avg_record_len.is_some()
164        ) {
165            return Err(SnifferError::SniffingFailed(format!(
166                "Failed to infer all metadata: {self:?}"
167            )));
168        }
169        // safety: we just checked that all these are Some, so it's safe to unwrap
170        Ok(Metadata {
171            dialect: Dialect {
172                delimiter: self.delimiter.unwrap(),
173                // header: Header {
174                //     num_preamble_rows: self.num_preamble_rows.unwrap(),
175                //     has_header_row: self.has_header_row.unwrap(),
176                // },
177                quote: self.quote.clone().unwrap(),
178                // flexible: self.flexible.unwrap(),
179                // is_utf8: self.is_utf8.unwrap(),
180            },
181            // avg_record_len: self.avg_record_len.unwrap(),
182            // num_fields: self.delimiter_freq.unwrap() + 1,
183            // fields: self.fields.clone(),
184            // types: self.types.clone(),
185        })
186    }
187
188    // Infers quotes and delimiter from quoted (or possibly quoted) files. If quotes detected,
189    // updates self.quote and self.delimiter. If quotes not detected, updates self.quote to
190    // Quote::None. Only valid quote characters: " (double-quote), ' (single-quote), ` (back-tick).
191    fn infer_quotes_delim<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
192        if let (&Some(_), &Some(_)) = (&self.quote, &self.delimiter) {
193            // nothing left to infer!
194            return Ok(());
195        }
196        let quote_guesses = match self.quote {
197            Some(Quote::Some(chr)) => vec![chr],
198            Some(Quote::None) => {
199                // this function only checks quoted (or possibly quoted) files, nothing left to
200                // do if we know there are no quotes
201                return Ok(());
202            }
203            None => vec![b'\'', b'"', b'`'],
204        };
205        let (quote_chr, (quote_cnt, delim_guess)) = quote_guesses.iter().try_fold(
206            (b'"', (0, b'\0')),
207            |acc, &chr| -> Result<(u8, (usize, u8))> {
208                let mut sample_reader = take_sample_from_start(reader, self.get_sample_size())?;
209                if let Some((cnt, delim_chr)) =
210                    quote_count(&mut sample_reader, char::from(chr), self.delimiter)?
211                {
212                    Ok(if cnt > acc.1.0 {
213                        (chr, (cnt, delim_chr))
214                    } else {
215                        acc
216                    })
217                } else {
218                    Ok(acc)
219                }
220            },
221        )?;
222        if quote_cnt == 0 {
223            self.quote = Some(Quote::None);
224        } else {
225            self.quote = Some(Quote::Some(quote_chr));
226            self.delimiter = Some(delim_guess);
227        };
228        Ok(())
229    }
230
231    // Updates delimiter, delimiter frequency, number of preamble rows, and flexible boolean.
232    fn infer_delim_preamble<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
233        let sample_iter =
234            take_sample_from_start(reader, self.get_sample_size())?.collect::<Result<Vec<_>>>()?;
235
236        let mut chars_frequency: HashMap<u8, HashMap<NumberOfOccurrences, NumberOfLines>> =
237            HashMap::with_capacity(NUM_ASCII_CHARS);
238
239        let mut modes: HashMap<u8, (NumberOfOccurrences, AdjacentFrequency)> =
240            HashMap::with_capacity(NUM_ASCII_CHARS);
241
242        for line in &sample_iter {
243            let mut line_frequency = HashMap::with_capacity(128);
244            for character in line.chars() {
245                let Ok(ascii_char) = u8::try_from(character) else {
246                    continue;
247                };
248                if !CANDIDATES.contains(&ascii_char) {
249                    continue;
250                }
251                *line_frequency.entry(ascii_char).or_default() += 1;
252            }
253            for (ascii_char, freq) in line_frequency {
254                let char_frequency = chars_frequency.entry(ascii_char).or_default();
255                *char_frequency.entry(freq).or_default() += 1;
256            }
257        }
258        for (&ascii_char, line_count_map) in &chars_frequency {
259            let Some((&mode_value, _)) = line_count_map
260                .iter()
261                .max_by_key(|&(_count, num_lines)| num_lines)
262            else {
263                continue; // skip empty maps, just in case
264            };
265
266            let mut adjusted_count = 0;
267            for delta in 0..=TOLERANCE {
268                for count in [mode_value.saturating_sub(delta), mode_value + delta] {
269                    if let Some(&lines) = line_count_map.get(&count) {
270                        adjusted_count += lines;
271                    }
272                }
273            }
274            if TOLERANCE > 0 {
275                if let Some(&lines) = line_count_map.get(&mode_value) {
276                    adjusted_count -= lines;
277                }
278            }
279
280            modes.insert(ascii_char, (mode_value, adjusted_count));
281        }
282        let top_candidates: Vec<u8> = modes
283            .iter()
284            .filter(|(_, (_, score))| *score > 0)
285            .sorted_by_key(|&(_, &(_, score))| std::cmp::Reverse(score)) // needs itertools or just sort
286            .take(6)
287            .map(|(&ch, _)| ch)
288            .collect();
289
290        let mut chains = vec![Chain::default(); NUM_ASCII_CHARS];
291
292        for line in sample_iter {
293            let mut freqs = [0; NUM_ASCII_CHARS];
294            for &chr in line.as_bytes() {
295                if chr < NUM_ASCII_CHARS as u8 {
296                    freqs[chr as usize] += 1;
297                }
298            }
299            for &ch in &top_candidates {
300                chains[ch as usize].add_observation(freqs[ch as usize]);
301            }
302        }
303
304        self.run_chains(chains)
305    }
306
307    // Updates delimiter (if not already known), delimiter frequency, number of preamble rows, and
308    // flexible boolean.
309    fn run_chains(&mut self, mut chains: Vec<Chain>) -> Result<()> {
310        // Find the 'best' delimiter: choose strict (non-flexible) delimiters over flexible ones,
311        // and choose the one that had the highest probability markov chain in the end.
312        //
313        // In the case where delim is already known, 'best_delim' will be incorrect (since it won't
314        // correspond with position in a vector of Chains), but we'll just ignore it when
315        // constructing our return value later. 'best_state' and 'path' are necessary, though, to
316        // compute the preamble rows.
317        let (best_delim, _, _, _, _) = chains.iter_mut().enumerate().fold(
318            (b',', 0, STATE_UNSTEADY, vec![], 0.0),
319            |acc, (i, ref mut chain)| {
320                let (_, _, best_state, _, best_state_prob) = acc;
321                let ViterbiResults {
322                    max_delim_freq,
323                    path,
324                } = chain.viterbi();
325                if path.is_empty() {
326                    return acc;
327                }
328                let (final_state, final_viter) = path[path.len() - 1];
329                if final_state < best_state
330                    || (final_state == best_state && final_viter.prob > best_state_prob)
331                {
332                    (i as u8, max_delim_freq, final_state, path, final_viter.prob)
333                } else {
334                    acc
335                }
336            },
337        );
338        // self.flexible = Some(match best_state {
339        //     STATE_STEADYSTRICT => false,
340        //     STATE_STEADYFLEX => true,
341        //     _ => {
342        //         return Err(SnifferError::SniffingFailed(
343        //             "unable to find valid delimiter".to_string(),
344        //         ));
345        //     }
346        // });
347
348        // Find the number of preamble rows (the number of rows during which the state fluctuated
349        // before getting to the final state).
350        // let mut num_preamble_rows = 0;
351        // since path has an extra state as the beginning, skip one
352        // for &(state, _) in path.iter().skip(2) {
353        //     if state == best_state {
354        //         break;
355        //     }
356        //     num_preamble_rows += 1;
357        // }
358        // if num_preamble_rows > 0 {
359        //     num_preamble_rows += 1;
360        // }
361        if self.delimiter.is_none() {
362            self.delimiter = Some(best_delim);
363        }
364        // self.num_preamble_rows = Some(num_preamble_rows);
365        Ok(())
366    }
367
368    // fn infer_types<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
369    //     // prerequisites for calling this function:
370    //     if self.delimiter_freq.is_none() {
371    //         // instead of assert, return error
372    //         // assert!(self.delimiter_freq.is_some());
373    //         return Err(SnifferError::SniffingFailed(
374    //             "delimiter frequency not known".to_string(),
375    //         ));
376    //     }
377    //     // safety: unwrap is safe as we just checked that delimiter_freq is Some
378    //     let field_count = self.delimiter_freq.unwrap() + 1;
379
380    //     let mut csv_reader = self.create_csv_reader(reader)?;
381    //     let mut records_iter = csv_reader.byte_records();
382    //     let mut n_bytes = 0;
383    //     let mut n_records = 0;
384    //     let sample_size = self.get_sample_size();
385
386    //     // Infer types for the top row. We'll save this set of types to check against the types
387    //     // of the remaining rows to see if this is part of the data or a separate header row.
388    //     let header_row_types = match records_iter.next() {
389    //         Some(record) => {
390    //             let byte_record = record?;
391    //             let str_record = StringRecord::from_byte_record_lossy(byte_record);
392    //             n_records += 1;
393    //             n_bytes += count_bytes(&str_record);
394    //             infer_record_types(&str_record)
395    //         }
396    //         None => {
397    //             return Err(SnifferError::SniffingFailed(
398    //                 "CSV empty (after preamble)".into(),
399    //             ));
400    //         }
401    //     };
402    //     let mut row_types = vec![TypeGuesses::all(); field_count];
403
404    //     for record in records_iter {
405    //         let record = record?;
406    //         for (i, field) in record.iter().enumerate() {
407    //             let str_field = String::from_utf8_lossy(field).to_string();
408    //             row_types[i] &= infer_types(&str_field);
409    //         }
410    //         n_records += 1;
411    //         n_bytes += record.as_slice().len();
412    //         // break if we pass sample size limits
413    //         match sample_size {
414    //             SampleSize::Records(recs) => {
415    //                 if n_records > recs {
416    //                     break;
417    //                 }
418    //             }
419    //             SampleSize::Bytes(bytes) => {
420    //                 if n_bytes > bytes {
421    //                     break;
422    //                 }
423    //             }
424    //             SampleSize::All => {}
425    //         }
426    //     }
427    //     if n_records == 1 {
428    //         // there's only one row in the whole data file (the top row already parsed),
429    //         // so we're going to assume it's a data row, not a header row.
430    //         self.has_header_row = Some(false);
431    //         self.types = get_best_types(&header_row_types);
432    //         self.avg_record_len = Some(n_bytes);
433    //         return Ok(());
434    //     }
435
436    //     if header_row_types
437    //         .iter()
438    //         .zip(&row_types)
439    //         .any(|(header, data)| !data.allows(*header))
440    //     {
441    //         self.has_header_row = Some(true);
442    //         // get field names in header
443    //         for field in csv_reader.byte_headers()? {
444    //             self.fields.push(String::from_utf8_lossy(field).to_string());
445    //         }
446    //     } else {
447    //         self.has_header_row = Some(false);
448    //     }
449
450    //     self.types = get_best_types(&row_types);
451    //     self.avg_record_len = Some(n_bytes / n_records);
452    //     Ok(())
453    // }
454
455    // fn create_csv_reader<'a, R: Read + Seek>(
456    //     &self,
457    //     mut reader: &'a mut R,
458    // ) -> Result<Reader<&'a mut R>> {
459    //     reader.seek(SeekFrom::Start(0))?;
460    //     if let Some(num_preamble_rows) = self.num_preamble_rows {
461    //         snip_preamble(&mut reader, num_preamble_rows)?;
462    //     }
463
464    //     let mut builder = csv::ReaderBuilder::new();
465    //     if let Some(delim) = self.delimiter {
466    //         builder.delimiter(delim);
467    //     }
468    //     if let Some(has_header_row) = self.has_header_row {
469    //         builder.has_headers(has_header_row);
470    //     }
471    //     match self.quote {
472    //         Some(Quote::Some(chr)) => {
473    //             builder.quoting(true);
474    //             builder.quote(chr);
475    //         }
476    //         Some(Quote::None) => {
477    //             builder.quoting(false);
478    //         }
479    //         _ => {}
480    //     }
481    //     if let Some(flexible) = self.flexible {
482    //         builder.flexible(flexible);
483    //     }
484
485    //     Ok(builder.from_reader(reader))
486    // }
487}
488
489fn quote_count<R: Read>(
490    sample_iter: &mut SampleIter<R>,
491    character: char,
492    delim: Option<u8>,
493) -> Result<Option<(usize, u8)>> {
494    // Collect all lines into a single string to handle multi-line quoted fields
495    let mut sample_content = String::new();
496    for line in sample_iter {
497        let line = line?;
498        if !sample_content.is_empty() {
499            sample_content.push('\n');
500        }
501        sample_content.push_str(&line);
502    }
503
504    // Build a regex that matches a quoted CSV cell.
505    // For multi-line support, use DOTALL mode to match newlines within quotes.
506    let pattern = delim.map_or_else(
507        || {
508            // When delim is not provided, look for quoted fields and capture surrounding delimiters
509            // Make the pattern more restrictive to avoid matching apostrophes in text
510            format!(
511                r"(?:(?<delim1>[,;|\t:])\s*{character}(?:(?:{character}{character})|(?s:[^{character}]))*?{character}(?:\s*(?<delim2>[,;|\t:]))?)|(?:^{character}(?:(?:{character}{character})|(?s:[^{character}]))*?{character}(?:\s*(?<delim3>[,;|\t:])|$))"
512            )
513        },
514        |delim| {
515            // When a delimiter is provided, enforce its presence if it appears.
516            format!(
517                r"{q}(?P<field>(?s:(?:[^{q}]|{q}{q})*)){q}(?:\s*{d}\s*)?",
518                q = character,
519                d = delim as char
520            )
521        },
522    );
523
524    // Safety: unwrap is safe here because we control the regex pattern.
525    let re = Regex::new(&pattern).unwrap();
526
527    let mut delim_count_map: HashMap<u8, usize> = HashMap::new();
528    let mut count = 0;
529
530    // Apply regex to the entire concatenated sample content
531    for cap in re.captures_iter(&sample_content) {
532        count += 1;
533
534        if let Some(delim) = get_delimiter(&cap) {
535            *delim_count_map.entry(delim).or_insert(0) += 1;
536        }
537    }
538
539    if count == 0 {
540        return Ok(None);
541    }
542
543    // If a delimiter was provided, just return it.
544    if let Some(delim) = delim {
545        return Ok(Some((count, delim)));
546    }
547
548    // Otherwise, select the candidate delimiter that was matched most frequently.
549    let (delim_count, delim) =
550        delim_count_map
551            .into_iter()
552            .fold((0, b'\0'), |acc, (delim, d_count)| {
553                if d_count > acc.0 {
554                    (d_count, delim)
555                } else {
556                    acc
557                }
558            });
559
560    if delim_count == 0 {
561        return Err(SnifferError::SniffingFailed(
562            "invalid regex match: no delimiter found".into(),
563        ));
564    }
565    Ok(Some((count, delim)))
566}
567
568fn get_delimiter(captures: &Captures<'_>) -> Option<u8> {
569    let mut counts: HashMap<char, usize> = HashMap::new();
570    // Check groups delim1 through delim3.
571    for i in 1..=3 {
572        let group_name = format!("delim{i}");
573        if let Some(matched) = captures.name(&group_name) {
574            if let Some(ch) = matched.as_str().chars().next() {
575                *counts.entry(ch).or_insert(0) += 1;
576            }
577        }
578    }
579
580    // If no candidates were found, return None.
581    if counts.is_empty() {
582        return None;
583    }
584
585    let mut candidates: Vec<(char, usize)> = counts.into_iter().collect();
586    candidates.sort_by(|a, b| {
587        b.1.cmp(&a.1)
588            .then_with(|| delimiter_priority(a.0).cmp(&delimiter_priority(b.0))) // Priority ascending (lower number wins)
589            .then_with(|| a.0.cmp(&b.0))
590    });
591
592    let (candidate, _) = candidates.into_iter().next()?;
593    u8::try_from(candidate).ok()
594}
595
596const fn delimiter_priority(delimiter: char) -> u8 {
597    match delimiter {
598        ',' => 0,  // Comma - most common CSV delimiter
599        ';' => 1,  // Semicolon - common in European CSVs
600        '\t' => 2, // Tab - TSV files
601        '|' => 3,  // Pipe - alternate delimiter
602        ':' => 4,  // Colon - sometimes used
603        _ => 255,  // Everything else gets lowest priority
604    }
605}