1use std::collections::{HashMap, HashSet};
15use std::io::{self, BufRead, Cursor};
16use std::sync::Arc;
17
18use minarrow::ffi::arrow_dtype::CategoricalIndexType;
19use minarrow::{
20 Array, ArrowType, Bitmask, Buffer, Field, FieldArray, FloatArray, IntegerArray, NumericArray,
21 Table, TextArray, Vec64, vec64,
22};
23
24#[derive(Debug, Clone)]
26pub struct CsvDecodeOptions {
27 pub delimiter: u8,
29 pub nulls: Vec<&'static str>,
31 pub quote: u8,
33 pub has_header: bool,
35 pub schema: Option<Vec<Field>>,
37 pub all_as_text: bool,
39 pub categorical_cols: HashSet<String>,
41}
42
43impl Default for CsvDecodeOptions {
44 fn default() -> Self {
45 CsvDecodeOptions {
46 delimiter: b',',
47 nulls: vec!["", "NA", "null", "NULL"],
48 quote: b'"',
49 has_header: true,
50 schema: None,
51 all_as_text: false,
52 categorical_cols: HashSet::new(),
53 }
54 }
55}
56
57pub fn decode_csv_batch<R: BufRead>(
61 reader: &mut R,
62 options: &CsvDecodeOptions,
63 batch_size: usize,
64) -> io::Result<Option<Table>> {
65 let opts = options.clone();
66 let need_header = opts.has_header;
67 let mut buf = Vec::new();
68 let mut chunk = Vec::new();
69 let mut saw_any = false;
70 let mut lines_to_read = batch_size;
71 if need_header {
72 lines_to_read += 1;
74 }
75
76 for _ in 0..lines_to_read {
77 buf.clear();
78 let n = reader.read_until(b'\n', &mut buf)?;
79 if n == 0 {
80 break;
81 }
82 if buf.ends_with(b"\r\n") {
84 buf.truncate(buf.len() - 2);
85 } else if buf.ends_with(b"\n") {
86 buf.truncate(buf.len() - 1);
87 }
88 if buf.is_empty() && !saw_any {
90 continue;
91 }
92 saw_any = true;
93 chunk.extend_from_slice(&buf);
94 chunk.push(b'\n');
95 }
96
97 if !saw_any {
98 return Ok(None);
100 }
101
102 let table = decode_csv(Cursor::new(chunk), &opts)?;
104 Ok(Some(table))
105}
106
107pub fn decode_csv<R: BufRead>(mut reader: R, options: &CsvDecodeOptions) -> io::Result<Table> {
118 let CsvDecodeOptions {
119 delimiter,
120 nulls,
121 quote,
122 has_header,
123 schema,
124 all_as_text,
125 categorical_cols,
126 } = options.clone();
127
128 let mut header: Vec<String> = Vec::new();
129 let mut rows: Vec<Vec<String>> = Vec::new();
130 let mut buf = Vec::new();
131
132 let mut first_row_is_header = false;
134 let mut col_count = 0;
135 loop {
136 buf.clear();
137 let n = reader.read_until(b'\n', &mut buf)?;
138 if n == 0 {
139 break;
140 }
141 let mut quote_balance = buf.iter().filter(|&&b| b == quote).count() % 2;
142 while quote_balance == 1
143 {
145 let m = reader.read_until(b'\n', &mut buf)?;
146 if m == 0 {
147 break;
148 } quote_balance ^= buf[n..].iter().filter(|&&b| b == quote).count() % 2;
150 }
151
152 let line = {
154 let l = if let Some(&b'\r') = buf.get(buf.len().saturating_sub(2)) {
155 &buf[..buf.len() - 2]
156 } else if buf.last() == Some(&b'\n') {
157 &buf[..buf.len() - 1]
158 } else {
159 &buf[..]
160 };
161 l
162 };
163
164 if line.is_empty() && rows.is_empty() {
165 continue;
166 } let fields = parse_csv_line(line, delimiter, quote);
169 if fields.is_empty() {
170 continue;
171 }
172
173 if header.is_empty() && has_header {
174 header = fields;
176 col_count = header.len();
177 first_row_is_header = true;
178 } else {
179 if col_count == 0 {
181 col_count = fields.len();
182 }
183 if fields.len() != col_count {
184 return Err(io::Error::new(
185 io::ErrorKind::InvalidData,
186 "inconsistent row length",
187 ));
188 }
189 rows.push(fields);
190 }
191 }
192
193 let col_names: Vec<String> = if first_row_is_header {
195 header
196 } else {
197 (0..col_count).map(|i| format!("col{}", i + 1)).collect()
198 };
199
200 let n_rows = rows.len();
201
202 let schema: Vec<Field> = if let Some(fields) = schema {
204 fields
205 } else if all_as_text {
206 col_names
207 .iter()
208 .map(|name| Field {
209 name: name.clone(),
210 dtype: ArrowType::String,
211 nullable: true,
212 metadata: Default::default(),
213 })
214 .collect()
215 } else {
216 infer_schema(&rows, &col_names, &categorical_cols, &nulls)
217 };
218
219 let mut cols: Vec<FieldArray> = Vec::with_capacity(col_count);
221 for (col_idx, field) in schema.iter().enumerate() {
222 let mut null_mask = vec![true; n_rows]; let mut str_values: Vec<Option<&str>> = Vec::with_capacity(n_rows);
224
225 for row in 0..n_rows {
226 let val = rows[row][col_idx].trim();
227 let is_null = nulls.iter().any(|n| n.eq_ignore_ascii_case(val));
228 if is_null {
229 null_mask[row] = false; str_values.push(None);
231 } else {
232 str_values.push(Some(val));
233 }
234 }
235
236 let array = match &field.dtype {
237 ArrowType::Int32 => parse_numeric_column::<i32>(&str_values, &null_mask)?,
238 ArrowType::Int64 => parse_numeric_column::<i64>(&str_values, &null_mask)?,
239 ArrowType::UInt32 => parse_numeric_column::<u32>(&str_values, &null_mask)?,
240 ArrowType::UInt64 => parse_numeric_column::<u64>(&str_values, &null_mask)?,
241 ArrowType::Float32 => parse_numeric_column::<f32>(&str_values, &null_mask)?,
242 ArrowType::Float64 => parse_numeric_column::<f64>(&str_values, &null_mask)?,
243 ArrowType::Boolean => parse_bool_column(&str_values, &null_mask)?,
244 ArrowType::String => parse_string_column(&str_values, &null_mask)?,
245 ArrowType::Dictionary(_) => {
246 parse_categorical_column(&str_values, &null_mask)?
248 }
249 _ => {
250 parse_string_column(&str_values, &null_mask)?
252 }
253 };
254
255 let null_count = null_mask.iter().filter(|x| !**x).count();
257
258 cols.push(FieldArray {
259 field: Arc::new(field.clone()),
260 array,
261 null_count,
262 });
263 }
264
265 Ok(Table {
266 name: "csv".to_string(),
267 cols,
268 n_rows,
269 })
270}
271
272#[inline]
274fn parse_csv_line(line: &[u8], delimiter: u8, quote: u8) -> Vec<String> {
275 let mut fields = Vec::new();
276 let mut field = Vec::with_capacity(32);
277 let mut in_quotes = false;
278 let mut i = 0;
279 while i < line.len() {
280 let b = line[i];
281 if in_quotes {
282 if b == quote {
283 if i + 1 < line.len() && line[i + 1] == quote {
284 field.push(quote);
286 i += 1;
287 } else {
288 in_quotes = false;
289 }
290 } else {
291 field.push(b);
292 }
293 } else if b == quote {
294 in_quotes = true;
295 } else if b == delimiter {
296 fields.push(String::from_utf8_lossy(&field).into_owned());
297 field.clear();
298 } else {
299 field.push(b);
300 }
301 i += 1;
302 }
303 fields.push(String::from_utf8_lossy(&field).into_owned());
304 fields
305}
306
307fn infer_schema(
310 rows: &[Vec<String>],
311 col_names: &[String],
312 categorical_cols: &HashSet<String>,
313 nulls: &[&'static str],
314) -> Vec<Field> {
315 let n_cols = col_names.len();
316 let mut types: Vec<ArrowType> = vec![ArrowType::String; n_cols];
317 for col in 0..n_cols {
318 let mut is_bool = true;
319 let mut is_i32 = true;
320 let mut is_i64 = true;
321 let mut is_u32 = true;
322 let mut is_u64 = true;
323 let mut is_f32 = true;
324 let mut is_f64 = true;
325 let is_cat = categorical_cols.contains(&col_names[col]);
326
327 for row in rows {
328 let val = row[col].trim();
329 if nulls.iter().any(|n| n.eq_ignore_ascii_case(val)) {
330 continue;
331 }
332 if is_bool && !matches!(val, "true" | "false" | "1" | "0" | "t" | "f" | "T" | "F") {
333 is_bool = false;
334 }
335 if is_i32 && val.parse::<i32>().is_err() {
336 is_i32 = false;
337 }
338 if is_i64 && val.parse::<i64>().is_err() {
339 is_i64 = false;
340 }
341 if is_u32 && val.parse::<u32>().is_err() {
342 is_u32 = false;
343 }
344 if is_u64 && val.parse::<u64>().is_err() {
345 is_u64 = false;
346 }
347 if is_f32 && val.parse::<f32>().is_err() {
348 is_f32 = false;
349 }
350 if is_f64 && val.parse::<f64>().is_err() {
351 is_f64 = false;
352 }
353 }
354
355 types[col] = if is_bool {
356 ArrowType::Boolean
357 } else if is_i32 {
358 ArrowType::Int32
359 } else if is_i64 {
360 ArrowType::Int64
361 } else if is_u32 {
362 ArrowType::UInt32
363 } else if is_u64 {
364 ArrowType::UInt64
365 } else if is_f32 {
366 ArrowType::Float32
367 } else if is_f64 {
368 ArrowType::Float64
369 } else if is_cat {
370 ArrowType::Dictionary(CategoricalIndexType::UInt32)
371 } else {
372 ArrowType::String
373 };
374 }
375
376 col_names
377 .iter()
378 .enumerate()
379 .map(|(i, name)| Field {
380 name: name.clone(),
381 dtype: types[i].clone(),
382 nullable: true,
383 metadata: Default::default(),
384 })
385 .collect()
386}
387
388fn mask_to_bitmask(mask: &[bool]) -> Bitmask {
391 Bitmask::from_bools(mask)
392}
393
394fn parse_numeric_column<T: std::str::FromStr + Copy + Default + 'static>(
396 values: &[Option<&str>],
397 null_mask: &[bool],
398) -> std::io::Result<Array> {
399 let mut out = vec64![T::default(); values.len()];
400 for (i, v) in values.iter().enumerate() {
401 if null_mask[i] {
402 out[i] = v.unwrap().parse::<T>().map_err(|_| {
404 std::io::Error::new(std::io::ErrorKind::InvalidData, "failed to parse number")
405 })?;
406 }
407 }
408
409 let arr = if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i32>() {
411 Array::NumericArray(NumericArray::Int32(
412 IntegerArray {
413 data: Buffer::from(
414 unsafe { std::mem::transmute::<Vec64<T>, Vec64<i32>>(out) },
416 ),
417 null_mask: Some(mask_to_bitmask(null_mask)),
418 }
419 .into(),
420 ))
421 } else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<i64>() {
422 Array::NumericArray(NumericArray::Int64(
423 IntegerArray {
424 data: Buffer::from(unsafe { std::mem::transmute::<Vec64<T>, Vec64<i64>>(out) }),
425 null_mask: Some(mask_to_bitmask(null_mask)),
426 }
427 .into(),
428 ))
429 } else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u32>() {
430 Array::NumericArray(NumericArray::UInt32(
431 IntegerArray {
432 data: Buffer::from(unsafe { std::mem::transmute::<Vec64<T>, Vec64<u32>>(out) }),
433 null_mask: Some(mask_to_bitmask(null_mask)),
434 }
435 .into(),
436 ))
437 } else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<u64>() {
438 Array::NumericArray(NumericArray::UInt64(
439 IntegerArray {
440 data: Buffer::from(unsafe { std::mem::transmute::<Vec64<T>, Vec64<u64>>(out) }),
441 null_mask: Some(mask_to_bitmask(null_mask)),
442 }
443 .into(),
444 ))
445 } else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
446 Array::NumericArray(NumericArray::Float32(
447 FloatArray {
448 data: Buffer::from(unsafe { std::mem::transmute::<Vec64<T>, Vec64<f32>>(out) }),
449 null_mask: Some(mask_to_bitmask(null_mask)),
450 }
451 .into(),
452 ))
453 } else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
454 Array::NumericArray(NumericArray::Float64(
455 FloatArray {
456 data: Buffer::from(unsafe { std::mem::transmute::<Vec64<T>, Vec64<f64>>(out) }),
457 null_mask: Some(mask_to_bitmask(null_mask)),
458 }
459 .into(),
460 ))
461 } else {
462 return Err(std::io::Error::new(
463 std::io::ErrorKind::InvalidInput,
464 "unsupported numeric type",
465 ));
466 };
467
468 Ok(arr)
469}
470
471fn parse_bool_column(values: &[Option<&str>], null_mask: &[bool]) -> std::io::Result<Array> {
473 let mut out = vec64![false; values.len()];
474 for (i, v) in values.iter().enumerate() {
475 if null_mask[i] {
476 let s = v.unwrap().to_ascii_lowercase();
478 out[i] = s == "true" || s == "1" || s == "t";
479 }
480 }
481 Ok(Array::BooleanArray(
482 minarrow::BooleanArray::new(Bitmask::from_bools(&out), Some(mask_to_bitmask(null_mask)))
483 .into(),
484 ))
485}
486
487fn parse_string_column(values: &[Option<&str>], null_mask: &[bool]) -> io::Result<Array> {
488 let mut offsets = vec64![0u32; values.len() + 1];
489 let mut data = Vec64::with_capacity(values.len() * 8);
490 let mut pos = 0u32;
491 for (i, v) in values.iter().enumerate() {
492 if null_mask[i] {
493 let s = v.unwrap().as_bytes();
495 data.extend_from_slice(s);
496 pos += s.len() as u32;
497 }
498 offsets[i + 1] = pos;
499 }
500 Ok(Array::TextArray(TextArray::String32(
501 minarrow::StringArray {
502 offsets: Buffer::from(offsets),
503 data: Buffer::from(data),
504 null_mask: Some(mask_to_bitmask(null_mask)),
505 }
506 .into(),
507 )))
508}
509
510fn parse_categorical_column(values: &[Option<&str>], null_mask: &[bool]) -> io::Result<Array> {
511 let mut uniques: Vec<String> = Vec::new();
512 let mut dict: HashMap<&str, u32> = HashMap::new();
513 let mut codes = vec64![0u32; values.len()];
514
515 for (i, v) in values.iter().enumerate() {
516 if !null_mask[i] {
517 continue;
519 }
520 let s = v.unwrap();
521 let code = if let Some(&idx) = dict.get(s) {
522 idx
523 } else {
524 let idx = uniques.len() as u32;
525 dict.insert(s, idx);
526 uniques.push(s.to_string());
527 idx
528 };
529 codes[i] = code;
530 }
531 Ok(Array::TextArray(TextArray::Categorical32(
532 minarrow::CategoricalArray {
533 data: Buffer::from(codes),
534 unique_values: uniques.into(),
535 null_mask: Some(mask_to_bitmask(null_mask)),
536 }
537 .into(),
538 )))
539}
540
541#[cfg(test)]
542mod tests {
543 use std::io::Cursor;
544
545 use super::*;
546
547 #[test]
548 fn test_decode_basic_csv() {
549 let csv = b"ints,strings,bools\n1,hello,true\n2,,false\n3,world,1\n4,rust,0\n";
550 let opts = CsvDecodeOptions::default();
551 let table = decode_csv(Cursor::new(&csv[..]), &opts).unwrap();
552
553 assert_eq!(table.n_rows, 4);
554 assert_eq!(table.cols.len(), 3);
555 assert_eq!(table.cols[0].field.name, "ints");
556 assert_eq!(table.cols[1].field.name, "strings");
557
558 match &table.cols[0].array {
560 Array::NumericArray(NumericArray::Int32(arr)) => {
561 let vals: Vec64<_> = arr.data.as_ref().iter().copied().collect();
562 assert_eq!(vals, vec64![1, 2, 3, 4]);
563 }
564 _ => panic!("wrong type"),
565 }
566
567 match &table.cols[2].array {
569 Array::BooleanArray(arr) => {
570 let actual: Vec<bool> = (0..arr.data.len).map(|i| arr.data.get(i)).collect();
571 assert_eq!(actual, vec![true, false, true, false]);
572 }
573 _ => panic!("wrong type"),
574 }
575
576 match &table.cols[1].array {
579 Array::TextArray(TextArray::String32(arr)) => {
580 assert_eq!(arr.null_mask.as_ref().unwrap().count_ones(), 3); assert_eq!(table.cols[1].null_count, 1); }
583 _ => panic!("wrong type"),
584 }
585 }
586
587 #[test]
588 fn test_decode_csv_custom_delim_and_quotes() {
589 let csv = b"i|s|b\n1|\"h|ello\"|T\n2||f\n";
590 let mut opts = CsvDecodeOptions::default();
591 opts.delimiter = b'|';
592 let table = decode_csv(Cursor::new(&csv[..]), &opts).unwrap();
593 assert_eq!(table.n_rows, 2);
594 match &table.cols[1].array {
595 Array::TextArray(TextArray::String32(arr)) => {
596 let s = std::str::from_utf8(&arr.data.as_ref()[..]).unwrap();
597 assert!(s.contains("h|ello"));
598 }
599 _ => panic!("wrong type"),
600 }
601 }
602
603 #[test]
604 fn test_decode_csv_batch_basic() {
605 use std::io::Cursor;
606 let csv = b"col1,col2\n10,A\n20,B\n30,C\n";
608 let mut reader = Cursor::new(&csv[..]);
609 let mut opts = CsvDecodeOptions::default();
610
611 let batch1 = decode_csv_batch(&mut reader, &opts, 2)
613 .unwrap()
614 .expect("first batch should be Some");
615 assert_eq!(batch1.n_rows, 2);
616 assert_eq!(batch1.cols[0].field.name, "col1");
618 assert_eq!(batch1.cols[1].field.name, "col2");
619 match &batch1.cols[0].array {
621 Array::NumericArray(NumericArray::Int32(arr)) => {
622 let v: Vec<i32> = arr.data.as_ref().iter().copied().collect();
623 assert_eq!(v, vec![10, 20]);
624 }
625 _ => panic!("wrong type for col1"),
626 }
627 match &batch1.cols[1].array {
628 Array::TextArray(TextArray::String32(arr)) => {
629 let s = std::str::from_utf8(&arr.data.as_ref()[..]).unwrap();
630 assert!(s.starts_with("AB")); }
632 _ => panic!("wrong type for col2"),
633 }
634
635 opts.has_header = false;
637 let batch2 = decode_csv_batch(&mut reader, &opts, 2)
638 .unwrap()
639 .expect("second batch should be Some");
640 assert_eq!(batch2.n_rows, 1);
642 match &batch2.cols[0].array {
643 Array::NumericArray(NumericArray::Int32(arr)) => {
644 assert_eq!(arr.data.as_ref()[0], 30);
645 }
646 _ => panic!(),
647 }
648
649 let batch3 = decode_csv_batch(&mut reader, &opts, 2).unwrap();
651 assert!(batch3.is_none());
652 }
653
654 #[test]
655 fn decode_escaped_quotes() {
656 use crate::models::decoders::csv::decode_csv;
657 let csv = b"id,msg\n1,\"She said \"\"hi\"\" yesterday\"\n";
658 let table = decode_csv(std::io::Cursor::new(csv.as_ref()), &Default::default()).unwrap();
659 match &table.cols[1].array {
660 Array::TextArray(TextArray::String32(arr)) => {
661 let text = std::str::from_utf8(&arr.data.as_ref()[..]).unwrap();
662 assert_eq!(text, "She said \"hi\" yesterday");
663 }
664 _ => panic!(),
665 }
666 }
667
668 #[test]
669 fn decode_embedded_newline() {
670 use crate::models::decoders::csv::decode_csv;
671 let csv = b"id,comment\n1,\"line1\nline2\"\n";
672 let tbl = decode_csv(std::io::Cursor::new(csv.as_ref()), &Default::default()).unwrap();
674 match &tbl.cols[1].array {
675 Array::TextArray(TextArray::String32(arr)) => {
676 let text = std::str::from_utf8(&arr.data.as_ref()[..]).unwrap();
677 assert_eq!(text, "line1\nline2");
678 }
679 _ => panic!(),
680 }
681 }
682
683 #[test]
684 fn decode_with_explicit_schema() {
685 use crate::models::decoders::csv::{CsvDecodeOptions, decode_csv};
686 use minarrow::{ArrowType, Field};
687 let csv = b"a,b\n001,1.23\n";
688 let schema = vec![
689 Field::new("a", ArrowType::String, false, None),
690 Field::new("b", ArrowType::Float64, false, None),
691 ];
692 let opts = CsvDecodeOptions {
693 schema: Some(schema.clone()),
694 ..Default::default()
695 };
696 let tbl = decode_csv(std::io::Cursor::new(csv.as_ref()), &opts).unwrap();
697 assert_eq!(tbl.cols[0].field.dtype, ArrowType::String); }
699
700 #[test]
701 fn decode_no_header() {
702 use crate::models::decoders::csv::{CsvDecodeOptions, decode_csv};
703 let csv = b"10,20\n30,40\n";
704 let opts = CsvDecodeOptions {
705 has_header: false,
706 ..Default::default()
707 };
708 let t = decode_csv(std::io::Cursor::new(csv.as_ref()), &opts).unwrap();
709 assert_eq!(t.cols[0].field.name, "col1");
710 assert_eq!(t.n_rows, 2);
711 }
712}