reddb_server/storage/import/
csv.rs1use crate::storage::schema::types::Value;
38use crate::storage::Store;
39use crate::storage::{EntityData, EntityKind, RowData, UnifiedEntity};
40use std::collections::HashMap;
41use std::fs::File;
42use std::io::{BufRead, BufReader, Read};
43use std::path::Path;
44use std::sync::Arc;
45
46#[derive(Debug, Clone)]
48pub struct CsvConfig {
49 pub collection: String,
51 pub has_header: bool,
54 pub delimiter: u8,
56 pub quote: u8,
59 pub treat_empty_as_null: bool,
62 pub batch_size: usize,
64 pub skip_errors: bool,
66 pub max_records: Option<usize>,
68 pub column_names: Option<Vec<String>>,
71}
72
73impl Default for CsvConfig {
74 fn default() -> Self {
75 Self {
76 collection: "imported".to_string(),
77 has_header: true,
78 delimiter: b',',
79 quote: b'"',
80 treat_empty_as_null: true,
81 batch_size: 1000,
82 skip_errors: false,
83 max_records: None,
84 column_names: None,
85 }
86 }
87}
88
89#[derive(Debug, Clone, Default)]
91pub struct CsvImportStats {
92 pub lines_processed: usize,
93 pub records_imported: usize,
94 pub errors_skipped: usize,
95 pub duration_ms: u64,
96}
97
98#[derive(Debug)]
100pub enum CsvError {
101 Io(String),
102 Parse(String),
103}
104
105impl std::fmt::Display for CsvError {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 match self {
108 CsvError::Io(s) => write!(f, "I/O error: {}", s),
109 CsvError::Parse(s) => write!(f, "parse error: {}", s),
110 }
111 }
112}
113
114impl std::error::Error for CsvError {}
115
116pub struct CsvImporter {
118 config: CsvConfig,
119}
120
121impl CsvImporter {
122 pub fn new(config: CsvConfig) -> Self {
123 Self { config }
124 }
125
126 pub fn with_defaults() -> Self {
127 Self::new(CsvConfig::default())
128 }
129
130 pub fn import_file<P: AsRef<Path>>(
132 &self,
133 path: P,
134 store: &Store,
135 ) -> Result<CsvImportStats, CsvError> {
136 let file = File::open(path.as_ref()).map_err(|e| CsvError::Io(e.to_string()))?;
137 let reader = BufReader::new(file);
138 self.import_reader(reader, store)
139 }
140
141 pub fn import_reader<R: BufRead>(
143 &self,
144 mut reader: R,
145 store: &Store,
146 ) -> Result<CsvImportStats, CsvError> {
147 let start = std::time::Instant::now();
148 let mut stats = CsvImportStats::default();
149 let mut buf = String::new();
150 reader
151 .read_to_string(&mut buf)
152 .map_err(|e| CsvError::Io(e.to_string()))?;
153
154 let records = parse_records(&buf, self.config.delimiter, self.config.quote)
155 .map_err(CsvError::Parse)?;
156 let mut iter = records.into_iter();
157
158 let headers: Vec<String> = if self.config.has_header {
160 match iter.next() {
161 Some(row) => row,
162 None => {
163 stats.duration_ms = start.elapsed().as_millis() as u64;
164 return Ok(stats);
165 }
166 }
167 } else if let Some(names) = &self.config.column_names {
168 names.clone()
169 } else {
170 Vec::new()
171 };
172
173 for (row_idx, row) in iter.enumerate() {
174 stats.lines_processed += 1;
175 if let Some(max) = self.config.max_records {
176 if stats.records_imported >= max {
177 break;
178 }
179 }
180
181 let column_names: Vec<String> = if headers.is_empty() {
182 (0..row.len()).map(|i| format!("c{i}")).collect()
183 } else {
184 headers.clone()
185 };
186
187 match self.insert_row(&column_names, row, store) {
188 Ok(()) => stats.records_imported += 1,
189 Err(e) => {
190 if self.config.skip_errors {
191 stats.errors_skipped += 1;
192 continue;
193 }
194 return Err(CsvError::Parse(format!("row {}: {}", row_idx + 1, e)));
195 }
196 }
197 }
198
199 stats.duration_ms = start.elapsed().as_millis() as u64;
200 Ok(stats)
201 }
202
203 fn insert_row(
204 &self,
205 columns: &[String],
206 values: Vec<String>,
207 store: &Store,
208 ) -> Result<(), String> {
209 let mut named: HashMap<String, Value> = HashMap::with_capacity(values.len());
210 for (i, raw) in values.into_iter().enumerate() {
211 let name = columns.get(i).cloned().unwrap_or_else(|| format!("c{i}"));
212 named.insert(name, coerce_field(&raw, self.config.treat_empty_as_null));
213 }
214
215 let entity_id = store.next_entity_id();
216 let row_id = entity_id.0;
217 let entity = UnifiedEntity::new(
218 entity_id,
219 EntityKind::TableRow {
220 table: Arc::from(self.config.collection.as_str()),
221 row_id,
222 },
223 EntityData::Row(RowData {
224 columns: Vec::new(),
225 named: Some(named),
226 schema: None,
227 }),
228 );
229 store
230 .insert(&self.config.collection, entity)
231 .map(|_| ())
232 .map_err(|e| format!("insert failed: {:?}", e))
233 }
234}
235
236fn parse_records(input: &str, delimiter: u8, quote: u8) -> Result<Vec<Vec<String>>, String> {
241 let bytes = input.as_bytes();
242 let mut records: Vec<Vec<String>> = Vec::new();
243 let mut current_row: Vec<String> = Vec::new();
244 let mut field = String::new();
245 let mut in_quotes = false;
246 let mut i = 0;
247 let len = bytes.len();
248
249 while i < len {
250 let b = bytes[i];
251 if in_quotes {
252 if b == quote {
253 if i + 1 < len && bytes[i + 1] == quote {
254 field.push(quote as char);
256 i += 2;
257 } else {
258 in_quotes = false;
259 i += 1;
260 }
261 } else {
262 field.push(b as char);
263 i += 1;
264 }
265 } else {
266 if b == quote && field.is_empty() {
267 in_quotes = true;
268 i += 1;
269 } else if b == delimiter {
270 current_row.push(std::mem::take(&mut field));
271 i += 1;
272 } else if b == b'\r' {
273 current_row.push(std::mem::take(&mut field));
275 records.push(std::mem::take(&mut current_row));
276 i += 1;
277 if i < len && bytes[i] == b'\n' {
278 i += 1;
279 }
280 } else if b == b'\n' {
281 current_row.push(std::mem::take(&mut field));
282 records.push(std::mem::take(&mut current_row));
283 i += 1;
284 } else {
285 field.push(b as char);
286 i += 1;
287 }
288 }
289 }
290
291 if in_quotes {
292 return Err("unterminated quoted field".to_string());
293 }
294 if !field.is_empty() || !current_row.is_empty() {
296 current_row.push(field);
297 records.push(current_row);
298 }
299 Ok(records)
300}
301
302fn coerce_field(raw: &str, treat_empty_as_null: bool) -> Value {
304 if treat_empty_as_null && raw.is_empty() {
305 return Value::Null;
306 }
307 if let Ok(n) = raw.parse::<i64>() {
309 if !raw.contains('.') && !raw.contains('e') && !raw.contains('E') {
310 return Value::Integer(n);
311 }
312 }
313 if let Ok(f) = raw.parse::<f64>() {
315 if raw.contains('.') || raw.contains('e') || raw.contains('E') {
316 return Value::Float(f);
317 }
318 }
319 if raw.eq_ignore_ascii_case("true") {
321 return Value::Boolean(true);
322 }
323 if raw.eq_ignore_ascii_case("false") {
324 return Value::Boolean(false);
325 }
326 Value::text(raw.to_string())
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn parse_simple_csv() {
336 let input = "id,name,age\n1,Alice,30\n2,Bob,25\n";
337 let records = parse_records(input, b',', b'"').unwrap();
338 assert_eq!(records.len(), 3);
339 assert_eq!(records[0], vec!["id", "name", "age"]);
340 assert_eq!(records[1], vec!["1", "Alice", "30"]);
341 assert_eq!(records[2], vec!["2", "Bob", "25"]);
342 }
343
344 #[test]
345 fn parse_quoted_and_escaped_fields() {
346 let input = "id,note\n1,\"hello, world\"\n2,\"say \"\"hi\"\"\"\n";
347 let records = parse_records(input, b',', b'"').unwrap();
348 assert_eq!(records.len(), 3);
349 assert_eq!(records[1], vec!["1", "hello, world"]);
350 assert_eq!(records[2], vec!["2", "say \"hi\""]);
351 }
352
353 #[test]
354 fn parse_alternate_delimiter() {
355 let input = "a;b;c\n1;2;3\n";
356 let records = parse_records(input, b';', b'"').unwrap();
357 assert_eq!(records[1], vec!["1", "2", "3"]);
358 }
359
360 #[test]
361 fn parse_crlf_newlines() {
362 let input = "a,b\r\n1,2\r\n";
363 let records = parse_records(input, b',', b'"').unwrap();
364 assert_eq!(records.len(), 2);
365 }
366
367 #[test]
368 fn parse_no_trailing_newline() {
369 let input = "a,b\n1,2";
370 let records = parse_records(input, b',', b'"').unwrap();
371 assert_eq!(records.len(), 2);
372 assert_eq!(records[1], vec!["1", "2"]);
373 }
374
375 #[test]
376 fn parse_unterminated_quote_errors() {
377 let input = "a,\"unclosed\n";
378 assert!(parse_records(input, b',', b'"').is_err());
379 }
380
381 #[test]
382 fn coerce_int_float_bool_text_null() {
383 assert_eq!(coerce_field("42", true), Value::Integer(42));
384 assert_eq!(coerce_field("-17", true), Value::Integer(-17));
385 assert_eq!(coerce_field("3.14", true), Value::Float(3.14));
386 assert_eq!(coerce_field("1e3", true), Value::Float(1000.0));
387 assert_eq!(coerce_field("TRUE", true), Value::Boolean(true));
388 assert_eq!(coerce_field("False", true), Value::Boolean(false));
389 assert_eq!(
390 coerce_field("hello", true),
391 Value::text("hello".to_string())
392 );
393 assert_eq!(coerce_field("", true), Value::Null);
394 assert_eq!(coerce_field("", false), Value::text(String::new()));
395 }
396}