csv_scout/sniffer.rs
1use hashbrown::HashMap;
2use itertools::Itertools;
3use std::cell::RefCell;
4use std::fs::File;
5use std::io::{Read, Seek, SeekFrom};
6use std::path::Path;
7
8use csv::Reader;
9use regex::{Captures, Regex};
10
11use crate::{
12 chain::{Chain, STATE_UNSTEADY, ViterbiResults},
13 error::{Result, SnifferError},
14 // field_type::DatePreference,
15 metadata::{Dialect, Metadata, Quote},
16 sample::{SampleIter, SampleSize, take_sample_from_start},
17};
18
19type NumberOfOccurrences = u32;
20type NumberOfLines = u32;
21type AdjacentFrequency = u32;
22
23const TOLERANCE: u32 = 1;
24const NUM_ASCII_CHARS: usize = 128;
25const CANDIDATES: &[u8] = b"\t,;|:";
26
27thread_local! (pub static IS_UTF8: RefCell<bool> = const { RefCell::new(true) });
28// thread_local! (pub static DATE_PREFERENCE: RefCell<DatePreference> = const { RefCell::new(DatePreference::MdyFormat) });
29
30/// A CSV sniffer.
31///
32/// The sniffer examines a CSV file, passed in either through a file or a reader.
33#[derive(Debug, Default)]
34pub struct Sniffer {
35 // CSV file dialect guesses
36 delimiter: Option<u8>,
37 // num_preamble_rows: Option<usize>,
38 // has_header_row: Option<bool>,
39 quote: Option<Quote>,
40 is_utf8: Option<bool>,
41 // Metadata guesses
42 // delimiter_freq: Option<usize>,
43 // fields: Vec<String>,
44 // types: Vec<Type>,
45 // avg_record_len: Option<usize>,
46
47 // sample size to sniff
48 sample_size: Option<SampleSize>,
49 // date format preference
50 // date_preference: Option<DatePreference>,
51}
52impl Sniffer {
53 /// Create a new CSV sniffer.
54 pub fn new() -> Self {
55 Self::default()
56 }
57 /// Specify the delimiter character.
58 pub fn delimiter(&mut self, delimiter: u8) -> &mut Self {
59 self.delimiter = Some(delimiter);
60 self
61 }
62 /// Specify the header type (whether the CSV file has a header row, and where the data starts).
63 // pub fn header(&mut self, header: &Header) -> &mut Self {
64 // self.num_preamble_rows = Some(header.num_preamble_rows);
65 // self.has_header_row = Some(header.has_header_row);
66 // self
67 // }
68 /// Specify the quote character (if any), and whether two quotes in a row as to be interpreted
69 /// as an escaped quote.
70 pub fn quote(&mut self, quote: Quote) -> &mut Self {
71 self.quote = Some(quote);
72 self
73 }
74
75 /// The size of the sample to examine while sniffing. If using `SampleSize::Records`, the
76 /// sniffer will use the `Terminator::CRLF` as record separator.
77 ///
78 /// The sample size defaults to `SampleSize::Bytes(4096)`.
79 pub fn sample_size(&mut self, sample_size: SampleSize) -> &mut Self {
80 self.sample_size = Some(sample_size);
81 self
82 }
83
84 fn get_sample_size(&self) -> SampleSize {
85 self.sample_size.unwrap_or(SampleSize::Bytes(1 << 14))
86 }
87
88 // The date format preference when sniffing.
89 //
90 // The date format preference defaults to `DatePreference::MDY`.
91 // pub fn date_preference(&mut self, date_preference: DatePreference) -> &mut Self {
92 // DATE_PREFERENCE.with(|preference| {
93 // *preference.borrow_mut() = date_preference;
94 // });
95 // self.date_preference = Some(date_preference);
96 // self
97 // }
98
99 /// Sniff the CSV file located at the provided path, and return a `Reader` (from the
100 /// [`csv`](https://docs.rs/csv) crate) ready to ready the file.
101 ///
102 /// Fails on file opening or rendering errors, or on an error examining the file.
103 pub fn open_path<P: AsRef<Path>>(&mut self, path: P) -> Result<Reader<File>> {
104 self.open_reader(File::open(path)?)
105 }
106 /// Sniff the CSV file provided by the reader, and return a [`csv`](https://docs.rs/csv)
107 /// `Reader` object.
108 ///
109 /// Fails on file opening or rendering errors, or on an error examining the file.
110 pub fn open_reader<R: Read + Seek>(&mut self, mut reader: R) -> Result<Reader<R>> {
111 let metadata = self.sniff_reader(&mut reader)?;
112 reader.seek(SeekFrom::Start(0))?;
113 metadata.dialect.open_reader(reader)
114 }
115
116 /// Sniff the CSV file located at the provided path, and return a
117 /// [`Metadata`](struct.Metadata.html) object containing information about the CSV file.
118 ///
119 /// Fails on file opening or rendering errors, or on an error examining the file.
120 pub fn sniff_path<P: AsRef<Path>>(&mut self, path: P) -> Result<Metadata> {
121 let file = File::open(path)?;
122 self.sniff_reader(&file)
123 }
124 /// Sniff the CSV file provider by the reader, and return a
125 /// [`Metadata`](struct.Metadata.html) object containing information about the CSV file.
126 ///
127 /// Fails on file opening or readering errors, or on an error examining the file.
128 pub fn sniff_reader<R: Read + Seek>(&mut self, mut reader: R) -> Result<Metadata> {
129 // init IS_UTF8 global var to true
130 IS_UTF8.with(|flag| {
131 *flag.borrow_mut() = true;
132 });
133 // guess quotes & delim
134 self.infer_quotes_delim(&mut reader)?;
135
136 // if we have a delimiter, we just need to search for num_preamble_rows and check for
137 // flexible. Otherwise, we need to guess a delimiter as well.
138 if self.delimiter.is_none() {
139 self.infer_delim_preamble(&mut reader)?;
140 }
141
142 // self.infer_types(&mut reader)?;
143 self.is_utf8 = Some(IS_UTF8.with(|flag| *flag.borrow()));
144
145 // as this point of the process, we should have all these filled in.
146 // assert!(
147 // self.delimiter.is_some()
148 // && self.num_preamble_rows.is_some()
149 // && self.quote.is_some()
150 // && self.flexible.is_some()
151 // && self.is_utf8.is_some()
152 // && self.delimiter_freq.is_some()
153 // && self.has_header_row.is_some()
154 // && self.avg_record_len.is_some()
155 // && self.delimiter_freq.is_some()
156 // );
157 if !(
158 self.delimiter.is_some()
159 // && self.num_preamble_rows.is_some()
160 && self.quote.is_some()
161 && self.is_utf8.is_some()
162 // && self.has_header_row.is_some()
163 // && self.avg_record_len.is_some()
164 ) {
165 return Err(SnifferError::SniffingFailed(format!(
166 "Failed to infer all metadata: {self:?}"
167 )));
168 }
169 // safety: we just checked that all these are Some, so it's safe to unwrap
170 Ok(Metadata {
171 dialect: Dialect {
172 delimiter: self.delimiter.unwrap(),
173 // header: Header {
174 // num_preamble_rows: self.num_preamble_rows.unwrap(),
175 // has_header_row: self.has_header_row.unwrap(),
176 // },
177 quote: self.quote.clone().unwrap(),
178 // flexible: self.flexible.unwrap(),
179 // is_utf8: self.is_utf8.unwrap(),
180 },
181 // avg_record_len: self.avg_record_len.unwrap(),
182 // num_fields: self.delimiter_freq.unwrap() + 1,
183 // fields: self.fields.clone(),
184 // types: self.types.clone(),
185 })
186 }
187
188 // Infers quotes and delimiter from quoted (or possibly quoted) files. If quotes detected,
189 // updates self.quote and self.delimiter. If quotes not detected, updates self.quote to
190 // Quote::None. Only valid quote characters: " (double-quote), ' (single-quote), ` (back-tick).
191 fn infer_quotes_delim<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
192 if let (&Some(_), &Some(_)) = (&self.quote, &self.delimiter) {
193 // nothing left to infer!
194 return Ok(());
195 }
196 let quote_guesses = match self.quote {
197 Some(Quote::Some(chr)) => vec![chr],
198 Some(Quote::None) => {
199 // this function only checks quoted (or possibly quoted) files, nothing left to
200 // do if we know there are no quotes
201 return Ok(());
202 }
203 None => vec![b'\'', b'"', b'`'],
204 };
205 let (quote_chr, (quote_cnt, delim_guess)) = quote_guesses.iter().try_fold(
206 (b'"', (0, b'\0')),
207 |acc, &chr| -> Result<(u8, (usize, u8))> {
208 let mut sample_reader = take_sample_from_start(reader, self.get_sample_size())?;
209 if let Some((cnt, delim_chr)) =
210 quote_count(&mut sample_reader, char::from(chr), self.delimiter)?
211 {
212 Ok(if cnt > acc.1.0 {
213 (chr, (cnt, delim_chr))
214 } else {
215 acc
216 })
217 } else {
218 Ok(acc)
219 }
220 },
221 )?;
222 if quote_cnt == 0 {
223 self.quote = Some(Quote::None);
224 } else {
225 self.quote = Some(Quote::Some(quote_chr));
226 self.delimiter = Some(delim_guess);
227 };
228 Ok(())
229 }
230
231 // Updates delimiter, delimiter frequency, number of preamble rows, and flexible boolean.
232 fn infer_delim_preamble<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
233 let sample_iter =
234 take_sample_from_start(reader, self.get_sample_size())?.collect::<Result<Vec<_>>>()?;
235
236 let mut chars_frequency: HashMap<u8, HashMap<NumberOfOccurrences, NumberOfLines>> =
237 HashMap::with_capacity(NUM_ASCII_CHARS);
238
239 let mut modes: HashMap<u8, (NumberOfOccurrences, AdjacentFrequency)> =
240 HashMap::with_capacity(NUM_ASCII_CHARS);
241
242 for line in &sample_iter {
243 let mut line_frequency = HashMap::with_capacity(128);
244 for character in line.chars() {
245 let Ok(ascii_char) = u8::try_from(character) else {
246 continue;
247 };
248 if !CANDIDATES.contains(&ascii_char) {
249 continue;
250 }
251 *line_frequency.entry(ascii_char).or_default() += 1;
252 }
253 for (ascii_char, freq) in line_frequency {
254 let char_frequency = chars_frequency.entry(ascii_char).or_default();
255 *char_frequency.entry(freq).or_default() += 1;
256 }
257 }
258 for (&ascii_char, line_count_map) in &chars_frequency {
259 let Some((&mode_value, _)) = line_count_map
260 .iter()
261 .max_by_key(|&(_count, num_lines)| num_lines)
262 else {
263 continue; // skip empty maps, just in case
264 };
265
266 let mut adjusted_count = 0;
267 for delta in 0..=TOLERANCE {
268 for count in [mode_value.saturating_sub(delta), mode_value + delta] {
269 if let Some(&lines) = line_count_map.get(&count) {
270 adjusted_count += lines;
271 }
272 }
273 }
274 if TOLERANCE > 0 {
275 if let Some(&lines) = line_count_map.get(&mode_value) {
276 adjusted_count -= lines;
277 }
278 }
279
280 modes.insert(ascii_char, (mode_value, adjusted_count));
281 }
282 let top_candidates: Vec<u8> = modes
283 .iter()
284 .filter(|(_, (_, score))| *score > 0)
285 .sorted_by_key(|&(_, &(_, score))| std::cmp::Reverse(score)) // needs itertools or just sort
286 .take(6)
287 .map(|(&ch, _)| ch)
288 .collect();
289
290 let mut chains = vec![Chain::default(); NUM_ASCII_CHARS];
291
292 for line in sample_iter {
293 let mut freqs = [0; NUM_ASCII_CHARS];
294 for &chr in line.as_bytes() {
295 if chr < NUM_ASCII_CHARS as u8 {
296 freqs[chr as usize] += 1;
297 }
298 }
299 for &ch in &top_candidates {
300 chains[ch as usize].add_observation(freqs[ch as usize]);
301 }
302 }
303
304 self.run_chains(chains)
305 }
306
307 // Updates delimiter (if not already known), delimiter frequency, number of preamble rows, and
308 // flexible boolean.
309 fn run_chains(&mut self, mut chains: Vec<Chain>) -> Result<()> {
310 // Find the 'best' delimiter: choose strict (non-flexible) delimiters over flexible ones,
311 // and choose the one that had the highest probability markov chain in the end.
312 //
313 // In the case where delim is already known, 'best_delim' will be incorrect (since it won't
314 // correspond with position in a vector of Chains), but we'll just ignore it when
315 // constructing our return value later. 'best_state' and 'path' are necessary, though, to
316 // compute the preamble rows.
317 let (best_delim, _, _, _, _) = chains.iter_mut().enumerate().fold(
318 (b',', 0, STATE_UNSTEADY, vec![], 0.0),
319 |acc, (i, ref mut chain)| {
320 let (_, _, best_state, _, best_state_prob) = acc;
321 let ViterbiResults {
322 max_delim_freq,
323 path,
324 } = chain.viterbi();
325 if path.is_empty() {
326 return acc;
327 }
328 let (final_state, final_viter) = path[path.len() - 1];
329 if final_state < best_state
330 || (final_state == best_state && final_viter.prob > best_state_prob)
331 {
332 (i as u8, max_delim_freq, final_state, path, final_viter.prob)
333 } else {
334 acc
335 }
336 },
337 );
338 // self.flexible = Some(match best_state {
339 // STATE_STEADYSTRICT => false,
340 // STATE_STEADYFLEX => true,
341 // _ => {
342 // return Err(SnifferError::SniffingFailed(
343 // "unable to find valid delimiter".to_string(),
344 // ));
345 // }
346 // });
347
348 // Find the number of preamble rows (the number of rows during which the state fluctuated
349 // before getting to the final state).
350 // let mut num_preamble_rows = 0;
351 // since path has an extra state as the beginning, skip one
352 // for &(state, _) in path.iter().skip(2) {
353 // if state == best_state {
354 // break;
355 // }
356 // num_preamble_rows += 1;
357 // }
358 // if num_preamble_rows > 0 {
359 // num_preamble_rows += 1;
360 // }
361 if self.delimiter.is_none() {
362 self.delimiter = Some(best_delim);
363 }
364 // self.num_preamble_rows = Some(num_preamble_rows);
365 Ok(())
366 }
367
368 // fn infer_types<R: Read + Seek>(&mut self, reader: &mut R) -> Result<()> {
369 // // prerequisites for calling this function:
370 // if self.delimiter_freq.is_none() {
371 // // instead of assert, return error
372 // // assert!(self.delimiter_freq.is_some());
373 // return Err(SnifferError::SniffingFailed(
374 // "delimiter frequency not known".to_string(),
375 // ));
376 // }
377 // // safety: unwrap is safe as we just checked that delimiter_freq is Some
378 // let field_count = self.delimiter_freq.unwrap() + 1;
379
380 // let mut csv_reader = self.create_csv_reader(reader)?;
381 // let mut records_iter = csv_reader.byte_records();
382 // let mut n_bytes = 0;
383 // let mut n_records = 0;
384 // let sample_size = self.get_sample_size();
385
386 // // Infer types for the top row. We'll save this set of types to check against the types
387 // // of the remaining rows to see if this is part of the data or a separate header row.
388 // let header_row_types = match records_iter.next() {
389 // Some(record) => {
390 // let byte_record = record?;
391 // let str_record = StringRecord::from_byte_record_lossy(byte_record);
392 // n_records += 1;
393 // n_bytes += count_bytes(&str_record);
394 // infer_record_types(&str_record)
395 // }
396 // None => {
397 // return Err(SnifferError::SniffingFailed(
398 // "CSV empty (after preamble)".into(),
399 // ));
400 // }
401 // };
402 // let mut row_types = vec![TypeGuesses::all(); field_count];
403
404 // for record in records_iter {
405 // let record = record?;
406 // for (i, field) in record.iter().enumerate() {
407 // let str_field = String::from_utf8_lossy(field).to_string();
408 // row_types[i] &= infer_types(&str_field);
409 // }
410 // n_records += 1;
411 // n_bytes += record.as_slice().len();
412 // // break if we pass sample size limits
413 // match sample_size {
414 // SampleSize::Records(recs) => {
415 // if n_records > recs {
416 // break;
417 // }
418 // }
419 // SampleSize::Bytes(bytes) => {
420 // if n_bytes > bytes {
421 // break;
422 // }
423 // }
424 // SampleSize::All => {}
425 // }
426 // }
427 // if n_records == 1 {
428 // // there's only one row in the whole data file (the top row already parsed),
429 // // so we're going to assume it's a data row, not a header row.
430 // self.has_header_row = Some(false);
431 // self.types = get_best_types(&header_row_types);
432 // self.avg_record_len = Some(n_bytes);
433 // return Ok(());
434 // }
435
436 // if header_row_types
437 // .iter()
438 // .zip(&row_types)
439 // .any(|(header, data)| !data.allows(*header))
440 // {
441 // self.has_header_row = Some(true);
442 // // get field names in header
443 // for field in csv_reader.byte_headers()? {
444 // self.fields.push(String::from_utf8_lossy(field).to_string());
445 // }
446 // } else {
447 // self.has_header_row = Some(false);
448 // }
449
450 // self.types = get_best_types(&row_types);
451 // self.avg_record_len = Some(n_bytes / n_records);
452 // Ok(())
453 // }
454
455 // fn create_csv_reader<'a, R: Read + Seek>(
456 // &self,
457 // mut reader: &'a mut R,
458 // ) -> Result<Reader<&'a mut R>> {
459 // reader.seek(SeekFrom::Start(0))?;
460 // if let Some(num_preamble_rows) = self.num_preamble_rows {
461 // snip_preamble(&mut reader, num_preamble_rows)?;
462 // }
463
464 // let mut builder = csv::ReaderBuilder::new();
465 // if let Some(delim) = self.delimiter {
466 // builder.delimiter(delim);
467 // }
468 // if let Some(has_header_row) = self.has_header_row {
469 // builder.has_headers(has_header_row);
470 // }
471 // match self.quote {
472 // Some(Quote::Some(chr)) => {
473 // builder.quoting(true);
474 // builder.quote(chr);
475 // }
476 // Some(Quote::None) => {
477 // builder.quoting(false);
478 // }
479 // _ => {}
480 // }
481 // if let Some(flexible) = self.flexible {
482 // builder.flexible(flexible);
483 // }
484
485 // Ok(builder.from_reader(reader))
486 // }
487}
488
489fn quote_count<R: Read>(
490 sample_iter: &mut SampleIter<R>,
491 character: char,
492 delim: Option<u8>,
493) -> Result<Option<(usize, u8)>> {
494 // Collect all lines into a single string to handle multi-line quoted fields
495 let mut sample_content = String::new();
496 for line in sample_iter {
497 let line = line?;
498 if !sample_content.is_empty() {
499 sample_content.push('\n');
500 }
501 sample_content.push_str(&line);
502 }
503
504 // Build a regex that matches a quoted CSV cell.
505 // For multi-line support, use DOTALL mode to match newlines within quotes.
506 let pattern = delim.map_or_else(
507 || {
508 // When delim is not provided, look for quoted fields and capture surrounding delimiters
509 // Make the pattern more restrictive to avoid matching apostrophes in text
510 format!(
511 r"(?:(?<delim1>[,;|\t:])\s*{character}(?:(?:{character}{character})|(?s:[^{character}]))*?{character}(?:\s*(?<delim2>[,;|\t:]))?)|(?:^{character}(?:(?:{character}{character})|(?s:[^{character}]))*?{character}(?:\s*(?<delim3>[,;|\t:])|$))"
512 )
513 },
514 |delim| {
515 // When a delimiter is provided, enforce its presence if it appears.
516 format!(
517 r"{q}(?P<field>(?s:(?:[^{q}]|{q}{q})*)){q}(?:\s*{d}\s*)?",
518 q = character,
519 d = delim as char
520 )
521 },
522 );
523
524 // Safety: unwrap is safe here because we control the regex pattern.
525 let re = Regex::new(&pattern).unwrap();
526
527 let mut delim_count_map: HashMap<u8, usize> = HashMap::new();
528 let mut count = 0;
529
530 // Apply regex to the entire concatenated sample content
531 for cap in re.captures_iter(&sample_content) {
532 count += 1;
533
534 if let Some(delim) = get_delimiter(&cap) {
535 *delim_count_map.entry(delim).or_insert(0) += 1;
536 }
537 }
538
539 if count == 0 {
540 return Ok(None);
541 }
542
543 // If a delimiter was provided, just return it.
544 if let Some(delim) = delim {
545 return Ok(Some((count, delim)));
546 }
547
548 // Otherwise, select the candidate delimiter that was matched most frequently.
549 let (delim_count, delim) =
550 delim_count_map
551 .into_iter()
552 .fold((0, b'\0'), |acc, (delim, d_count)| {
553 if d_count > acc.0 {
554 (d_count, delim)
555 } else {
556 acc
557 }
558 });
559
560 if delim_count == 0 {
561 return Err(SnifferError::SniffingFailed(
562 "invalid regex match: no delimiter found".into(),
563 ));
564 }
565 Ok(Some((count, delim)))
566}
567
568fn get_delimiter(captures: &Captures<'_>) -> Option<u8> {
569 let mut counts: HashMap<char, usize> = HashMap::new();
570 // Check groups delim1 through delim3.
571 for i in 1..=3 {
572 let group_name = format!("delim{i}");
573 if let Some(matched) = captures.name(&group_name) {
574 if let Some(ch) = matched.as_str().chars().next() {
575 *counts.entry(ch).or_insert(0) += 1;
576 }
577 }
578 }
579
580 // If no candidates were found, return None.
581 if counts.is_empty() {
582 return None;
583 }
584
585 let mut candidates: Vec<(char, usize)> = counts.into_iter().collect();
586 candidates.sort_by(|a, b| {
587 b.1.cmp(&a.1)
588 .then_with(|| delimiter_priority(a.0).cmp(&delimiter_priority(b.0))) // Priority ascending (lower number wins)
589 .then_with(|| a.0.cmp(&b.0))
590 });
591
592 let (candidate, _) = candidates.into_iter().next()?;
593 u8::try_from(candidate).ok()
594}
595
596const fn delimiter_priority(delimiter: char) -> u8 {
597 match delimiter {
598 ',' => 0, // Comma - most common CSV delimiter
599 ';' => 1, // Semicolon - common in European CSVs
600 '\t' => 2, // Tab - TSV files
601 '|' => 3, // Pipe - alternate delimiter
602 ':' => 4, // Colon - sometimes used
603 _ => 255, // Everything else gets lowest priority
604 }
605}