Skip to main content

rivet/format/
csv.rs

1use std::io::Write;
2
3use arrow::array::Time64MicrosecondArray;
4use arrow::array::types::Decimal128Type;
5use arrow::array::*;
6use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
7use arrow::record_batch::RecordBatch;
8
9use crate::error::Result;
10use crate::types::decimal::scaled_i128_to_decimal_str;
11
12pub struct CsvFormat;
13
14pub struct CsvFormatWriter {
15    writer: Box<dyn Write + Send>,
16    bytes_written: u64,
17}
18
19impl super::Format for CsvFormat {
20    fn create_writer(
21        &self,
22        schema: &SchemaRef,
23        mut writer: Box<dyn Write + Send>,
24    ) -> Result<Box<dyn super::FormatWriter + Send>> {
25        let header = schema
26            .fields()
27            .iter()
28            .map(|f| f.name().as_str())
29            .collect::<Vec<_>>()
30            .join(",");
31        let header_bytes = header.len() as u64 + 1; // +1 for newline
32        writeln!(writer, "{}", header)?;
33        Ok(Box::new(CsvFormatWriter {
34            writer,
35            bytes_written: header_bytes,
36        }))
37    }
38
39    fn file_extension(&self) -> &str {
40        "csv"
41    }
42}
43
44impl super::FormatWriter for CsvFormatWriter {
45    fn write_batch(&mut self, batch: &RecordBatch) -> Result<()> {
46        let mut buf = Vec::with_capacity(batch.num_rows() * batch.num_columns() * 8);
47        for row_idx in 0..batch.num_rows() {
48            for col_idx in 0..batch.num_columns() {
49                if col_idx > 0 {
50                    buf.push(b',');
51                }
52                write_csv_value(&mut buf, batch.column(col_idx), row_idx)?;
53            }
54            buf.push(b'\n');
55        }
56        self.bytes_written += buf.len() as u64;
57        self.writer.write_all(&buf)?;
58        Ok(())
59    }
60
61    fn finish(self: Box<Self>) -> Result<()> {
62        Ok(())
63    }
64
65    fn bytes_written(&self) -> u64 {
66        self.bytes_written
67    }
68}
69
70fn write_csv_value(writer: &mut dyn Write, array: &dyn Array, idx: usize) -> Result<()> {
71    if array.is_null(idx) {
72        return Ok(());
73    }
74
75    match array.data_type() {
76        DataType::Boolean => {
77            let arr = array
78                .as_any()
79                .downcast_ref::<BooleanArray>()
80                .expect("DataType/Array mismatch");
81            write!(writer, "{}", arr.value(idx))?;
82        }
83        DataType::Int16 => {
84            let arr = array
85                .as_any()
86                .downcast_ref::<Int16Array>()
87                .expect("DataType/Array mismatch");
88            write!(writer, "{}", arr.value(idx))?;
89        }
90        DataType::Int32 => {
91            let arr = array
92                .as_any()
93                .downcast_ref::<Int32Array>()
94                .expect("DataType/Array mismatch");
95            write!(writer, "{}", arr.value(idx))?;
96        }
97        DataType::Int64 => {
98            let arr = array
99                .as_any()
100                .downcast_ref::<Int64Array>()
101                .expect("DataType/Array mismatch");
102            write!(writer, "{}", arr.value(idx))?;
103        }
104        DataType::UInt64 => {
105            let arr = array
106                .as_any()
107                .downcast_ref::<UInt64Array>()
108                .expect("DataType/Array mismatch");
109            write!(writer, "{}", arr.value(idx))?;
110        }
111        DataType::Decimal128(_, scale) => {
112            let arr = array.as_primitive::<Decimal128Type>();
113            let text = scaled_i128_to_decimal_str(arr.value(idx), *scale);
114            writer.write_all(text.as_bytes())?;
115        }
116        DataType::Float32 => {
117            let arr = array
118                .as_any()
119                .downcast_ref::<Float32Array>()
120                .expect("DataType/Array mismatch");
121            write!(writer, "{}", arr.value(idx))?;
122        }
123        DataType::Float64 => {
124            let arr = array
125                .as_any()
126                .downcast_ref::<Float64Array>()
127                .expect("DataType/Array mismatch");
128            write!(writer, "{}", arr.value(idx))?;
129        }
130        DataType::Utf8 => {
131            let arr = array
132                .as_any()
133                .downcast_ref::<StringArray>()
134                .expect("DataType/Array mismatch");
135            let val = arr.value(idx);
136            if val.contains(',') || val.contains('"') || val.contains('\n') {
137                writer.write_all(b"\"")?;
138                let mut rest = val;
139                while let Some(pos) = rest.find('"') {
140                    writer.write_all(&rest.as_bytes()[..pos])?;
141                    writer.write_all(b"\"\"")?;
142                    rest = &rest[pos + 1..];
143                }
144                writer.write_all(rest.as_bytes())?;
145                writer.write_all(b"\"")?;
146            } else {
147                writer.write_all(val.as_bytes())?;
148            }
149        }
150        DataType::Binary => {
151            let arr = array
152                .as_any()
153                .downcast_ref::<BinaryArray>()
154                .expect("DataType/Array mismatch");
155            let val = arr.value(idx);
156            for byte in val {
157                write!(writer, "{:02x}", byte)?;
158            }
159        }
160        // FixedSizeBinary today only carries 16-byte UUIDs (see
161        // `RivetType::Uuid` → `DataType::FixedSizeBinary(16)` in
162        // `src/types/mapping.rs`). CSV has no native binary cell; emit the
163        // canonical hyphenated lowercase form so downstream readers can
164        // recognise it as a UUID rather than 16 bytes of mojibake. Any
165        // future FixedSizeBinary use that is not a UUID should branch on
166        // the size argument before reaching this arm.
167        DataType::FixedSizeBinary(16) => {
168            let arr = array
169                .as_any()
170                .downcast_ref::<FixedSizeBinaryArray>()
171                .expect("DataType/Array mismatch");
172            let val = arr.value(idx);
173            let mut bytes = [0u8; 16];
174            bytes.copy_from_slice(val);
175            write!(writer, "{}", uuid::Uuid::from_bytes(bytes).to_hyphenated())?;
176        }
177        DataType::Date32 => {
178            let arr = array
179                .as_any()
180                .downcast_ref::<Date32Array>()
181                .expect("DataType/Array mismatch");
182            let days = arr.value(idx);
183            // `Date32` is "days since 1970-01-01"; a pathological value near
184            // i32::MAX overflows `NaiveDate + Duration` and panics in chrono.
185            // Fall back to checked arithmetic and emit an empty cell on
186            // overflow — matches the null-cell convention for unserialisable
187            // values elsewhere in this writer.
188            let epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).expect("epoch is valid");
189            let date =
190                chrono::Duration::try_days(days as i64).and_then(|d| epoch.checked_add_signed(d));
191            if let Some(date) = date {
192                write!(writer, "{}", date)?;
193            }
194        }
195        DataType::Time64(TimeUnit::Microsecond) => {
196            let arr = array
197                .as_any()
198                .downcast_ref::<Time64MicrosecondArray>()
199                .expect("DataType/Array mismatch");
200            let micros = arr.value(idx);
201            let secs = micros / 1_000_000;
202            let frac_us = micros % 1_000_000;
203            write!(
204                writer,
205                "{:02}:{:02}:{:02}.{:06}",
206                secs / 3600,
207                (secs % 3600) / 60,
208                secs % 60,
209                frac_us
210            )?;
211        }
212        DataType::Timestamp(TimeUnit::Microsecond, _) => {
213            let arr = array
214                .as_any()
215                .downcast_ref::<TimestampMicrosecondArray>()
216                .expect("DataType/Array mismatch");
217            let micros = arr.value(idx);
218            let secs = micros / 1_000_000;
219            let nsecs = ((micros % 1_000_000) * 1_000) as u32;
220            if let Some(dt) = chrono::DateTime::from_timestamp(secs, nsecs) {
221                write!(writer, "{}", dt.format("%Y-%m-%dT%H:%M:%S%.6f"))?;
222            }
223        }
224        other => {
225            log::warn!("CSV: unhandled Arrow type {:?}, skipping value", other);
226        }
227    }
228
229    Ok(())
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
236    use std::sync::Arc;
237
238    // Helper: render one cell to a String using write_csv_value.
239    fn cell<A: Array + 'static>(array: A, idx: usize) -> String {
240        let mut buf = Vec::new();
241        write_csv_value(&mut buf, &array, idx).unwrap();
242        String::from_utf8(buf).unwrap()
243    }
244
245    // Helper: render a null cell from any typed array.
246    fn null_cell(dt: DataType) -> String {
247        use arrow::array::new_null_array;
248        let arr = new_null_array(&dt, 1);
249        let mut buf = Vec::new();
250        write_csv_value(&mut buf, arr.as_ref(), 0).unwrap();
251        String::from_utf8(buf).unwrap()
252    }
253
254    // ── null handling ────────────────────────────────────────────────────────
255
256    #[test]
257    fn null_value_writes_empty_string() {
258        assert_eq!(null_cell(DataType::Int64), "");
259        assert_eq!(null_cell(DataType::Utf8), "");
260        assert_eq!(null_cell(DataType::Boolean), "");
261    }
262
263    // ── scalars ─────────────────────────────────────────────────────────────
264
265    #[test]
266    fn bool_true_writes_true() {
267        assert_eq!(cell(BooleanArray::from(vec![true]), 0), "true");
268    }
269
270    #[test]
271    fn bool_false_writes_false() {
272        assert_eq!(cell(BooleanArray::from(vec![false]), 0), "false");
273    }
274
275    #[test]
276    fn int16_value() {
277        assert_eq!(cell(Int16Array::from(vec![42i16]), 0), "42");
278    }
279
280    #[test]
281    fn int32_negative() {
282        assert_eq!(cell(Int32Array::from(vec![-7i32]), 0), "-7");
283    }
284
285    #[test]
286    fn decimal128_writes_exact_text() {
287        let arr = Decimal128Array::from(vec![10i128])
288            .with_precision_and_scale(18, 2)
289            .unwrap();
290        assert_eq!(cell(arr, 0), "0.10");
291        let scaled =
292            crate::types::decimal::decimal_str_to_scaled_i128("999999999999.99", 2).unwrap();
293        let arr = Decimal128Array::from(vec![scaled])
294            .with_precision_and_scale(18, 2)
295            .unwrap();
296        assert_eq!(cell(arr, 0), "999999999999.99");
297    }
298
299    #[test]
300    fn int64_large() {
301        assert_eq!(
302            cell(Int64Array::from(vec![9_999_999_999i64]), 0),
303            "9999999999"
304        );
305    }
306
307    #[test]
308    fn float32_value() {
309        let result = cell(Float32Array::from(vec![1.5f32]), 0);
310        assert!(result.starts_with("1.5"), "got: {result}");
311    }
312
313    #[test]
314    fn float64_value() {
315        let result = cell(Float64Array::from(vec![std::f64::consts::PI]), 0);
316        assert!(result.starts_with("3.14"), "got: {result}");
317    }
318
319    // ── string escaping ──────────────────────────────────────────────────────
320
321    #[test]
322    fn plain_string_no_quoting() {
323        assert_eq!(cell(StringArray::from(vec!["hello"]), 0), "hello");
324    }
325
326    #[test]
327    fn string_with_comma_is_quoted() {
328        assert_eq!(cell(StringArray::from(vec!["a,b"]), 0), "\"a,b\"");
329    }
330
331    #[test]
332    fn string_with_double_quote_is_escaped() {
333        // say "hi" → opening " + say  + "" + hi + "" + closing " = "say ""hi"""
334        let result = cell(StringArray::from(vec![r#"say "hi""#]), 0);
335        assert_eq!(result, r#""say ""hi""""#);
336    }
337
338    #[test]
339    fn string_with_newline_is_quoted() {
340        let result = cell(StringArray::from(vec!["line1\nline2"]), 0);
341        assert!(
342            result.starts_with('"') && result.ends_with('"'),
343            "got: {result}"
344        );
345        assert!(result.contains("line1\nline2"), "got: {result}");
346    }
347
348    // ── binary ───────────────────────────────────────────────────────────────
349
350    #[test]
351    fn binary_is_written_as_hex() {
352        let arr = BinaryArray::from_vec(vec![&[0xDE, 0xAD, 0xBE, 0xEF][..]]);
353        assert_eq!(cell(arr, 0), "deadbeef");
354    }
355
356    #[test]
357    fn binary_empty_writes_empty() {
358        let arr = BinaryArray::from_vec(vec![&[][..]]);
359        assert_eq!(cell(arr, 0), "");
360    }
361
362    // ── Date32 ───────────────────────────────────────────────────────────────
363
364    #[test]
365    fn date32_epoch_is_1970_01_01() {
366        assert_eq!(cell(Date32Array::from(vec![0i32]), 0), "1970-01-01");
367    }
368
369    #[test]
370    fn date32_positive_offset() {
371        // 365 days after epoch = 1971-01-01
372        assert_eq!(cell(Date32Array::from(vec![365i32]), 0), "1971-01-01");
373    }
374
375    // ── Timestamp(Microsecond) ───────────────────────────────────────────────
376
377    #[test]
378    fn timestamp_micros_formats_as_iso() {
379        // 2023-01-01T00:00:00.000000 = 1672531200_000000 micros since epoch
380        let micros: i64 = 1_672_531_200 * 1_000_000;
381        let _schema = Arc::new(Schema::new(vec![Field::new(
382            "ts",
383            DataType::Timestamp(TimeUnit::Microsecond, None),
384            true,
385        )]));
386        let arr = TimestampMicrosecondArray::from(vec![micros]);
387        let result = cell(arr, 0);
388        assert!(result.starts_with("2023-01-01T"), "got: {result}");
389        assert!(result.contains("00:00:00"), "got: {result}");
390    }
391
392    // ── write_batch via CsvFormat ────────────────────────────────────────────
393
394    #[test]
395    fn csv_format_write_batch_tracks_bytes_and_succeeds() {
396        use crate::format::Format;
397
398        let schema = Arc::new(Schema::new(vec![
399            Field::new("id", DataType::Int64, false),
400            Field::new("name", DataType::Utf8, true),
401        ]));
402        let batch = arrow::record_batch::RecordBatch::try_new(
403            schema.clone(),
404            vec![
405                Arc::new(Int64Array::from(vec![1i64, 2])),
406                Arc::new(StringArray::from(vec![Some("alice"), None])),
407            ],
408        )
409        .unwrap();
410
411        // Pass Vec by value — avoids the &mut T 'static lifetime requirement.
412        let fmt = CsvFormat;
413        let mut writer = fmt
414            .create_writer(&schema, Box::new(Vec::<u8>::new()))
415            .unwrap();
416        writer.write_batch(&batch).unwrap();
417        // Header "id,name\n" + rows "1,alice\n" + "2,\n" = at least 18 bytes
418        assert!(
419            writer.bytes_written() > 10,
420            "expected >10 bytes, got {}",
421            writer.bytes_written()
422        );
423        writer.finish().unwrap();
424    }
425}