1use std::io::Write;
2
3use arrow::array::*;
4use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
5use arrow::record_batch::RecordBatch;
6
7use crate::error::Result;
8
9pub struct CsvFormat;
10
11pub struct CsvFormatWriter {
12 writer: Box<dyn Write + Send>,
13 bytes_written: u64,
14}
15
16impl super::Format for CsvFormat {
17 fn create_writer(
18 &self,
19 schema: &SchemaRef,
20 mut writer: Box<dyn Write + Send>,
21 ) -> Result<Box<dyn super::FormatWriter>> {
22 let header = schema
23 .fields()
24 .iter()
25 .map(|f| f.name().as_str())
26 .collect::<Vec<_>>()
27 .join(",");
28 let header_bytes = header.len() as u64 + 1; writeln!(writer, "{}", header)?;
30 Ok(Box::new(CsvFormatWriter {
31 writer,
32 bytes_written: header_bytes,
33 }))
34 }
35
36 fn file_extension(&self) -> &str {
37 "csv"
38 }
39}
40
41impl super::FormatWriter for CsvFormatWriter {
42 fn write_batch(&mut self, batch: &RecordBatch) -> Result<()> {
43 let mut buf = Vec::with_capacity(batch.num_rows() * batch.num_columns() * 8);
44 for row_idx in 0..batch.num_rows() {
45 for col_idx in 0..batch.num_columns() {
46 if col_idx > 0 {
47 buf.push(b',');
48 }
49 write_csv_value(&mut buf, batch.column(col_idx), row_idx)?;
50 }
51 buf.push(b'\n');
52 }
53 self.bytes_written += buf.len() as u64;
54 self.writer.write_all(&buf)?;
55 Ok(())
56 }
57
58 fn finish(self: Box<Self>) -> Result<()> {
59 Ok(())
60 }
61
62 fn bytes_written(&self) -> u64 {
63 self.bytes_written
64 }
65}
66
67fn write_csv_value(writer: &mut dyn Write, array: &dyn Array, idx: usize) -> Result<()> {
68 if array.is_null(idx) {
69 return Ok(());
70 }
71
72 match array.data_type() {
73 DataType::Boolean => {
74 let arr = array
75 .as_any()
76 .downcast_ref::<BooleanArray>()
77 .expect("DataType/Array mismatch");
78 write!(writer, "{}", arr.value(idx))?;
79 }
80 DataType::Int16 => {
81 let arr = array
82 .as_any()
83 .downcast_ref::<Int16Array>()
84 .expect("DataType/Array mismatch");
85 write!(writer, "{}", arr.value(idx))?;
86 }
87 DataType::Int32 => {
88 let arr = array
89 .as_any()
90 .downcast_ref::<Int32Array>()
91 .expect("DataType/Array mismatch");
92 write!(writer, "{}", arr.value(idx))?;
93 }
94 DataType::Int64 => {
95 let arr = array
96 .as_any()
97 .downcast_ref::<Int64Array>()
98 .expect("DataType/Array mismatch");
99 write!(writer, "{}", arr.value(idx))?;
100 }
101 DataType::Float32 => {
102 let arr = array
103 .as_any()
104 .downcast_ref::<Float32Array>()
105 .expect("DataType/Array mismatch");
106 write!(writer, "{}", arr.value(idx))?;
107 }
108 DataType::Float64 => {
109 let arr = array
110 .as_any()
111 .downcast_ref::<Float64Array>()
112 .expect("DataType/Array mismatch");
113 write!(writer, "{}", arr.value(idx))?;
114 }
115 DataType::Utf8 => {
116 let arr = array
117 .as_any()
118 .downcast_ref::<StringArray>()
119 .expect("DataType/Array mismatch");
120 let val = arr.value(idx);
121 if val.contains(',') || val.contains('"') || val.contains('\n') {
122 writer.write_all(b"\"")?;
123 let mut rest = val;
124 while let Some(pos) = rest.find('"') {
125 writer.write_all(&rest.as_bytes()[..pos])?;
126 writer.write_all(b"\"\"")?;
127 rest = &rest[pos + 1..];
128 }
129 writer.write_all(rest.as_bytes())?;
130 writer.write_all(b"\"")?;
131 } else {
132 writer.write_all(val.as_bytes())?;
133 }
134 }
135 DataType::Binary => {
136 let arr = array
137 .as_any()
138 .downcast_ref::<BinaryArray>()
139 .expect("DataType/Array mismatch");
140 let val = arr.value(idx);
141 for byte in val {
142 write!(writer, "{:02x}", byte)?;
143 }
144 }
145 DataType::Date32 => {
146 let arr = array
147 .as_any()
148 .downcast_ref::<Date32Array>()
149 .expect("DataType/Array mismatch");
150 let days = arr.value(idx);
151 let epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).expect("epoch is valid");
157 let date =
158 chrono::Duration::try_days(days as i64).and_then(|d| epoch.checked_add_signed(d));
159 if let Some(date) = date {
160 write!(writer, "{}", date)?;
161 }
162 }
163 DataType::Timestamp(TimeUnit::Microsecond, _) => {
164 let arr = array
165 .as_any()
166 .downcast_ref::<TimestampMicrosecondArray>()
167 .expect("DataType/Array mismatch");
168 let micros = arr.value(idx);
169 let secs = micros / 1_000_000;
170 let nsecs = ((micros % 1_000_000) * 1_000) as u32;
171 if let Some(dt) = chrono::DateTime::from_timestamp(secs, nsecs) {
172 write!(writer, "{}", dt.format("%Y-%m-%dT%H:%M:%S%.6f"))?;
173 }
174 }
175 other => {
176 log::warn!("CSV: unhandled Arrow type {:?}, skipping value", other);
177 }
178 }
179
180 Ok(())
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
187 use std::sync::Arc;
188
189 fn cell<A: Array + 'static>(array: A, idx: usize) -> String {
191 let mut buf = Vec::new();
192 write_csv_value(&mut buf, &array, idx).unwrap();
193 String::from_utf8(buf).unwrap()
194 }
195
196 fn null_cell(dt: DataType) -> String {
198 use arrow::array::new_null_array;
199 let arr = new_null_array(&dt, 1);
200 let mut buf = Vec::new();
201 write_csv_value(&mut buf, arr.as_ref(), 0).unwrap();
202 String::from_utf8(buf).unwrap()
203 }
204
205 #[test]
208 fn null_value_writes_empty_string() {
209 assert_eq!(null_cell(DataType::Int64), "");
210 assert_eq!(null_cell(DataType::Utf8), "");
211 assert_eq!(null_cell(DataType::Boolean), "");
212 }
213
214 #[test]
217 fn bool_true_writes_true() {
218 assert_eq!(cell(BooleanArray::from(vec![true]), 0), "true");
219 }
220
221 #[test]
222 fn bool_false_writes_false() {
223 assert_eq!(cell(BooleanArray::from(vec![false]), 0), "false");
224 }
225
226 #[test]
227 fn int16_value() {
228 assert_eq!(cell(Int16Array::from(vec![42i16]), 0), "42");
229 }
230
231 #[test]
232 fn int32_negative() {
233 assert_eq!(cell(Int32Array::from(vec![-7i32]), 0), "-7");
234 }
235
236 #[test]
237 fn int64_large() {
238 assert_eq!(
239 cell(Int64Array::from(vec![9_999_999_999i64]), 0),
240 "9999999999"
241 );
242 }
243
244 #[test]
245 fn float32_value() {
246 let result = cell(Float32Array::from(vec![1.5f32]), 0);
247 assert!(result.starts_with("1.5"), "got: {result}");
248 }
249
250 #[test]
251 fn float64_value() {
252 let result = cell(Float64Array::from(vec![std::f64::consts::PI]), 0);
253 assert!(result.starts_with("3.14"), "got: {result}");
254 }
255
256 #[test]
259 fn plain_string_no_quoting() {
260 assert_eq!(cell(StringArray::from(vec!["hello"]), 0), "hello");
261 }
262
263 #[test]
264 fn string_with_comma_is_quoted() {
265 assert_eq!(cell(StringArray::from(vec!["a,b"]), 0), "\"a,b\"");
266 }
267
268 #[test]
269 fn string_with_double_quote_is_escaped() {
270 let result = cell(StringArray::from(vec![r#"say "hi""#]), 0);
272 assert_eq!(result, r#""say ""hi""""#);
273 }
274
275 #[test]
276 fn string_with_newline_is_quoted() {
277 let result = cell(StringArray::from(vec!["line1\nline2"]), 0);
278 assert!(
279 result.starts_with('"') && result.ends_with('"'),
280 "got: {result}"
281 );
282 assert!(result.contains("line1\nline2"), "got: {result}");
283 }
284
285 #[test]
288 fn binary_is_written_as_hex() {
289 let arr = BinaryArray::from_vec(vec![&[0xDE, 0xAD, 0xBE, 0xEF][..]]);
290 assert_eq!(cell(arr, 0), "deadbeef");
291 }
292
293 #[test]
294 fn binary_empty_writes_empty() {
295 let arr = BinaryArray::from_vec(vec![&[][..]]);
296 assert_eq!(cell(arr, 0), "");
297 }
298
299 #[test]
302 fn date32_epoch_is_1970_01_01() {
303 assert_eq!(cell(Date32Array::from(vec![0i32]), 0), "1970-01-01");
304 }
305
306 #[test]
307 fn date32_positive_offset() {
308 assert_eq!(cell(Date32Array::from(vec![365i32]), 0), "1971-01-01");
310 }
311
312 #[test]
315 fn timestamp_micros_formats_as_iso() {
316 let micros: i64 = 1_672_531_200 * 1_000_000;
318 let _schema = Arc::new(Schema::new(vec![Field::new(
319 "ts",
320 DataType::Timestamp(TimeUnit::Microsecond, None),
321 true,
322 )]));
323 let arr = TimestampMicrosecondArray::from(vec![micros]);
324 let result = cell(arr, 0);
325 assert!(result.starts_with("2023-01-01T"), "got: {result}");
326 assert!(result.contains("00:00:00"), "got: {result}");
327 }
328
329 #[test]
332 fn csv_format_write_batch_tracks_bytes_and_succeeds() {
333 use crate::format::Format;
334
335 let schema = Arc::new(Schema::new(vec![
336 Field::new("id", DataType::Int64, false),
337 Field::new("name", DataType::Utf8, true),
338 ]));
339 let batch = arrow::record_batch::RecordBatch::try_new(
340 schema.clone(),
341 vec![
342 Arc::new(Int64Array::from(vec![1i64, 2])),
343 Arc::new(StringArray::from(vec![Some("alice"), None])),
344 ],
345 )
346 .unwrap();
347
348 let fmt = CsvFormat;
350 let mut writer = fmt
351 .create_writer(&schema, Box::new(Vec::<u8>::new()))
352 .unwrap();
353 writer.write_batch(&batch).unwrap();
354 assert!(
356 writer.bytes_written() > 10,
357 "expected >10 bytes, got {}",
358 writer.bytes_written()
359 );
360 writer.finish().unwrap();
361 }
362}