Skip to main content

rivet/format/
csv.rs

1use std::io::Write;
2
3use arrow::array::*;
4use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
5use arrow::record_batch::RecordBatch;
6
7use crate::error::Result;
8
9pub struct CsvFormat;
10
11pub struct CsvFormatWriter {
12    writer: Box<dyn Write + Send>,
13    bytes_written: u64,
14}
15
16impl super::Format for CsvFormat {
17    fn create_writer(
18        &self,
19        schema: &SchemaRef,
20        mut writer: Box<dyn Write + Send>,
21    ) -> Result<Box<dyn super::FormatWriter>> {
22        let header = schema
23            .fields()
24            .iter()
25            .map(|f| f.name().as_str())
26            .collect::<Vec<_>>()
27            .join(",");
28        let header_bytes = header.len() as u64 + 1; // +1 for newline
29        writeln!(writer, "{}", header)?;
30        Ok(Box::new(CsvFormatWriter {
31            writer,
32            bytes_written: header_bytes,
33        }))
34    }
35
36    fn file_extension(&self) -> &str {
37        "csv"
38    }
39}
40
41impl super::FormatWriter for CsvFormatWriter {
42    fn write_batch(&mut self, batch: &RecordBatch) -> Result<()> {
43        let mut buf = Vec::with_capacity(batch.num_rows() * batch.num_columns() * 8);
44        for row_idx in 0..batch.num_rows() {
45            for col_idx in 0..batch.num_columns() {
46                if col_idx > 0 {
47                    buf.push(b',');
48                }
49                write_csv_value(&mut buf, batch.column(col_idx), row_idx)?;
50            }
51            buf.push(b'\n');
52        }
53        self.bytes_written += buf.len() as u64;
54        self.writer.write_all(&buf)?;
55        Ok(())
56    }
57
58    fn finish(self: Box<Self>) -> Result<()> {
59        Ok(())
60    }
61
62    fn bytes_written(&self) -> u64 {
63        self.bytes_written
64    }
65}
66
67fn write_csv_value(writer: &mut dyn Write, array: &dyn Array, idx: usize) -> Result<()> {
68    if array.is_null(idx) {
69        return Ok(());
70    }
71
72    match array.data_type() {
73        DataType::Boolean => {
74            let arr = array
75                .as_any()
76                .downcast_ref::<BooleanArray>()
77                .expect("DataType/Array mismatch");
78            write!(writer, "{}", arr.value(idx))?;
79        }
80        DataType::Int16 => {
81            let arr = array
82                .as_any()
83                .downcast_ref::<Int16Array>()
84                .expect("DataType/Array mismatch");
85            write!(writer, "{}", arr.value(idx))?;
86        }
87        DataType::Int32 => {
88            let arr = array
89                .as_any()
90                .downcast_ref::<Int32Array>()
91                .expect("DataType/Array mismatch");
92            write!(writer, "{}", arr.value(idx))?;
93        }
94        DataType::Int64 => {
95            let arr = array
96                .as_any()
97                .downcast_ref::<Int64Array>()
98                .expect("DataType/Array mismatch");
99            write!(writer, "{}", arr.value(idx))?;
100        }
101        DataType::Float32 => {
102            let arr = array
103                .as_any()
104                .downcast_ref::<Float32Array>()
105                .expect("DataType/Array mismatch");
106            write!(writer, "{}", arr.value(idx))?;
107        }
108        DataType::Float64 => {
109            let arr = array
110                .as_any()
111                .downcast_ref::<Float64Array>()
112                .expect("DataType/Array mismatch");
113            write!(writer, "{}", arr.value(idx))?;
114        }
115        DataType::Utf8 => {
116            let arr = array
117                .as_any()
118                .downcast_ref::<StringArray>()
119                .expect("DataType/Array mismatch");
120            let val = arr.value(idx);
121            if val.contains(',') || val.contains('"') || val.contains('\n') {
122                writer.write_all(b"\"")?;
123                let mut rest = val;
124                while let Some(pos) = rest.find('"') {
125                    writer.write_all(&rest.as_bytes()[..pos])?;
126                    writer.write_all(b"\"\"")?;
127                    rest = &rest[pos + 1..];
128                }
129                writer.write_all(rest.as_bytes())?;
130                writer.write_all(b"\"")?;
131            } else {
132                writer.write_all(val.as_bytes())?;
133            }
134        }
135        DataType::Binary => {
136            let arr = array
137                .as_any()
138                .downcast_ref::<BinaryArray>()
139                .expect("DataType/Array mismatch");
140            let val = arr.value(idx);
141            for byte in val {
142                write!(writer, "{:02x}", byte)?;
143            }
144        }
145        DataType::Date32 => {
146            let arr = array
147                .as_any()
148                .downcast_ref::<Date32Array>()
149                .expect("DataType/Array mismatch");
150            let days = arr.value(idx);
151            // `Date32` is "days since 1970-01-01"; a pathological value near
152            // i32::MAX overflows `NaiveDate + Duration` and panics in chrono.
153            // Fall back to checked arithmetic and emit an empty cell on
154            // overflow — matches the null-cell convention for unserialisable
155            // values elsewhere in this writer.
156            let epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).expect("epoch is valid");
157            let date =
158                chrono::Duration::try_days(days as i64).and_then(|d| epoch.checked_add_signed(d));
159            if let Some(date) = date {
160                write!(writer, "{}", date)?;
161            }
162        }
163        DataType::Timestamp(TimeUnit::Microsecond, _) => {
164            let arr = array
165                .as_any()
166                .downcast_ref::<TimestampMicrosecondArray>()
167                .expect("DataType/Array mismatch");
168            let micros = arr.value(idx);
169            let secs = micros / 1_000_000;
170            let nsecs = ((micros % 1_000_000) * 1_000) as u32;
171            if let Some(dt) = chrono::DateTime::from_timestamp(secs, nsecs) {
172                write!(writer, "{}", dt.format("%Y-%m-%dT%H:%M:%S%.6f"))?;
173            }
174        }
175        other => {
176            log::warn!("CSV: unhandled Arrow type {:?}, skipping value", other);
177        }
178    }
179
180    Ok(())
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
187    use std::sync::Arc;
188
189    // Helper: render one cell to a String using write_csv_value.
190    fn cell<A: Array + 'static>(array: A, idx: usize) -> String {
191        let mut buf = Vec::new();
192        write_csv_value(&mut buf, &array, idx).unwrap();
193        String::from_utf8(buf).unwrap()
194    }
195
196    // Helper: render a null cell from any typed array.
197    fn null_cell(dt: DataType) -> String {
198        use arrow::array::new_null_array;
199        let arr = new_null_array(&dt, 1);
200        let mut buf = Vec::new();
201        write_csv_value(&mut buf, arr.as_ref(), 0).unwrap();
202        String::from_utf8(buf).unwrap()
203    }
204
205    // ── null handling ────────────────────────────────────────────────────────
206
207    #[test]
208    fn null_value_writes_empty_string() {
209        assert_eq!(null_cell(DataType::Int64), "");
210        assert_eq!(null_cell(DataType::Utf8), "");
211        assert_eq!(null_cell(DataType::Boolean), "");
212    }
213
214    // ── scalars ─────────────────────────────────────────────────────────────
215
216    #[test]
217    fn bool_true_writes_true() {
218        assert_eq!(cell(BooleanArray::from(vec![true]), 0), "true");
219    }
220
221    #[test]
222    fn bool_false_writes_false() {
223        assert_eq!(cell(BooleanArray::from(vec![false]), 0), "false");
224    }
225
226    #[test]
227    fn int16_value() {
228        assert_eq!(cell(Int16Array::from(vec![42i16]), 0), "42");
229    }
230
231    #[test]
232    fn int32_negative() {
233        assert_eq!(cell(Int32Array::from(vec![-7i32]), 0), "-7");
234    }
235
236    #[test]
237    fn int64_large() {
238        assert_eq!(
239            cell(Int64Array::from(vec![9_999_999_999i64]), 0),
240            "9999999999"
241        );
242    }
243
244    #[test]
245    fn float32_value() {
246        let result = cell(Float32Array::from(vec![1.5f32]), 0);
247        assert!(result.starts_with("1.5"), "got: {result}");
248    }
249
250    #[test]
251    fn float64_value() {
252        let result = cell(Float64Array::from(vec![std::f64::consts::PI]), 0);
253        assert!(result.starts_with("3.14"), "got: {result}");
254    }
255
256    // ── string escaping ──────────────────────────────────────────────────────
257
258    #[test]
259    fn plain_string_no_quoting() {
260        assert_eq!(cell(StringArray::from(vec!["hello"]), 0), "hello");
261    }
262
263    #[test]
264    fn string_with_comma_is_quoted() {
265        assert_eq!(cell(StringArray::from(vec!["a,b"]), 0), "\"a,b\"");
266    }
267
268    #[test]
269    fn string_with_double_quote_is_escaped() {
270        // say "hi" → opening " + say  + "" + hi + "" + closing " = "say ""hi"""
271        let result = cell(StringArray::from(vec![r#"say "hi""#]), 0);
272        assert_eq!(result, r#""say ""hi""""#);
273    }
274
275    #[test]
276    fn string_with_newline_is_quoted() {
277        let result = cell(StringArray::from(vec!["line1\nline2"]), 0);
278        assert!(
279            result.starts_with('"') && result.ends_with('"'),
280            "got: {result}"
281        );
282        assert!(result.contains("line1\nline2"), "got: {result}");
283    }
284
285    // ── binary ───────────────────────────────────────────────────────────────
286
287    #[test]
288    fn binary_is_written_as_hex() {
289        let arr = BinaryArray::from_vec(vec![&[0xDE, 0xAD, 0xBE, 0xEF][..]]);
290        assert_eq!(cell(arr, 0), "deadbeef");
291    }
292
293    #[test]
294    fn binary_empty_writes_empty() {
295        let arr = BinaryArray::from_vec(vec![&[][..]]);
296        assert_eq!(cell(arr, 0), "");
297    }
298
299    // ── Date32 ───────────────────────────────────────────────────────────────
300
301    #[test]
302    fn date32_epoch_is_1970_01_01() {
303        assert_eq!(cell(Date32Array::from(vec![0i32]), 0), "1970-01-01");
304    }
305
306    #[test]
307    fn date32_positive_offset() {
308        // 365 days after epoch = 1971-01-01
309        assert_eq!(cell(Date32Array::from(vec![365i32]), 0), "1971-01-01");
310    }
311
312    // ── Timestamp(Microsecond) ───────────────────────────────────────────────
313
314    #[test]
315    fn timestamp_micros_formats_as_iso() {
316        // 2023-01-01T00:00:00.000000 = 1672531200_000000 micros since epoch
317        let micros: i64 = 1_672_531_200 * 1_000_000;
318        let _schema = Arc::new(Schema::new(vec![Field::new(
319            "ts",
320            DataType::Timestamp(TimeUnit::Microsecond, None),
321            true,
322        )]));
323        let arr = TimestampMicrosecondArray::from(vec![micros]);
324        let result = cell(arr, 0);
325        assert!(result.starts_with("2023-01-01T"), "got: {result}");
326        assert!(result.contains("00:00:00"), "got: {result}");
327    }
328
329    // ── write_batch via CsvFormat ────────────────────────────────────────────
330
331    #[test]
332    fn csv_format_write_batch_tracks_bytes_and_succeeds() {
333        use crate::format::Format;
334
335        let schema = Arc::new(Schema::new(vec![
336            Field::new("id", DataType::Int64, false),
337            Field::new("name", DataType::Utf8, true),
338        ]));
339        let batch = arrow::record_batch::RecordBatch::try_new(
340            schema.clone(),
341            vec![
342                Arc::new(Int64Array::from(vec![1i64, 2])),
343                Arc::new(StringArray::from(vec![Some("alice"), None])),
344            ],
345        )
346        .unwrap();
347
348        // Pass Vec by value — avoids the &mut T 'static lifetime requirement.
349        let fmt = CsvFormat;
350        let mut writer = fmt
351            .create_writer(&schema, Box::new(Vec::<u8>::new()))
352            .unwrap();
353        writer.write_batch(&batch).unwrap();
354        // Header "id,name\n" + rows "1,alice\n" + "2,\n" = at least 18 bytes
355        assert!(
356            writer.bytes_written() > 10,
357            "expected >10 bytes, got {}",
358            writer.bytes_written()
359        );
360        writer.finish().unwrap();
361    }
362}