1use std::io::Write;
2
3use arrow::array::*;
4use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
5use arrow::record_batch::RecordBatch;
6
7use crate::error::Result;
8
9pub struct CsvFormat;
10
11pub struct CsvFormatWriter {
12 writer: Box<dyn Write + Send>,
13 bytes_written: u64,
14}
15
16impl super::Format for CsvFormat {
17 fn create_writer(
18 &self,
19 schema: &SchemaRef,
20 mut writer: Box<dyn Write + Send>,
21 ) -> Result<Box<dyn super::FormatWriter>> {
22 let header = schema
23 .fields()
24 .iter()
25 .map(|f| f.name().as_str())
26 .collect::<Vec<_>>()
27 .join(",");
28 let header_bytes = header.len() as u64 + 1; writeln!(writer, "{}", header)?;
30 Ok(Box::new(CsvFormatWriter {
31 writer,
32 bytes_written: header_bytes,
33 }))
34 }
35
36 fn file_extension(&self) -> &str {
37 "csv"
38 }
39}
40
41impl super::FormatWriter for CsvFormatWriter {
42 fn write_batch(&mut self, batch: &RecordBatch) -> Result<()> {
43 let before = self.bytes_written;
44 let _ = before; let mut buf = Vec::new();
47 for row_idx in 0..batch.num_rows() {
48 let mut first = true;
49 for col_idx in 0..batch.num_columns() {
50 if !first {
51 write!(buf, ",")?;
52 }
53 first = false;
54 write_csv_value(&mut buf, batch.column(col_idx), row_idx)?;
55 }
56 writeln!(buf)?;
57 }
58 self.bytes_written += buf.len() as u64;
59 self.writer.write_all(&buf)?;
60 Ok(())
61 }
62
63 fn finish(self: Box<Self>) -> Result<()> {
64 Ok(())
65 }
66
67 fn bytes_written(&self) -> u64 {
68 self.bytes_written
69 }
70}
71
72fn write_csv_value(writer: &mut dyn Write, array: &dyn Array, idx: usize) -> Result<()> {
73 if array.is_null(idx) {
74 return Ok(());
75 }
76
77 match array.data_type() {
78 DataType::Boolean => {
79 let arr = array
80 .as_any()
81 .downcast_ref::<BooleanArray>()
82 .expect("DataType/Array mismatch");
83 write!(writer, "{}", arr.value(idx))?;
84 }
85 DataType::Int16 => {
86 let arr = array
87 .as_any()
88 .downcast_ref::<Int16Array>()
89 .expect("DataType/Array mismatch");
90 write!(writer, "{}", arr.value(idx))?;
91 }
92 DataType::Int32 => {
93 let arr = array
94 .as_any()
95 .downcast_ref::<Int32Array>()
96 .expect("DataType/Array mismatch");
97 write!(writer, "{}", arr.value(idx))?;
98 }
99 DataType::Int64 => {
100 let arr = array
101 .as_any()
102 .downcast_ref::<Int64Array>()
103 .expect("DataType/Array mismatch");
104 write!(writer, "{}", arr.value(idx))?;
105 }
106 DataType::Float32 => {
107 let arr = array
108 .as_any()
109 .downcast_ref::<Float32Array>()
110 .expect("DataType/Array mismatch");
111 write!(writer, "{}", arr.value(idx))?;
112 }
113 DataType::Float64 => {
114 let arr = array
115 .as_any()
116 .downcast_ref::<Float64Array>()
117 .expect("DataType/Array mismatch");
118 write!(writer, "{}", arr.value(idx))?;
119 }
120 DataType::Utf8 => {
121 let arr = array
122 .as_any()
123 .downcast_ref::<StringArray>()
124 .expect("DataType/Array mismatch");
125 let val = arr.value(idx);
126 if val.contains(',') || val.contains('"') || val.contains('\n') {
127 write!(writer, "\"{}\"", val.replace('"', "\"\""))?;
128 } else {
129 write!(writer, "{}", val)?;
130 }
131 }
132 DataType::Binary => {
133 let arr = array
134 .as_any()
135 .downcast_ref::<BinaryArray>()
136 .expect("DataType/Array mismatch");
137 let val = arr.value(idx);
138 for byte in val {
139 write!(writer, "{:02x}", byte)?;
140 }
141 }
142 DataType::Date32 => {
143 let arr = array
144 .as_any()
145 .downcast_ref::<Date32Array>()
146 .expect("DataType/Array mismatch");
147 let days = arr.value(idx);
148 let date = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).expect("epoch is valid")
149 + chrono::Duration::days(days as i64);
150 write!(writer, "{}", date)?;
151 }
152 DataType::Timestamp(TimeUnit::Microsecond, _) => {
153 let arr = array
154 .as_any()
155 .downcast_ref::<TimestampMicrosecondArray>()
156 .expect("DataType/Array mismatch");
157 let micros = arr.value(idx);
158 let secs = micros / 1_000_000;
159 let nsecs = ((micros % 1_000_000) * 1_000) as u32;
160 if let Some(dt) = chrono::DateTime::from_timestamp(secs, nsecs) {
161 write!(writer, "{}", dt.format("%Y-%m-%dT%H:%M:%S%.6f"))?;
162 }
163 }
164 other => {
165 log::warn!("CSV: unhandled Arrow type {:?}, skipping value", other);
166 }
167 }
168
169 Ok(())
170}