Skip to main content

rivet/format/
csv.rs

1use std::io::Write;
2
3use arrow::array::Time64MicrosecondArray;
4use arrow::array::types::Decimal128Type;
5use arrow::array::*;
6use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
7use arrow::record_batch::RecordBatch;
8
9use crate::error::Result;
10use crate::types::decimal::scaled_i128_to_decimal_str;
11
12pub struct CsvFormat;
13
14pub struct CsvFormatWriter {
15    writer: Box<dyn Write + Send>,
16    bytes_written: u64,
17}
18
19impl super::Format for CsvFormat {
20    fn create_writer(
21        &self,
22        schema: &SchemaRef,
23        mut writer: Box<dyn Write + Send>,
24    ) -> Result<Box<dyn super::FormatWriter + Send>> {
25        // Fail loud: arrays and other nested/wide Arrow types have no CSV cell
26        // representation. Reject them up front, naming the column, instead of
27        // silently writing empty values for every row — `format: parquet` or
28        // excluding the column from the query is the fix.
29        if let Some(field) = schema
30            .fields()
31            .iter()
32            .find(|f| !csv_serializable(f.data_type()))
33        {
34            anyhow::bail!(
35                "CSV cannot serialize column '{}' (Arrow type {:?}); use `format: parquet` \
36                 or drop the column from the query",
37                field.name(),
38                field.data_type()
39            );
40        }
41        let header = schema
42            .fields()
43            .iter()
44            .map(|f| f.name().as_str())
45            .collect::<Vec<_>>()
46            .join(",");
47        let header_bytes = header.len() as u64 + 1; // +1 for newline
48        writeln!(writer, "{}", header)?;
49        Ok(Box::new(CsvFormatWriter {
50            writer,
51            bytes_written: header_bytes,
52        }))
53    }
54
55    fn file_extension(&self) -> &str {
56        "csv"
57    }
58}
59
60impl super::FormatWriter for CsvFormatWriter {
61    fn write_batch(&mut self, batch: &RecordBatch) -> Result<()> {
62        let mut buf = Vec::with_capacity(batch.num_rows() * batch.num_columns() * 8);
63        for row_idx in 0..batch.num_rows() {
64            for col_idx in 0..batch.num_columns() {
65                if col_idx > 0 {
66                    buf.push(b',');
67                }
68                write_csv_value(&mut buf, batch.column(col_idx), row_idx)?;
69            }
70            buf.push(b'\n');
71        }
72        self.bytes_written += buf.len() as u64;
73        self.writer.write_all(&buf)?;
74        Ok(())
75    }
76
77    fn finish(self: Box<Self>) -> Result<()> {
78        Ok(())
79    }
80
81    fn bytes_written(&self) -> u64 {
82        self.bytes_written
83    }
84}
85
86/// Arrow types `write_csv_value` can serialize. Everything else — lists,
87/// structs, maps, `Decimal256`, non-UUID fixed binary, … — has no CSV cell
88/// representation and is rejected at writer creation rather than silently
89/// emitted as an empty value.
90pub(crate) fn csv_serializable(dt: &DataType) -> bool {
91    matches!(
92        dt,
93        DataType::Boolean
94            | DataType::Int16
95            | DataType::Int32
96            | DataType::Int64
97            | DataType::UInt64
98            | DataType::Decimal128(_, _)
99            | DataType::Float32
100            | DataType::Float64
101            | DataType::Utf8
102            | DataType::Binary
103            | DataType::FixedSizeBinary(16)
104            | DataType::Date32
105            | DataType::Time64(TimeUnit::Microsecond)
106            | DataType::Timestamp(TimeUnit::Microsecond, _)
107    )
108}
109
110fn write_csv_value(writer: &mut dyn Write, array: &dyn Array, idx: usize) -> Result<()> {
111    if array.is_null(idx) {
112        return Ok(());
113    }
114
115    match array.data_type() {
116        DataType::Boolean => {
117            let arr = array
118                .as_any()
119                .downcast_ref::<BooleanArray>()
120                .expect("DataType/Array mismatch");
121            write!(writer, "{}", arr.value(idx))?;
122        }
123        DataType::Int16 => {
124            let arr = array
125                .as_any()
126                .downcast_ref::<Int16Array>()
127                .expect("DataType/Array mismatch");
128            write!(writer, "{}", arr.value(idx))?;
129        }
130        DataType::Int32 => {
131            let arr = array
132                .as_any()
133                .downcast_ref::<Int32Array>()
134                .expect("DataType/Array mismatch");
135            write!(writer, "{}", arr.value(idx))?;
136        }
137        DataType::Int64 => {
138            let arr = array
139                .as_any()
140                .downcast_ref::<Int64Array>()
141                .expect("DataType/Array mismatch");
142            write!(writer, "{}", arr.value(idx))?;
143        }
144        DataType::UInt64 => {
145            let arr = array
146                .as_any()
147                .downcast_ref::<UInt64Array>()
148                .expect("DataType/Array mismatch");
149            write!(writer, "{}", arr.value(idx))?;
150        }
151        DataType::Decimal128(_, scale) => {
152            let arr = array.as_primitive::<Decimal128Type>();
153            let text = scaled_i128_to_decimal_str(arr.value(idx), *scale);
154            writer.write_all(text.as_bytes())?;
155        }
156        DataType::Float32 => {
157            let arr = array
158                .as_any()
159                .downcast_ref::<Float32Array>()
160                .expect("DataType/Array mismatch");
161            write!(writer, "{}", arr.value(idx))?;
162        }
163        DataType::Float64 => {
164            let arr = array
165                .as_any()
166                .downcast_ref::<Float64Array>()
167                .expect("DataType/Array mismatch");
168            write!(writer, "{}", arr.value(idx))?;
169        }
170        DataType::Utf8 => {
171            let arr = array
172                .as_any()
173                .downcast_ref::<StringArray>()
174                .expect("DataType/Array mismatch");
175            let val = arr.value(idx);
176            if val.contains(',') || val.contains('"') || val.contains('\n') {
177                writer.write_all(b"\"")?;
178                let mut rest = val;
179                while let Some(pos) = rest.find('"') {
180                    writer.write_all(&rest.as_bytes()[..pos])?;
181                    writer.write_all(b"\"\"")?;
182                    rest = &rest[pos + 1..];
183                }
184                writer.write_all(rest.as_bytes())?;
185                writer.write_all(b"\"")?;
186            } else {
187                writer.write_all(val.as_bytes())?;
188            }
189        }
190        DataType::Binary => {
191            let arr = array
192                .as_any()
193                .downcast_ref::<BinaryArray>()
194                .expect("DataType/Array mismatch");
195            let val = arr.value(idx);
196            for byte in val {
197                write!(writer, "{:02x}", byte)?;
198            }
199        }
200        // FixedSizeBinary today only carries 16-byte UUIDs (see
201        // `RivetType::Uuid` → `DataType::FixedSizeBinary(16)` in
202        // `src/types/mapping.rs`). CSV has no native binary cell; emit the
203        // canonical hyphenated lowercase form so downstream readers can
204        // recognise it as a UUID rather than 16 bytes of mojibake. Any
205        // future FixedSizeBinary use that is not a UUID should branch on
206        // the size argument before reaching this arm.
207        DataType::FixedSizeBinary(16) => {
208            let arr = array
209                .as_any()
210                .downcast_ref::<FixedSizeBinaryArray>()
211                .expect("DataType/Array mismatch");
212            let val = arr.value(idx);
213            let mut bytes = [0u8; 16];
214            bytes.copy_from_slice(val);
215            write!(writer, "{}", uuid::Uuid::from_bytes(bytes).to_hyphenated())?;
216        }
217        DataType::Date32 => {
218            let arr = array
219                .as_any()
220                .downcast_ref::<Date32Array>()
221                .expect("DataType/Array mismatch");
222            let days = arr.value(idx);
223            // `Date32` is "days since 1970-01-01"; a pathological value near
224            // i32::MAX overflows `NaiveDate + Duration` and panics in chrono.
225            // Fall back to checked arithmetic and emit an empty cell on
226            // overflow — matches the null-cell convention for unserialisable
227            // values elsewhere in this writer.
228            let epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).expect("epoch is valid");
229            let date =
230                chrono::Duration::try_days(days as i64).and_then(|d| epoch.checked_add_signed(d));
231            if let Some(date) = date {
232                write!(writer, "{}", date)?;
233            }
234        }
235        DataType::Time64(TimeUnit::Microsecond) => {
236            let arr = array
237                .as_any()
238                .downcast_ref::<Time64MicrosecondArray>()
239                .expect("DataType/Array mismatch");
240            let micros = arr.value(idx);
241            let secs = micros / 1_000_000;
242            let frac_us = micros % 1_000_000;
243            write!(
244                writer,
245                "{:02}:{:02}:{:02}.{:06}",
246                secs / 3600,
247                (secs % 3600) / 60,
248                secs % 60,
249                frac_us
250            )?;
251        }
252        DataType::Timestamp(TimeUnit::Microsecond, _) => {
253            let arr = array
254                .as_any()
255                .downcast_ref::<TimestampMicrosecondArray>()
256                .expect("DataType/Array mismatch");
257            let micros = arr.value(idx);
258            let secs = micros / 1_000_000;
259            let nsecs = ((micros % 1_000_000) * 1_000) as u32;
260            if let Some(dt) = chrono::DateTime::from_timestamp(secs, nsecs) {
261                write!(writer, "{}", dt.format("%Y-%m-%dT%H:%M:%S%.6f"))?;
262            }
263        }
264        other => {
265            // Defensive: `create_writer` rejects unsupported types up front, so
266            // this should be unreachable. Bail rather than silently skip if a
267            // new type slips through.
268            anyhow::bail!(
269                "CSV: no serializer for Arrow type {other:?} (column should have been rejected at writer creation)"
270            );
271        }
272    }
273
274    Ok(())
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
281    use std::sync::Arc;
282
283    // Helper: render one cell to a String using write_csv_value.
284    fn cell<A: Array + 'static>(array: A, idx: usize) -> String {
285        let mut buf = Vec::new();
286        write_csv_value(&mut buf, &array, idx).unwrap();
287        String::from_utf8(buf).unwrap()
288    }
289
290    // Helper: render a null cell from any typed array.
291    fn null_cell(dt: DataType) -> String {
292        use arrow::array::new_null_array;
293        let arr = new_null_array(&dt, 1);
294        let mut buf = Vec::new();
295        write_csv_value(&mut buf, arr.as_ref(), 0).unwrap();
296        String::from_utf8(buf).unwrap()
297    }
298
299    // ── null handling ────────────────────────────────────────────────────────
300
301    #[test]
302    fn null_value_writes_empty_string() {
303        assert_eq!(null_cell(DataType::Int64), "");
304        assert_eq!(null_cell(DataType::Utf8), "");
305        assert_eq!(null_cell(DataType::Boolean), "");
306    }
307
308    // ── scalars ─────────────────────────────────────────────────────────────
309
310    #[test]
311    fn bool_true_writes_true() {
312        assert_eq!(cell(BooleanArray::from(vec![true]), 0), "true");
313    }
314
315    #[test]
316    fn bool_false_writes_false() {
317        assert_eq!(cell(BooleanArray::from(vec![false]), 0), "false");
318    }
319
320    #[test]
321    fn int16_value() {
322        assert_eq!(cell(Int16Array::from(vec![42i16]), 0), "42");
323    }
324
325    #[test]
326    fn int32_negative() {
327        assert_eq!(cell(Int32Array::from(vec![-7i32]), 0), "-7");
328    }
329
330    #[test]
331    fn decimal128_writes_exact_text() {
332        let arr = Decimal128Array::from(vec![10i128])
333            .with_precision_and_scale(18, 2)
334            .unwrap();
335        assert_eq!(cell(arr, 0), "0.10");
336        let scaled =
337            crate::types::decimal::decimal_str_to_scaled_i128("999999999999.99", 2).unwrap();
338        let arr = Decimal128Array::from(vec![scaled])
339            .with_precision_and_scale(18, 2)
340            .unwrap();
341        assert_eq!(cell(arr, 0), "999999999999.99");
342    }
343
344    #[test]
345    fn int64_large() {
346        assert_eq!(
347            cell(Int64Array::from(vec![9_999_999_999i64]), 0),
348            "9999999999"
349        );
350    }
351
352    #[test]
353    fn float32_value() {
354        let result = cell(Float32Array::from(vec![1.5f32]), 0);
355        assert!(result.starts_with("1.5"), "got: {result}");
356    }
357
358    #[test]
359    fn float64_value() {
360        let result = cell(Float64Array::from(vec![std::f64::consts::PI]), 0);
361        assert!(result.starts_with("3.14"), "got: {result}");
362    }
363
364    // Characterization: float NaN/±Infinity are IEEE-754 values a float column
365    // can legitimately hold (unlike `decimal`, whose Arrow `Decimal128` has no
366    // NaN bit pattern — see the NUMERIC NaN/infinity reject in
367    // `postgres::arrow_convert::build_array`). They are preserved natively in
368    // Parquet; in CSV we emit the Rust float literal (`NaN` / `inf` / `-inf`)
369    // rather than an empty cell, because writing empty would silently conflate
370    // a real NaN/Inf with NULL — corruption — whereas a recognizable literal
371    // round-trips into every major loader's float parser. This pins that
372    // contract so a future arrow/std change can't silently alter it. The CSV
373    // literal is documented in docs/type-mapping.md.
374    #[test]
375    fn float_special_values_emit_literals_not_empty() {
376        assert_eq!(cell(Float64Array::from(vec![f64::NAN]), 0), "NaN");
377        assert_eq!(cell(Float64Array::from(vec![f64::INFINITY]), 0), "inf");
378        assert_eq!(cell(Float64Array::from(vec![f64::NEG_INFINITY]), 0), "-inf");
379        assert_eq!(cell(Float32Array::from(vec![f32::NAN]), 0), "NaN");
380        assert_eq!(cell(Float32Array::from(vec![f32::INFINITY]), 0), "inf");
381        // -0.0 keeps its sign (a real IEEE-754 distinction), never becomes "0".
382        assert_eq!(cell(Float64Array::from(vec![-0.0f64]), 0), "-0");
383    }
384
385    // ── string escaping ──────────────────────────────────────────────────────
386
387    #[test]
388    fn plain_string_no_quoting() {
389        assert_eq!(cell(StringArray::from(vec!["hello"]), 0), "hello");
390    }
391
392    #[test]
393    fn string_with_comma_is_quoted() {
394        assert_eq!(cell(StringArray::from(vec!["a,b"]), 0), "\"a,b\"");
395    }
396
397    #[test]
398    fn string_with_double_quote_is_escaped() {
399        // say "hi" → opening " + say  + "" + hi + "" + closing " = "say ""hi"""
400        let result = cell(StringArray::from(vec![r#"say "hi""#]), 0);
401        assert_eq!(result, r#""say ""hi""""#);
402    }
403
404    #[test]
405    fn string_with_newline_is_quoted() {
406        let result = cell(StringArray::from(vec!["line1\nline2"]), 0);
407        assert!(
408            result.starts_with('"') && result.ends_with('"'),
409            "got: {result}"
410        );
411        assert!(result.contains("line1\nline2"), "got: {result}");
412    }
413
414    // ── binary ───────────────────────────────────────────────────────────────
415
416    #[test]
417    fn binary_is_written_as_hex() {
418        let arr = BinaryArray::from_vec(vec![&[0xDE, 0xAD, 0xBE, 0xEF][..]]);
419        assert_eq!(cell(arr, 0), "deadbeef");
420    }
421
422    #[test]
423    fn binary_empty_writes_empty() {
424        let arr = BinaryArray::from_vec(vec![&[][..]]);
425        assert_eq!(cell(arr, 0), "");
426    }
427
428    // ── Date32 ───────────────────────────────────────────────────────────────
429
430    #[test]
431    fn date32_epoch_is_1970_01_01() {
432        assert_eq!(cell(Date32Array::from(vec![0i32]), 0), "1970-01-01");
433    }
434
435    #[test]
436    fn date32_positive_offset() {
437        // 365 days after epoch = 1971-01-01
438        assert_eq!(cell(Date32Array::from(vec![365i32]), 0), "1971-01-01");
439    }
440
441    // ── Timestamp(Microsecond) ───────────────────────────────────────────────
442
443    #[test]
444    fn timestamp_micros_formats_as_iso() {
445        // 2023-01-01T00:00:00.000000 = 1672531200_000000 micros since epoch
446        let micros: i64 = 1_672_531_200 * 1_000_000;
447        let _schema = Arc::new(Schema::new(vec![Field::new(
448            "ts",
449            DataType::Timestamp(TimeUnit::Microsecond, None),
450            true,
451        )]));
452        let arr = TimestampMicrosecondArray::from(vec![micros]);
453        let result = cell(arr, 0);
454        assert!(result.starts_with("2023-01-01T"), "got: {result}");
455        assert!(result.contains("00:00:00"), "got: {result}");
456    }
457
458    // ── write_batch via CsvFormat ────────────────────────────────────────────
459
460    #[test]
461    fn csv_format_write_batch_tracks_bytes_and_succeeds() {
462        use crate::format::Format;
463
464        let schema = Arc::new(Schema::new(vec![
465            Field::new("id", DataType::Int64, false),
466            Field::new("name", DataType::Utf8, true),
467        ]));
468        let batch = arrow::record_batch::RecordBatch::try_new(
469            schema.clone(),
470            vec![
471                Arc::new(Int64Array::from(vec![1i64, 2])),
472                Arc::new(StringArray::from(vec![Some("alice"), None])),
473            ],
474        )
475        .unwrap();
476
477        // Pass Vec by value — avoids the &mut T 'static lifetime requirement.
478        let fmt = CsvFormat;
479        let mut writer = fmt
480            .create_writer(&schema, Box::new(Vec::<u8>::new()))
481            .unwrap();
482        writer.write_batch(&batch).unwrap();
483        // Header "id,name\n" + rows "1,alice\n" + "2,\n" = at least 18 bytes
484        assert!(
485            writer.bytes_written() > 10,
486            "expected >10 bytes, got {}",
487            writer.bytes_written()
488        );
489        writer.finish().unwrap();
490    }
491
492    // ── fail loud on types CSV can't represent ───────────────────────────────
493
494    #[test]
495    fn csv_rejects_array_columns_loudly() {
496        use crate::format::Format;
497        let schema = Arc::new(Schema::new(vec![
498            Field::new("id", DataType::Int64, false),
499            Field::new(
500                "tags",
501                DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
502                true,
503            ),
504        ]));
505        let Err(err) = CsvFormat.create_writer(&schema, Box::new(Vec::<u8>::new())) else {
506            panic!("CSV must reject array columns, not silently drop them");
507        };
508        let msg = format!("{err:#}");
509        assert!(msg.contains("tags"), "error must name the column: {msg}");
510        assert!(msg.to_lowercase().contains("csv"), "{msg}");
511    }
512
513    /// Consistency guard: every type `csv_serializable` admits must actually be
514    /// handled by `write_csv_value` (not hit its `other => bail` fallthrough).
515    /// Keeps the whitelist and the writer in lock-step so one can't drift.
516    #[test]
517    fn every_serializable_type_is_actually_written() {
518        use crate::format::Format;
519        let cols: Vec<(&str, ArrayRef)> = vec![
520            ("b", Arc::new(BooleanArray::from(vec![true]))),
521            ("i16", Arc::new(Int16Array::from(vec![1i16]))),
522            ("i32", Arc::new(Int32Array::from(vec![1i32]))),
523            ("i64", Arc::new(Int64Array::from(vec![1i64]))),
524            ("u64", Arc::new(UInt64Array::from(vec![1u64]))),
525            (
526                "dec",
527                Arc::new(
528                    Decimal128Array::from(vec![100i128])
529                        .with_precision_and_scale(18, 2)
530                        .unwrap(),
531                ),
532            ),
533            ("f32", Arc::new(Float32Array::from(vec![1.0f32]))),
534            ("f64", Arc::new(Float64Array::from(vec![1.0f64]))),
535            ("s", Arc::new(StringArray::from(vec!["x"]))),
536            ("bin", Arc::new(BinaryArray::from_vec(vec![&[1u8][..]]))),
537            (
538                "uuid",
539                Arc::new(
540                    FixedSizeBinaryArray::try_from_iter(std::iter::once(vec![0u8; 16])).unwrap(),
541                ),
542            ),
543            ("d", Arc::new(Date32Array::from(vec![0i32]))),
544            ("t", Arc::new(Time64MicrosecondArray::from(vec![0i64]))),
545            ("ts", Arc::new(TimestampMicrosecondArray::from(vec![0i64]))),
546        ];
547        let fields: Vec<Field> = cols
548            .iter()
549            .map(|(n, a)| Field::new(*n, a.data_type().clone(), true))
550            .collect();
551        // Sanity: each column's type is on the whitelist.
552        for f in &fields {
553            assert!(
554                csv_serializable(f.data_type()),
555                "test type {:?} not in csv_serializable",
556                f.data_type()
557            );
558        }
559        let schema = Arc::new(Schema::new(fields));
560        let arrays: Vec<ArrayRef> = cols.into_iter().map(|(_, a)| a).collect();
561        let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
562        let mut w = CsvFormat
563            .create_writer(&schema, Box::new(Vec::<u8>::new()))
564            .unwrap();
565        w.write_batch(&batch)
566            .expect("every serializable type must write without hitting the fallthrough");
567    }
568}