1use std::io::Write;
2
3use arrow::array::Time64MicrosecondArray;
4use arrow::array::types::Decimal128Type;
5use arrow::array::*;
6use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
7use arrow::record_batch::RecordBatch;
8
9use crate::error::Result;
10use crate::types::decimal::scaled_i128_to_decimal_str;
11
12pub struct CsvFormat;
13
14pub struct CsvFormatWriter {
15 writer: Box<dyn Write + Send>,
16 bytes_written: u64,
17}
18
19impl super::Format for CsvFormat {
20 fn create_writer(
21 &self,
22 schema: &SchemaRef,
23 mut writer: Box<dyn Write + Send>,
24 ) -> Result<Box<dyn super::FormatWriter + Send>> {
25 if let Some(field) = schema
30 .fields()
31 .iter()
32 .find(|f| !csv_serializable(f.data_type()))
33 {
34 anyhow::bail!(
35 "CSV cannot serialize column '{}' (Arrow type {:?}); use `format: parquet` \
36 or drop the column from the query",
37 field.name(),
38 field.data_type()
39 );
40 }
41 let header = schema
42 .fields()
43 .iter()
44 .map(|f| f.name().as_str())
45 .collect::<Vec<_>>()
46 .join(",");
47 let header_bytes = header.len() as u64 + 1; writeln!(writer, "{}", header)?;
49 Ok(Box::new(CsvFormatWriter {
50 writer,
51 bytes_written: header_bytes,
52 }))
53 }
54
55 fn file_extension(&self) -> &str {
56 "csv"
57 }
58}
59
60impl super::FormatWriter for CsvFormatWriter {
61 fn write_batch(&mut self, batch: &RecordBatch) -> Result<()> {
62 let mut buf = Vec::with_capacity(batch.num_rows() * batch.num_columns() * 8);
63 for row_idx in 0..batch.num_rows() {
64 for col_idx in 0..batch.num_columns() {
65 if col_idx > 0 {
66 buf.push(b',');
67 }
68 write_csv_value(&mut buf, batch.column(col_idx), row_idx)?;
69 }
70 buf.push(b'\n');
71 }
72 self.bytes_written += buf.len() as u64;
73 self.writer.write_all(&buf)?;
74 Ok(())
75 }
76
77 fn finish(self: Box<Self>) -> Result<()> {
78 Ok(())
79 }
80
81 fn bytes_written(&self) -> u64 {
82 self.bytes_written
83 }
84}
85
86pub(crate) fn csv_serializable(dt: &DataType) -> bool {
91 matches!(
92 dt,
93 DataType::Boolean
94 | DataType::Int16
95 | DataType::Int32
96 | DataType::Int64
97 | DataType::UInt64
98 | DataType::Decimal128(_, _)
99 | DataType::Float32
100 | DataType::Float64
101 | DataType::Utf8
102 | DataType::Binary
103 | DataType::FixedSizeBinary(16)
104 | DataType::Date32
105 | DataType::Time64(TimeUnit::Microsecond)
106 | DataType::Timestamp(TimeUnit::Microsecond, _)
107 )
108}
109
110fn write_csv_value(writer: &mut dyn Write, array: &dyn Array, idx: usize) -> Result<()> {
111 if array.is_null(idx) {
112 return Ok(());
113 }
114
115 match array.data_type() {
116 DataType::Boolean => {
117 let arr = array
118 .as_any()
119 .downcast_ref::<BooleanArray>()
120 .expect("DataType/Array mismatch");
121 write!(writer, "{}", arr.value(idx))?;
122 }
123 DataType::Int16 => {
124 let arr = array
125 .as_any()
126 .downcast_ref::<Int16Array>()
127 .expect("DataType/Array mismatch");
128 write!(writer, "{}", arr.value(idx))?;
129 }
130 DataType::Int32 => {
131 let arr = array
132 .as_any()
133 .downcast_ref::<Int32Array>()
134 .expect("DataType/Array mismatch");
135 write!(writer, "{}", arr.value(idx))?;
136 }
137 DataType::Int64 => {
138 let arr = array
139 .as_any()
140 .downcast_ref::<Int64Array>()
141 .expect("DataType/Array mismatch");
142 write!(writer, "{}", arr.value(idx))?;
143 }
144 DataType::UInt64 => {
145 let arr = array
146 .as_any()
147 .downcast_ref::<UInt64Array>()
148 .expect("DataType/Array mismatch");
149 write!(writer, "{}", arr.value(idx))?;
150 }
151 DataType::Decimal128(_, scale) => {
152 let arr = array.as_primitive::<Decimal128Type>();
153 let text = scaled_i128_to_decimal_str(arr.value(idx), *scale);
154 writer.write_all(text.as_bytes())?;
155 }
156 DataType::Float32 => {
157 let arr = array
158 .as_any()
159 .downcast_ref::<Float32Array>()
160 .expect("DataType/Array mismatch");
161 write!(writer, "{}", arr.value(idx))?;
162 }
163 DataType::Float64 => {
164 let arr = array
165 .as_any()
166 .downcast_ref::<Float64Array>()
167 .expect("DataType/Array mismatch");
168 write!(writer, "{}", arr.value(idx))?;
169 }
170 DataType::Utf8 => {
171 let arr = array
172 .as_any()
173 .downcast_ref::<StringArray>()
174 .expect("DataType/Array mismatch");
175 let val = arr.value(idx);
176 if val.contains(',') || val.contains('"') || val.contains('\n') {
177 writer.write_all(b"\"")?;
178 let mut rest = val;
179 while let Some(pos) = rest.find('"') {
180 writer.write_all(&rest.as_bytes()[..pos])?;
181 writer.write_all(b"\"\"")?;
182 rest = &rest[pos + 1..];
183 }
184 writer.write_all(rest.as_bytes())?;
185 writer.write_all(b"\"")?;
186 } else {
187 writer.write_all(val.as_bytes())?;
188 }
189 }
190 DataType::Binary => {
191 let arr = array
192 .as_any()
193 .downcast_ref::<BinaryArray>()
194 .expect("DataType/Array mismatch");
195 let val = arr.value(idx);
196 for byte in val {
197 write!(writer, "{:02x}", byte)?;
198 }
199 }
200 DataType::FixedSizeBinary(16) => {
208 let arr = array
209 .as_any()
210 .downcast_ref::<FixedSizeBinaryArray>()
211 .expect("DataType/Array mismatch");
212 let val = arr.value(idx);
213 let mut bytes = [0u8; 16];
214 bytes.copy_from_slice(val);
215 write!(writer, "{}", uuid::Uuid::from_bytes(bytes).to_hyphenated())?;
216 }
217 DataType::Date32 => {
218 let arr = array
219 .as_any()
220 .downcast_ref::<Date32Array>()
221 .expect("DataType/Array mismatch");
222 let days = arr.value(idx);
223 let epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).expect("epoch is valid");
229 let date =
230 chrono::Duration::try_days(days as i64).and_then(|d| epoch.checked_add_signed(d));
231 if let Some(date) = date {
232 write!(writer, "{}", date)?;
233 }
234 }
235 DataType::Time64(TimeUnit::Microsecond) => {
236 let arr = array
237 .as_any()
238 .downcast_ref::<Time64MicrosecondArray>()
239 .expect("DataType/Array mismatch");
240 let micros = arr.value(idx);
241 let secs = micros / 1_000_000;
242 let frac_us = micros % 1_000_000;
243 write!(
244 writer,
245 "{:02}:{:02}:{:02}.{:06}",
246 secs / 3600,
247 (secs % 3600) / 60,
248 secs % 60,
249 frac_us
250 )?;
251 }
252 DataType::Timestamp(TimeUnit::Microsecond, _) => {
253 let arr = array
254 .as_any()
255 .downcast_ref::<TimestampMicrosecondArray>()
256 .expect("DataType/Array mismatch");
257 let micros = arr.value(idx);
258 let secs = micros / 1_000_000;
259 let nsecs = ((micros % 1_000_000) * 1_000) as u32;
260 if let Some(dt) = chrono::DateTime::from_timestamp(secs, nsecs) {
261 write!(writer, "{}", dt.format("%Y-%m-%dT%H:%M:%S%.6f"))?;
262 }
263 }
264 other => {
265 anyhow::bail!(
269 "CSV: no serializer for Arrow type {other:?} (column should have been rejected at writer creation)"
270 );
271 }
272 }
273
274 Ok(())
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
281 use std::sync::Arc;
282
283 fn cell<A: Array + 'static>(array: A, idx: usize) -> String {
285 let mut buf = Vec::new();
286 write_csv_value(&mut buf, &array, idx).unwrap();
287 String::from_utf8(buf).unwrap()
288 }
289
290 fn null_cell(dt: DataType) -> String {
292 use arrow::array::new_null_array;
293 let arr = new_null_array(&dt, 1);
294 let mut buf = Vec::new();
295 write_csv_value(&mut buf, arr.as_ref(), 0).unwrap();
296 String::from_utf8(buf).unwrap()
297 }
298
299 #[test]
302 fn null_value_writes_empty_string() {
303 assert_eq!(null_cell(DataType::Int64), "");
304 assert_eq!(null_cell(DataType::Utf8), "");
305 assert_eq!(null_cell(DataType::Boolean), "");
306 }
307
308 #[test]
311 fn bool_true_writes_true() {
312 assert_eq!(cell(BooleanArray::from(vec![true]), 0), "true");
313 }
314
315 #[test]
316 fn bool_false_writes_false() {
317 assert_eq!(cell(BooleanArray::from(vec![false]), 0), "false");
318 }
319
320 #[test]
321 fn int16_value() {
322 assert_eq!(cell(Int16Array::from(vec![42i16]), 0), "42");
323 }
324
325 #[test]
326 fn int32_negative() {
327 assert_eq!(cell(Int32Array::from(vec![-7i32]), 0), "-7");
328 }
329
330 #[test]
331 fn decimal128_writes_exact_text() {
332 let arr = Decimal128Array::from(vec![10i128])
333 .with_precision_and_scale(18, 2)
334 .unwrap();
335 assert_eq!(cell(arr, 0), "0.10");
336 let scaled =
337 crate::types::decimal::decimal_str_to_scaled_i128("999999999999.99", 2).unwrap();
338 let arr = Decimal128Array::from(vec![scaled])
339 .with_precision_and_scale(18, 2)
340 .unwrap();
341 assert_eq!(cell(arr, 0), "999999999999.99");
342 }
343
344 #[test]
345 fn int64_large() {
346 assert_eq!(
347 cell(Int64Array::from(vec![9_999_999_999i64]), 0),
348 "9999999999"
349 );
350 }
351
352 #[test]
353 fn float32_value() {
354 let result = cell(Float32Array::from(vec![1.5f32]), 0);
355 assert!(result.starts_with("1.5"), "got: {result}");
356 }
357
358 #[test]
359 fn float64_value() {
360 let result = cell(Float64Array::from(vec![std::f64::consts::PI]), 0);
361 assert!(result.starts_with("3.14"), "got: {result}");
362 }
363
364 #[test]
375 fn float_special_values_emit_literals_not_empty() {
376 assert_eq!(cell(Float64Array::from(vec![f64::NAN]), 0), "NaN");
377 assert_eq!(cell(Float64Array::from(vec![f64::INFINITY]), 0), "inf");
378 assert_eq!(cell(Float64Array::from(vec![f64::NEG_INFINITY]), 0), "-inf");
379 assert_eq!(cell(Float32Array::from(vec![f32::NAN]), 0), "NaN");
380 assert_eq!(cell(Float32Array::from(vec![f32::INFINITY]), 0), "inf");
381 assert_eq!(cell(Float64Array::from(vec![-0.0f64]), 0), "-0");
383 }
384
385 #[test]
388 fn plain_string_no_quoting() {
389 assert_eq!(cell(StringArray::from(vec!["hello"]), 0), "hello");
390 }
391
392 #[test]
393 fn string_with_comma_is_quoted() {
394 assert_eq!(cell(StringArray::from(vec!["a,b"]), 0), "\"a,b\"");
395 }
396
397 #[test]
398 fn string_with_double_quote_is_escaped() {
399 let result = cell(StringArray::from(vec![r#"say "hi""#]), 0);
401 assert_eq!(result, r#""say ""hi""""#);
402 }
403
404 #[test]
405 fn string_with_newline_is_quoted() {
406 let result = cell(StringArray::from(vec!["line1\nline2"]), 0);
407 assert!(
408 result.starts_with('"') && result.ends_with('"'),
409 "got: {result}"
410 );
411 assert!(result.contains("line1\nline2"), "got: {result}");
412 }
413
414 #[test]
417 fn binary_is_written_as_hex() {
418 let arr = BinaryArray::from_vec(vec![&[0xDE, 0xAD, 0xBE, 0xEF][..]]);
419 assert_eq!(cell(arr, 0), "deadbeef");
420 }
421
422 #[test]
423 fn binary_empty_writes_empty() {
424 let arr = BinaryArray::from_vec(vec![&[][..]]);
425 assert_eq!(cell(arr, 0), "");
426 }
427
428 #[test]
431 fn date32_epoch_is_1970_01_01() {
432 assert_eq!(cell(Date32Array::from(vec![0i32]), 0), "1970-01-01");
433 }
434
435 #[test]
436 fn date32_positive_offset() {
437 assert_eq!(cell(Date32Array::from(vec![365i32]), 0), "1971-01-01");
439 }
440
441 #[test]
444 fn timestamp_micros_formats_as_iso() {
445 let micros: i64 = 1_672_531_200 * 1_000_000;
447 let _schema = Arc::new(Schema::new(vec![Field::new(
448 "ts",
449 DataType::Timestamp(TimeUnit::Microsecond, None),
450 true,
451 )]));
452 let arr = TimestampMicrosecondArray::from(vec![micros]);
453 let result = cell(arr, 0);
454 assert!(result.starts_with("2023-01-01T"), "got: {result}");
455 assert!(result.contains("00:00:00"), "got: {result}");
456 }
457
458 #[test]
461 fn csv_format_write_batch_tracks_bytes_and_succeeds() {
462 use crate::format::Format;
463
464 let schema = Arc::new(Schema::new(vec![
465 Field::new("id", DataType::Int64, false),
466 Field::new("name", DataType::Utf8, true),
467 ]));
468 let batch = arrow::record_batch::RecordBatch::try_new(
469 schema.clone(),
470 vec![
471 Arc::new(Int64Array::from(vec![1i64, 2])),
472 Arc::new(StringArray::from(vec![Some("alice"), None])),
473 ],
474 )
475 .unwrap();
476
477 let fmt = CsvFormat;
479 let mut writer = fmt
480 .create_writer(&schema, Box::new(Vec::<u8>::new()))
481 .unwrap();
482 writer.write_batch(&batch).unwrap();
483 assert!(
485 writer.bytes_written() > 10,
486 "expected >10 bytes, got {}",
487 writer.bytes_written()
488 );
489 writer.finish().unwrap();
490 }
491
492 #[test]
495 fn csv_rejects_array_columns_loudly() {
496 use crate::format::Format;
497 let schema = Arc::new(Schema::new(vec![
498 Field::new("id", DataType::Int64, false),
499 Field::new(
500 "tags",
501 DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
502 true,
503 ),
504 ]));
505 let Err(err) = CsvFormat.create_writer(&schema, Box::new(Vec::<u8>::new())) else {
506 panic!("CSV must reject array columns, not silently drop them");
507 };
508 let msg = format!("{err:#}");
509 assert!(msg.contains("tags"), "error must name the column: {msg}");
510 assert!(msg.to_lowercase().contains("csv"), "{msg}");
511 }
512
513 #[test]
517 fn every_serializable_type_is_actually_written() {
518 use crate::format::Format;
519 let cols: Vec<(&str, ArrayRef)> = vec![
520 ("b", Arc::new(BooleanArray::from(vec![true]))),
521 ("i16", Arc::new(Int16Array::from(vec![1i16]))),
522 ("i32", Arc::new(Int32Array::from(vec![1i32]))),
523 ("i64", Arc::new(Int64Array::from(vec![1i64]))),
524 ("u64", Arc::new(UInt64Array::from(vec![1u64]))),
525 (
526 "dec",
527 Arc::new(
528 Decimal128Array::from(vec![100i128])
529 .with_precision_and_scale(18, 2)
530 .unwrap(),
531 ),
532 ),
533 ("f32", Arc::new(Float32Array::from(vec![1.0f32]))),
534 ("f64", Arc::new(Float64Array::from(vec![1.0f64]))),
535 ("s", Arc::new(StringArray::from(vec!["x"]))),
536 ("bin", Arc::new(BinaryArray::from_vec(vec![&[1u8][..]]))),
537 (
538 "uuid",
539 Arc::new(
540 FixedSizeBinaryArray::try_from_iter(std::iter::once(vec![0u8; 16])).unwrap(),
541 ),
542 ),
543 ("d", Arc::new(Date32Array::from(vec![0i32]))),
544 ("t", Arc::new(Time64MicrosecondArray::from(vec![0i64]))),
545 ("ts", Arc::new(TimestampMicrosecondArray::from(vec![0i64]))),
546 ];
547 let fields: Vec<Field> = cols
548 .iter()
549 .map(|(n, a)| Field::new(*n, a.data_type().clone(), true))
550 .collect();
551 for f in &fields {
553 assert!(
554 csv_serializable(f.data_type()),
555 "test type {:?} not in csv_serializable",
556 f.data_type()
557 );
558 }
559 let schema = Arc::new(Schema::new(fields));
560 let arrays: Vec<ArrayRef> = cols.into_iter().map(|(_, a)| a).collect();
561 let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
562 let mut w = CsvFormat
563 .create_writer(&schema, Box::new(Vec::<u8>::new()))
564 .unwrap();
565 w.write_batch(&batch)
566 .expect("every serializable type must write without hitting the fallthrough");
567 }
568}