csv_nose/
sniffer.rs

1//! Main Sniffer builder and sniff methods.
2//!
3//! This module provides the qsv-sniffer compatible API.
4
5use std::fs::File;
6use std::io::{Read, Seek};
7use std::path::Path;
8
9use crate::encoding::{detect_and_transcode, detect_encoding, skip_bom};
10use crate::error::{Result, SnifferError};
11use crate::field_type::Type;
12use crate::metadata::{Dialect, Header, Metadata, Quote};
13use crate::sample::{DatePreference, SampleSize};
14use crate::tum::potential_dialects::{
15    PotentialDialect, detect_line_terminator, generate_dialects_with_terminator,
16};
17use crate::tum::score::{DialectScore, find_best_dialect, score_all_dialects};
18use crate::tum::table::parse_table;
19use crate::tum::type_detection::infer_column_types;
20
21/// CSV dialect sniffer using the Table Uniformity Method.
22///
23/// # Example
24///
25/// ```no_run
26/// use csv_nose::{Sniffer, SampleSize};
27///
28/// let mut sniffer = Sniffer::new();
29/// sniffer.sample_size(SampleSize::Records(100));
30///
31/// let metadata = sniffer.sniff_path("data.csv").unwrap();
32/// println!("Delimiter: {}", metadata.dialect.delimiter as char);
33/// println!("Has header: {}", metadata.dialect.header.has_header_row);
34/// ```
35#[derive(Debug, Clone)]
36pub struct Sniffer {
37    /// Sample size for sniffing.
38    sample_size: SampleSize,
39    /// Date format preference for ambiguous dates.
40    date_preference: DatePreference,
41    /// Optional forced delimiter.
42    forced_delimiter: Option<u8>,
43    /// Optional forced quote character.
44    forced_quote: Option<Quote>,
45}
46
47impl Default for Sniffer {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl Sniffer {
54    /// Create a new Sniffer with default settings.
55    pub fn new() -> Self {
56        Self {
57            sample_size: SampleSize::Records(100),
58            date_preference: DatePreference::MdyFormat,
59            forced_delimiter: None,
60            forced_quote: None,
61        }
62    }
63
64    /// Set the sample size for sniffing.
65    pub fn sample_size(&mut self, sample_size: SampleSize) -> &mut Self {
66        self.sample_size = sample_size;
67        self
68    }
69
70    /// Set the date preference for ambiguous date parsing.
71    pub fn date_preference(&mut self, date_preference: DatePreference) -> &mut Self {
72        self.date_preference = date_preference;
73        self
74    }
75
76    /// Force a specific delimiter (skip delimiter detection).
77    pub fn delimiter(&mut self, delimiter: u8) -> &mut Self {
78        self.forced_delimiter = Some(delimiter);
79        self
80    }
81
82    /// Force a specific quote character.
83    pub fn quote(&mut self, quote: Quote) -> &mut Self {
84        self.forced_quote = Some(quote);
85        self
86    }
87
88    /// Sniff a CSV file at the given path.
89    pub fn sniff_path<P: AsRef<Path>>(&mut self, path: P) -> Result<Metadata> {
90        let file = File::open(path.as_ref())?;
91        let mut reader = std::io::BufReader::new(file);
92        self.sniff_reader(&mut reader)
93    }
94
95    /// Sniff CSV data from a reader.
96    pub fn sniff_reader<R: Read + Seek>(&mut self, reader: R) -> Result<Metadata> {
97        let data = self.read_sample(reader)?;
98
99        if data.is_empty() {
100            return Err(SnifferError::EmptyData);
101        }
102
103        self.sniff_bytes(&data)
104    }
105
106    /// Sniff CSV data from bytes.
107    pub fn sniff_bytes(&self, data: &[u8]) -> Result<Metadata> {
108        if data.is_empty() {
109            return Err(SnifferError::EmptyData);
110        }
111
112        // Detect encoding and transcode to UTF-8 if necessary
113        let (transcoded_data, was_transcoded) = detect_and_transcode(data);
114        let data = &transcoded_data[..];
115
116        // Detect encoding info (for metadata)
117        let encoding_info = detect_encoding(data);
118        let is_utf8 = !was_transcoded || encoding_info.is_utf8;
119
120        // Skip BOM
121        let data = skip_bom(data);
122
123        // Skip comment/preamble lines (lines starting with #)
124        let (preamble_rows, data) = skip_preamble(data);
125        let _ = preamble_rows; // Will be used for metadata in future
126
127        // Detect line terminator first to reduce search space
128        let line_terminator = detect_line_terminator(data);
129
130        // Generate potential dialects
131        let dialects = if let Some(delim) = self.forced_delimiter {
132            // If delimiter is forced, only test that delimiter with different quotes
133            let quotes = if let Some(q) = self.forced_quote {
134                vec![q]
135            } else {
136                vec![Quote::Some(b'"'), Quote::Some(b'\''), Quote::None]
137            };
138
139            quotes
140                .into_iter()
141                .map(|q| PotentialDialect::new(delim, q, line_terminator))
142                .collect()
143        } else {
144            generate_dialects_with_terminator(line_terminator)
145        };
146
147        // Determine max rows for scoring
148        let max_rows = match self.sample_size {
149            SampleSize::Records(n) => n,
150            SampleSize::Bytes(_) | SampleSize::All => 0, // Already limited by read_sample
151        };
152
153        // Score all dialects
154        let scores = score_all_dialects(data, &dialects, max_rows);
155
156        // Find the best dialect
157        let best = find_best_dialect(&scores)
158            .ok_or_else(|| SnifferError::NoDialectDetected("No valid dialect found".to_string()))?;
159
160        // Build metadata from the best dialect
161        self.build_metadata(data, best, is_utf8)
162    }
163
164    /// Read a sample of data from the reader based on sample_size settings.
165    fn read_sample<R: Read + Seek>(&self, mut reader: R) -> Result<Vec<u8>> {
166        match self.sample_size {
167            SampleSize::Bytes(n) => {
168                let mut buffer = vec![0u8; n];
169                let bytes_read = reader.read(&mut buffer)?;
170                buffer.truncate(bytes_read);
171                Ok(buffer)
172            }
173            SampleSize::All => {
174                let mut buffer = Vec::new();
175                reader.read_to_end(&mut buffer)?;
176                Ok(buffer)
177            }
178            SampleSize::Records(n) => {
179                // For records, we read enough to capture n records
180                // Estimate ~1KB per record as a starting point, with a minimum
181                let estimated_size = (n * 1024).max(8192);
182                let mut buffer = vec![0u8; estimated_size];
183                let bytes_read = reader.read(&mut buffer)?;
184                buffer.truncate(bytes_read);
185
186                // If we need more data, keep reading
187                if bytes_read == estimated_size {
188                    // Count newlines to see if we have enough records
189                    let newlines = bytecount::count(&buffer, b'\n');
190                    if newlines < n {
191                        // Read more data
192                        let additional = (n - newlines) * 2048;
193                        let mut more = vec![0u8; additional];
194                        let more_read = reader.read(&mut more)?;
195                        more.truncate(more_read);
196                        buffer.extend(more);
197                    }
198                }
199
200                Ok(buffer)
201            }
202        }
203    }
204
205    /// Build Metadata from the best scoring dialect.
206    fn build_metadata(&self, data: &[u8], score: &DialectScore, is_utf8: bool) -> Result<Metadata> {
207        // Parse the table with the best dialect
208        let max_rows = match self.sample_size {
209            SampleSize::Records(n) => n,
210            _ => 0,
211        };
212
213        let table = parse_table(data, &score.dialect, max_rows);
214
215        if table.is_empty() {
216            return Err(SnifferError::EmptyData);
217        }
218
219        // Detect header
220        let header = detect_header(&table, &score.dialect);
221
222        // Get field names
223        let fields = if header.has_header_row && !table.rows.is_empty() {
224            table.rows[0].clone()
225        } else {
226            // Generate field names
227            (0..score.num_fields)
228                .map(|i| format!("field_{}", i + 1))
229                .collect()
230        };
231
232        // Skip header row for type inference if present
233        let data_table = if header.has_header_row && table.rows.len() > 1 {
234            let mut dt = crate::tum::table::Table::new();
235            dt.rows = table.rows[1..].to_vec();
236            dt.field_counts = table.field_counts[1..].to_vec();
237            dt
238        } else {
239            table.clone()
240        };
241
242        // Infer types for each column
243        let types = infer_column_types(&data_table);
244
245        // Build dialect
246        let dialect = Dialect {
247            delimiter: score.dialect.delimiter,
248            header,
249            quote: score.dialect.quote,
250            flexible: !score.is_uniform,
251            is_utf8,
252        };
253
254        // Calculate average record length
255        let avg_record_len = calculate_avg_record_len(data, table.num_rows());
256
257        Ok(Metadata {
258            dialect,
259            avg_record_len,
260            num_fields: score.num_fields,
261            fields,
262            types,
263        })
264    }
265}
266
267/// Detect if the first row is likely a header row.
268fn detect_header(table: &crate::tum::table::Table, _dialect: &PotentialDialect) -> Header {
269    if table.rows.is_empty() {
270        return Header::new(false, 0);
271    }
272
273    if table.rows.len() < 2 {
274        // Can't determine header with only one row
275        return Header::new(false, 0);
276    }
277
278    let first_row = &table.rows[0];
279    let second_row = &table.rows[1];
280
281    // Heuristics for header detection:
282    // 1. First row has different types than subsequent rows
283    // 2. First row values look like labels (text when data is numeric)
284    // 3. First row has no duplicates (header columns should be unique)
285
286    let mut header_score = 0.0;
287    let mut checks = 0;
288
289    // Check 1: First row is all text, second row has typed data
290    let first_types: Vec<Type> = first_row
291        .iter()
292        .map(|s| crate::tum::type_detection::detect_cell_type(s))
293        .collect();
294    let second_types: Vec<Type> = second_row
295        .iter()
296        .map(|s| crate::tum::type_detection::detect_cell_type(s))
297        .collect();
298
299    let first_text_count = first_types.iter().filter(|&&t| t == Type::Text).count();
300    let second_text_count = second_types.iter().filter(|&&t| t == Type::Text).count();
301
302    if first_text_count > second_text_count {
303        header_score += 1.0;
304    }
305    checks += 1;
306
307    // Check 2: First row has more text than numeric
308    let first_numeric_count = first_types.iter().filter(|&&t| t.is_numeric()).count();
309    if first_text_count > first_numeric_count {
310        header_score += 0.5;
311    }
312    checks += 1;
313
314    // Check 3: No duplicates in first row
315    let unique_count = {
316        let mut seen = std::collections::HashSet::new();
317        first_row.iter().filter(|s| seen.insert(s.as_str())).count()
318    };
319    if unique_count == first_row.len() {
320        header_score += 0.5;
321    }
322    checks += 1;
323
324    // Check 4: First row values are shorter (headers tend to be concise)
325    let avg_first_len: f64 = first_row
326        .iter()
327        .map(std::string::String::len)
328        .sum::<usize>() as f64
329        / first_row.len().max(1) as f64;
330    let avg_second_len: f64 = second_row
331        .iter()
332        .map(std::string::String::len)
333        .sum::<usize>() as f64
334        / second_row.len().max(1) as f64;
335
336    if avg_first_len <= avg_second_len {
337        header_score += 0.3;
338    }
339    checks += 1;
340
341    // Threshold for header detection
342    let has_header = (header_score / checks as f64) > 0.4;
343
344    Header::new(has_header, 0)
345}
346
347/// Calculate average record length.
348fn calculate_avg_record_len(data: &[u8], num_rows: usize) -> usize {
349    if num_rows == 0 {
350        return 0;
351    }
352    data.len() / num_rows
353}
354
355/// Skip preamble/comment lines at the start of data.
356///
357/// Detects lines starting with '#' at the beginning of the file and returns
358/// the number of preamble rows and a slice starting after the preamble.
359fn skip_preamble(data: &[u8]) -> (usize, &[u8]) {
360    let mut preamble_rows = 0;
361    let mut offset = 0;
362
363    while offset < data.len() {
364        // Skip leading whitespace on the line
365        let mut line_start = offset;
366        while line_start < data.len() && (data[line_start] == b' ' || data[line_start] == b'\t') {
367            line_start += 1;
368        }
369
370        // Check if line starts with #
371        if line_start < data.len() && data[line_start] == b'#' {
372            // Find end of line
373            let mut line_end = line_start;
374            while line_end < data.len() && data[line_end] != b'\n' && data[line_end] != b'\r' {
375                line_end += 1;
376            }
377
378            // Skip line terminator
379            if line_end < data.len() && data[line_end] == b'\r' {
380                line_end += 1;
381            }
382            if line_end < data.len() && data[line_end] == b'\n' {
383                line_end += 1;
384            }
385
386            preamble_rows += 1;
387            offset = line_end;
388        } else {
389            // Not a comment line, stop
390            break;
391        }
392    }
393
394    (preamble_rows, &data[offset..])
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    #[test]
402    fn test_sniffer_builder() {
403        let mut sniffer = Sniffer::new();
404        sniffer
405            .sample_size(SampleSize::Records(50))
406            .date_preference(DatePreference::DmyFormat)
407            .delimiter(b',');
408
409        assert_eq!(sniffer.sample_size, SampleSize::Records(50));
410        assert_eq!(sniffer.date_preference, DatePreference::DmyFormat);
411        assert_eq!(sniffer.forced_delimiter, Some(b','));
412    }
413
414    #[test]
415    fn test_sniff_bytes() {
416        let data = b"name,age,city\nAlice,30,NYC\nBob,25,LA\n";
417        let sniffer = Sniffer::new();
418
419        let metadata = sniffer.sniff_bytes(data).unwrap();
420
421        assert_eq!(metadata.dialect.delimiter, b',');
422        assert!(metadata.dialect.header.has_header_row);
423        assert_eq!(metadata.num_fields, 3);
424        assert_eq!(metadata.fields, vec!["name", "age", "city"]);
425    }
426
427    #[test]
428    fn test_sniff_tsv() {
429        let data = b"name\tage\tcity\nAlice\t30\tNYC\nBob\t25\tLA\n";
430        let sniffer = Sniffer::new();
431
432        let metadata = sniffer.sniff_bytes(data).unwrap();
433
434        assert_eq!(metadata.dialect.delimiter, b'\t');
435        assert!(metadata.dialect.header.has_header_row);
436    }
437
438    #[test]
439    fn test_sniff_semicolon() {
440        let data = b"name;age;city\nAlice;30;NYC\nBob;25;LA\n";
441        let sniffer = Sniffer::new();
442
443        let metadata = sniffer.sniff_bytes(data).unwrap();
444
445        assert_eq!(metadata.dialect.delimiter, b';');
446    }
447
448    #[test]
449    fn test_sniff_no_header() {
450        let data = b"1,2,3\n4,5,6\n7,8,9\n";
451        let sniffer = Sniffer::new();
452
453        let metadata = sniffer.sniff_bytes(data).unwrap();
454
455        assert_eq!(metadata.dialect.delimiter, b',');
456        // All numeric data - should not detect header
457        assert!(!metadata.dialect.header.has_header_row);
458    }
459
460    #[test]
461    fn test_sniff_with_quotes() {
462        let data = b"\"name\",\"value\"\n\"hello, world\",123\n\"test\",456\n";
463        let sniffer = Sniffer::new();
464
465        let metadata = sniffer.sniff_bytes(data).unwrap();
466
467        assert_eq!(metadata.dialect.delimiter, b',');
468        assert_eq!(metadata.dialect.quote, Quote::Some(b'"'));
469    }
470
471    #[test]
472    fn test_sniff_empty() {
473        let data = b"";
474        let sniffer = Sniffer::new();
475
476        let result = sniffer.sniff_bytes(data);
477        assert!(result.is_err());
478    }
479
480    #[test]
481    fn test_skip_preamble() {
482        // Test with comment lines
483        let data = b"# This is a comment\n# Another comment\nname,age\nAlice,30\n";
484        let (preamble_rows, remaining) = skip_preamble(data);
485        assert_eq!(preamble_rows, 2);
486        assert_eq!(remaining, b"name,age\nAlice,30\n");
487
488        // Test without comment lines
489        let data = b"name,age\nAlice,30\n";
490        let (preamble_rows, remaining) = skip_preamble(data);
491        assert_eq!(preamble_rows, 0);
492        assert_eq!(remaining, b"name,age\nAlice,30\n");
493
494        // Test with whitespace before #
495        let data = b"  # Indented comment\nname,age\n";
496        let (preamble_rows, remaining) = skip_preamble(data);
497        assert_eq!(preamble_rows, 1);
498        assert_eq!(remaining, b"name,age\n");
499    }
500
501    #[test]
502    fn test_sniff_with_preamble() {
503        let data = b"# LimeSurvey export\n# Generated 2024-01-01\nname,age,city\nAlice,30,NYC\nBob,25,LA\n";
504        let sniffer = Sniffer::new();
505
506        let metadata = sniffer.sniff_bytes(data).unwrap();
507
508        assert_eq!(metadata.dialect.delimiter, b',');
509        assert!(metadata.dialect.header.has_header_row);
510        assert_eq!(metadata.num_fields, 3);
511    }
512}