Skip to main content

reddb_server/storage/import/
csv.rs

1//! CSV Importer (Phase 1.5 PG parity)
2//!
3//! Imports data from CSV (comma-separated-values) files into Store.
4//! Implements a minimal RFC 4180 subset: quoted fields, escaped quotes
5//! (`""` → `"`), configurable delimiter, and optional header row.
6//!
7//! # Format examples
8//!
9//! ```text
10//! id,name,age
11//! 1,Alice,30
12//! 2,"Bob, Jr.",25
13//! 3,"Say ""hi""",40
14//! ```
15//!
16//! # Type coercion
17//!
18//! Each field is parsed with the following precedence:
19//! 1. Empty string + `treat_empty_as_null=true` → Null
20//! 2. Exact integer (`-?\d+`)  → Value::Integer
21//! 3. Floating point (contains `.` or `e/E`) → Value::Float
22//! 4. Boolean literal (`true`/`false`, case-insensitive) → Value::Boolean
23//! 5. Fallback → Value::Text
24//!
25//! # Usage
26//!
27//! ```rust,ignore
28//! let importer = CsvImporter::new(CsvConfig {
29//!     collection: "users".to_string(),
30//!     has_header: true,
31//!     delimiter: b',',
32//!     ..Default::default()
33//! });
34//! let stats = importer.import_file("users.csv", &mut store)?;
35//! ```
36
37use crate::storage::schema::types::Value;
38use crate::storage::Store;
39use crate::storage::{EntityData, EntityKind, RowData, UnifiedEntity};
40use std::collections::HashMap;
41use std::fs::File;
42use std::io::{BufRead, BufReader, Read};
43use std::path::Path;
44use std::sync::Arc;
45
46/// CSV import configuration
47#[derive(Debug, Clone)]
48pub struct CsvConfig {
49    /// Collection/table name
50    pub collection: String,
51    /// Whether the first row contains column names.
52    /// When false, columns are named `c0`, `c1`, ...
53    pub has_header: bool,
54    /// Field delimiter byte. Default `,`. Common alternates: `;`, `\t`.
55    pub delimiter: u8,
56    /// Quote character used to wrap fields that contain the delimiter or
57    /// newlines. Default `"`. Doubled inside a field is an escaped quote.
58    pub quote: u8,
59    /// Empty (unquoted) fields map to `Value::Null` when true.
60    /// An empty quoted field (`""`) is always `Value::text("")`.
61    pub treat_empty_as_null: bool,
62    /// Batch size (records processed per bulk-insert chunk).
63    pub batch_size: usize,
64    /// Skip lines that fail to parse instead of aborting.
65    pub skip_errors: bool,
66    /// Maximum records to import (None for all).
67    pub max_records: Option<usize>,
68    /// Explicit column names, used when `has_header` is false but the
69    /// caller wants typed names. Takes precedence over `c0`, `c1`, ...
70    pub column_names: Option<Vec<String>>,
71}
72
73impl Default for CsvConfig {
74    fn default() -> Self {
75        Self {
76            collection: "imported".to_string(),
77            has_header: true,
78            delimiter: b',',
79            quote: b'"',
80            treat_empty_as_null: true,
81            batch_size: 1000,
82            skip_errors: false,
83            max_records: None,
84            column_names: None,
85        }
86    }
87}
88
89/// Import statistics
90#[derive(Debug, Clone, Default)]
91pub struct CsvImportStats {
92    pub lines_processed: usize,
93    pub records_imported: usize,
94    pub errors_skipped: usize,
95    pub duration_ms: u64,
96}
97
98/// CSV import error
99#[derive(Debug)]
100pub enum CsvError {
101    Io(String),
102    Parse(String),
103}
104
105impl std::fmt::Display for CsvError {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        match self {
108            CsvError::Io(s) => write!(f, "I/O error: {}", s),
109            CsvError::Parse(s) => write!(f, "parse error: {}", s),
110        }
111    }
112}
113
114impl std::error::Error for CsvError {}
115
116/// CSV importer
117pub struct CsvImporter {
118    config: CsvConfig,
119}
120
121impl CsvImporter {
122    pub fn new(config: CsvConfig) -> Self {
123        Self { config }
124    }
125
126    pub fn with_defaults() -> Self {
127        Self::new(CsvConfig::default())
128    }
129
130    /// Import from a file path.
131    pub fn import_file<P: AsRef<Path>>(
132        &self,
133        path: P,
134        store: &Store,
135    ) -> Result<CsvImportStats, CsvError> {
136        let file = File::open(path.as_ref()).map_err(|e| CsvError::Io(e.to_string()))?;
137        let reader = BufReader::new(file);
138        self.import_reader(reader, store)
139    }
140
141    /// Import from any BufRead implementation.
142    pub fn import_reader<R: BufRead>(
143        &self,
144        mut reader: R,
145        store: &Store,
146    ) -> Result<CsvImportStats, CsvError> {
147        let start = std::time::Instant::now();
148        let mut stats = CsvImportStats::default();
149        let mut buf = String::new();
150        reader
151            .read_to_string(&mut buf)
152            .map_err(|e| CsvError::Io(e.to_string()))?;
153
154        let records = parse_records(&buf, self.config.delimiter, self.config.quote)
155            .map_err(CsvError::Parse)?;
156        let mut iter = records.into_iter();
157
158        // Resolve column names.
159        let headers: Vec<String> = if self.config.has_header {
160            match iter.next() {
161                Some(row) => row,
162                None => {
163                    stats.duration_ms = start.elapsed().as_millis() as u64;
164                    return Ok(stats);
165                }
166            }
167        } else if let Some(names) = &self.config.column_names {
168            names.clone()
169        } else {
170            Vec::new()
171        };
172
173        for (row_idx, row) in iter.enumerate() {
174            stats.lines_processed += 1;
175            if let Some(max) = self.config.max_records {
176                if stats.records_imported >= max {
177                    break;
178                }
179            }
180
181            let column_names: Vec<String> = if headers.is_empty() {
182                (0..row.len()).map(|i| format!("c{i}")).collect()
183            } else {
184                headers.clone()
185            };
186
187            match self.insert_row(&column_names, row, store) {
188                Ok(()) => stats.records_imported += 1,
189                Err(e) => {
190                    if self.config.skip_errors {
191                        stats.errors_skipped += 1;
192                        continue;
193                    }
194                    return Err(CsvError::Parse(format!("row {}: {}", row_idx + 1, e)));
195                }
196            }
197        }
198
199        stats.duration_ms = start.elapsed().as_millis() as u64;
200        Ok(stats)
201    }
202
203    fn insert_row(
204        &self,
205        columns: &[String],
206        values: Vec<String>,
207        store: &Store,
208    ) -> Result<(), String> {
209        let mut named: HashMap<String, Value> = HashMap::with_capacity(values.len());
210        for (i, raw) in values.into_iter().enumerate() {
211            let name = columns.get(i).cloned().unwrap_or_else(|| format!("c{i}"));
212            named.insert(name, coerce_field(&raw, self.config.treat_empty_as_null));
213        }
214
215        let entity_id = store.next_entity_id();
216        let row_id = entity_id.0;
217        let entity = UnifiedEntity::new(
218            entity_id,
219            EntityKind::TableRow {
220                table: Arc::from(self.config.collection.as_str()),
221                row_id,
222            },
223            EntityData::Row(RowData {
224                columns: Vec::new(),
225                named: Some(named),
226                schema: None,
227            }),
228        );
229        store
230            .insert(&self.config.collection, entity)
231            .map(|_| ())
232            .map_err(|e| format!("insert failed: {:?}", e))
233    }
234}
235
236/// Parse an entire CSV buffer into records.
237///
238/// Handles RFC 4180 quoting: a field wrapped in `"` may contain the
239/// delimiter, newlines, and literal `"` escaped as `""`.
240fn parse_records(input: &str, delimiter: u8, quote: u8) -> Result<Vec<Vec<String>>, String> {
241    let bytes = input.as_bytes();
242    let mut records: Vec<Vec<String>> = Vec::new();
243    let mut current_row: Vec<String> = Vec::new();
244    let mut field = String::new();
245    let mut in_quotes = false;
246    let mut i = 0;
247    let len = bytes.len();
248
249    while i < len {
250        let b = bytes[i];
251        if in_quotes {
252            if b == quote {
253                if i + 1 < len && bytes[i + 1] == quote {
254                    // Escaped quote.
255                    field.push(quote as char);
256                    i += 2;
257                } else {
258                    in_quotes = false;
259                    i += 1;
260                }
261            } else {
262                field.push(b as char);
263                i += 1;
264            }
265        } else {
266            if b == quote && field.is_empty() {
267                in_quotes = true;
268                i += 1;
269            } else if b == delimiter {
270                current_row.push(std::mem::take(&mut field));
271                i += 1;
272            } else if b == b'\r' {
273                // Treat \r alone or \r\n as end of record.
274                current_row.push(std::mem::take(&mut field));
275                records.push(std::mem::take(&mut current_row));
276                i += 1;
277                if i < len && bytes[i] == b'\n' {
278                    i += 1;
279                }
280            } else if b == b'\n' {
281                current_row.push(std::mem::take(&mut field));
282                records.push(std::mem::take(&mut current_row));
283                i += 1;
284            } else {
285                field.push(b as char);
286                i += 1;
287            }
288        }
289    }
290
291    if in_quotes {
292        return Err("unterminated quoted field".to_string());
293    }
294    // Flush trailing record (no final newline).
295    if !field.is_empty() || !current_row.is_empty() {
296        current_row.push(field);
297        records.push(current_row);
298    }
299    Ok(records)
300}
301
302/// Coerce a raw CSV field string into the best-matching Value.
303fn coerce_field(raw: &str, treat_empty_as_null: bool) -> Value {
304    if treat_empty_as_null && raw.is_empty() {
305        return Value::Null;
306    }
307    // Integer first — must not have decimal or exponent.
308    if let Ok(n) = raw.parse::<i64>() {
309        if !raw.contains('.') && !raw.contains('e') && !raw.contains('E') {
310            return Value::Integer(n);
311        }
312    }
313    // Float.
314    if let Ok(f) = raw.parse::<f64>() {
315        if raw.contains('.') || raw.contains('e') || raw.contains('E') {
316            return Value::Float(f);
317        }
318    }
319    // Boolean literal.
320    if raw.eq_ignore_ascii_case("true") {
321        return Value::Boolean(true);
322    }
323    if raw.eq_ignore_ascii_case("false") {
324        return Value::Boolean(false);
325    }
326    // Fallback.
327    Value::text(raw.to_string())
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn parse_simple_csv() {
336        let input = "id,name,age\n1,Alice,30\n2,Bob,25\n";
337        let records = parse_records(input, b',', b'"').unwrap();
338        assert_eq!(records.len(), 3);
339        assert_eq!(records[0], vec!["id", "name", "age"]);
340        assert_eq!(records[1], vec!["1", "Alice", "30"]);
341        assert_eq!(records[2], vec!["2", "Bob", "25"]);
342    }
343
344    #[test]
345    fn parse_quoted_and_escaped_fields() {
346        let input = "id,note\n1,\"hello, world\"\n2,\"say \"\"hi\"\"\"\n";
347        let records = parse_records(input, b',', b'"').unwrap();
348        assert_eq!(records.len(), 3);
349        assert_eq!(records[1], vec!["1", "hello, world"]);
350        assert_eq!(records[2], vec!["2", "say \"hi\""]);
351    }
352
353    #[test]
354    fn parse_alternate_delimiter() {
355        let input = "a;b;c\n1;2;3\n";
356        let records = parse_records(input, b';', b'"').unwrap();
357        assert_eq!(records[1], vec!["1", "2", "3"]);
358    }
359
360    #[test]
361    fn parse_crlf_newlines() {
362        let input = "a,b\r\n1,2\r\n";
363        let records = parse_records(input, b',', b'"').unwrap();
364        assert_eq!(records.len(), 2);
365    }
366
367    #[test]
368    fn parse_no_trailing_newline() {
369        let input = "a,b\n1,2";
370        let records = parse_records(input, b',', b'"').unwrap();
371        assert_eq!(records.len(), 2);
372        assert_eq!(records[1], vec!["1", "2"]);
373    }
374
375    #[test]
376    fn parse_unterminated_quote_errors() {
377        let input = "a,\"unclosed\n";
378        assert!(parse_records(input, b',', b'"').is_err());
379    }
380
381    #[test]
382    fn coerce_int_float_bool_text_null() {
383        assert_eq!(coerce_field("42", true), Value::Integer(42));
384        assert_eq!(coerce_field("-17", true), Value::Integer(-17));
385        assert_eq!(coerce_field("3.14", true), Value::Float(3.14));
386        assert_eq!(coerce_field("1e3", true), Value::Float(1000.0));
387        assert_eq!(coerce_field("TRUE", true), Value::Boolean(true));
388        assert_eq!(coerce_field("False", true), Value::Boolean(false));
389        assert_eq!(
390            coerce_field("hello", true),
391            Value::text("hello".to_string())
392        );
393        assert_eq!(coerce_field("", true), Value::Null);
394        assert_eq!(coerce_field("", false), Value::text(String::new()));
395    }
396}