Skip to main content

rivet_cli/format/
csv.rs

1use std::io::Write;
2
3use arrow::array::*;
4use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
5use arrow::record_batch::RecordBatch;
6
7use crate::error::Result;
8
9pub struct CsvFormat;
10
11pub struct CsvFormatWriter {
12    writer: Box<dyn Write + Send>,
13    bytes_written: u64,
14}
15
16impl super::Format for CsvFormat {
17    fn create_writer(
18        &self,
19        schema: &SchemaRef,
20        mut writer: Box<dyn Write + Send>,
21    ) -> Result<Box<dyn super::FormatWriter>> {
22        let header = schema
23            .fields()
24            .iter()
25            .map(|f| f.name().as_str())
26            .collect::<Vec<_>>()
27            .join(",");
28        let header_bytes = header.len() as u64 + 1; // +1 for newline
29        writeln!(writer, "{}", header)?;
30        Ok(Box::new(CsvFormatWriter {
31            writer,
32            bytes_written: header_bytes,
33        }))
34    }
35
36    fn file_extension(&self) -> &str {
37        "csv"
38    }
39}
40
41impl super::FormatWriter for CsvFormatWriter {
42    fn write_batch(&mut self, batch: &RecordBatch) -> Result<()> {
43        let before = self.bytes_written;
44        let _ = before; // suppress unused warning, we count after
45
46        let mut buf = Vec::new();
47        for row_idx in 0..batch.num_rows() {
48            let mut first = true;
49            for col_idx in 0..batch.num_columns() {
50                if !first {
51                    write!(buf, ",")?;
52                }
53                first = false;
54                write_csv_value(&mut buf, batch.column(col_idx), row_idx)?;
55            }
56            writeln!(buf)?;
57        }
58        self.bytes_written += buf.len() as u64;
59        self.writer.write_all(&buf)?;
60        Ok(())
61    }
62
63    fn finish(self: Box<Self>) -> Result<()> {
64        Ok(())
65    }
66
67    fn bytes_written(&self) -> u64 {
68        self.bytes_written
69    }
70}
71
72fn write_csv_value(writer: &mut dyn Write, array: &dyn Array, idx: usize) -> Result<()> {
73    if array.is_null(idx) {
74        return Ok(());
75    }
76
77    match array.data_type() {
78        DataType::Boolean => {
79            let arr = array
80                .as_any()
81                .downcast_ref::<BooleanArray>()
82                .expect("DataType/Array mismatch");
83            write!(writer, "{}", arr.value(idx))?;
84        }
85        DataType::Int16 => {
86            let arr = array
87                .as_any()
88                .downcast_ref::<Int16Array>()
89                .expect("DataType/Array mismatch");
90            write!(writer, "{}", arr.value(idx))?;
91        }
92        DataType::Int32 => {
93            let arr = array
94                .as_any()
95                .downcast_ref::<Int32Array>()
96                .expect("DataType/Array mismatch");
97            write!(writer, "{}", arr.value(idx))?;
98        }
99        DataType::Int64 => {
100            let arr = array
101                .as_any()
102                .downcast_ref::<Int64Array>()
103                .expect("DataType/Array mismatch");
104            write!(writer, "{}", arr.value(idx))?;
105        }
106        DataType::Float32 => {
107            let arr = array
108                .as_any()
109                .downcast_ref::<Float32Array>()
110                .expect("DataType/Array mismatch");
111            write!(writer, "{}", arr.value(idx))?;
112        }
113        DataType::Float64 => {
114            let arr = array
115                .as_any()
116                .downcast_ref::<Float64Array>()
117                .expect("DataType/Array mismatch");
118            write!(writer, "{}", arr.value(idx))?;
119        }
120        DataType::Utf8 => {
121            let arr = array
122                .as_any()
123                .downcast_ref::<StringArray>()
124                .expect("DataType/Array mismatch");
125            let val = arr.value(idx);
126            if val.contains(',') || val.contains('"') || val.contains('\n') {
127                write!(writer, "\"{}\"", val.replace('"', "\"\""))?;
128            } else {
129                write!(writer, "{}", val)?;
130            }
131        }
132        DataType::Binary => {
133            let arr = array
134                .as_any()
135                .downcast_ref::<BinaryArray>()
136                .expect("DataType/Array mismatch");
137            let val = arr.value(idx);
138            for byte in val {
139                write!(writer, "{:02x}", byte)?;
140            }
141        }
142        DataType::Date32 => {
143            let arr = array
144                .as_any()
145                .downcast_ref::<Date32Array>()
146                .expect("DataType/Array mismatch");
147            let days = arr.value(idx);
148            let date = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).expect("epoch is valid")
149                + chrono::Duration::days(days as i64);
150            write!(writer, "{}", date)?;
151        }
152        DataType::Timestamp(TimeUnit::Microsecond, _) => {
153            let arr = array
154                .as_any()
155                .downcast_ref::<TimestampMicrosecondArray>()
156                .expect("DataType/Array mismatch");
157            let micros = arr.value(idx);
158            let secs = micros / 1_000_000;
159            let nsecs = ((micros % 1_000_000) * 1_000) as u32;
160            if let Some(dt) = chrono::DateTime::from_timestamp(secs, nsecs) {
161                write!(writer, "{}", dt.format("%Y-%m-%dT%H:%M:%S%.6f"))?;
162            }
163        }
164        other => {
165            log::warn!("CSV: unhandled Arrow type {:?}, skipping value", other);
166        }
167    }
168
169    Ok(())
170}