1use std::io::Write;
2
3use arrow::array::Time64MicrosecondArray;
4use arrow::array::types::Decimal128Type;
5use arrow::array::*;
6use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
7use arrow::record_batch::RecordBatch;
8
9use crate::error::Result;
10use crate::types::decimal::scaled_i128_to_decimal_str;
11
12pub struct CsvFormat;
13
14pub struct CsvFormatWriter {
15 writer: Box<dyn Write + Send>,
16 bytes_written: u64,
17}
18
19impl super::Format for CsvFormat {
20 fn create_writer(
21 &self,
22 schema: &SchemaRef,
23 mut writer: Box<dyn Write + Send>,
24 ) -> Result<Box<dyn super::FormatWriter + Send>> {
25 if let Some(field) = schema
30 .fields()
31 .iter()
32 .find(|f| !csv_serializable(f.data_type()))
33 {
34 anyhow::bail!(
35 "CSV cannot serialize column '{}' (Arrow type {:?}); use `format: parquet` \
36 or drop the column from the query",
37 field.name(),
38 field.data_type()
39 );
40 }
41 let header = schema
42 .fields()
43 .iter()
44 .map(|f| f.name().as_str())
45 .collect::<Vec<_>>()
46 .join(",");
47 let header_bytes = header.len() as u64 + 1; writeln!(writer, "{}", header)?;
49 Ok(Box::new(CsvFormatWriter {
50 writer,
51 bytes_written: header_bytes,
52 }))
53 }
54
55 fn file_extension(&self) -> &str {
56 "csv"
57 }
58}
59
60impl super::FormatWriter for CsvFormatWriter {
61 fn write_batch(&mut self, batch: &RecordBatch) -> Result<()> {
62 let mut buf = Vec::with_capacity(batch.num_rows() * batch.num_columns() * 8);
63 for row_idx in 0..batch.num_rows() {
64 for col_idx in 0..batch.num_columns() {
65 if col_idx > 0 {
66 buf.push(b',');
67 }
68 write_csv_value(&mut buf, batch.column(col_idx), row_idx)?;
69 }
70 buf.push(b'\n');
71 }
72 self.bytes_written += buf.len() as u64;
73 self.writer.write_all(&buf)?;
74 Ok(())
75 }
76
77 fn finish(self: Box<Self>) -> Result<()> {
78 Ok(())
79 }
80
81 fn bytes_written(&self) -> u64 {
82 self.bytes_written
83 }
84}
85
86pub(crate) fn csv_serializable(dt: &DataType) -> bool {
91 matches!(
92 dt,
93 DataType::Boolean
94 | DataType::Int16
95 | DataType::Int32
96 | DataType::Int64
97 | DataType::UInt64
98 | DataType::Decimal128(_, _)
99 | DataType::Float32
100 | DataType::Float64
101 | DataType::Utf8
102 | DataType::Binary
103 | DataType::FixedSizeBinary(16)
104 | DataType::Date32
105 | DataType::Time64(TimeUnit::Microsecond)
106 | DataType::Timestamp(TimeUnit::Microsecond, _)
107 )
108}
109
110fn write_lower_hex(writer: &mut dyn Write, bytes: &[u8]) -> Result<()> {
116 const HEX: &[u8; 16] = b"0123456789abcdef";
117 let mut chunk = [0u8; 1024];
118 for slab in bytes.chunks(chunk.len() / 2) {
119 let mut n = 0;
120 for &b in slab {
121 chunk[n] = HEX[(b >> 4) as usize];
122 chunk[n + 1] = HEX[(b & 0x0f) as usize];
123 n += 2;
124 }
125 writer.write_all(&chunk[..n])?;
126 }
127 Ok(())
128}
129
130fn write_csv_value(writer: &mut dyn Write, array: &dyn Array, idx: usize) -> Result<()> {
131 if array.is_null(idx) {
132 return Ok(());
133 }
134
135 match array.data_type() {
136 DataType::Boolean => {
137 let arr = array
138 .as_any()
139 .downcast_ref::<BooleanArray>()
140 .expect("DataType/Array mismatch");
141 write!(writer, "{}", arr.value(idx))?;
142 }
143 DataType::Int16 => {
144 let arr = array
145 .as_any()
146 .downcast_ref::<Int16Array>()
147 .expect("DataType/Array mismatch");
148 write!(writer, "{}", arr.value(idx))?;
149 }
150 DataType::Int32 => {
151 let arr = array
152 .as_any()
153 .downcast_ref::<Int32Array>()
154 .expect("DataType/Array mismatch");
155 write!(writer, "{}", arr.value(idx))?;
156 }
157 DataType::Int64 => {
158 let arr = array
159 .as_any()
160 .downcast_ref::<Int64Array>()
161 .expect("DataType/Array mismatch");
162 write!(writer, "{}", arr.value(idx))?;
163 }
164 DataType::UInt64 => {
165 let arr = array
166 .as_any()
167 .downcast_ref::<UInt64Array>()
168 .expect("DataType/Array mismatch");
169 write!(writer, "{}", arr.value(idx))?;
170 }
171 DataType::Decimal128(_, scale) => {
172 let arr = array.as_primitive::<Decimal128Type>();
173 let text = scaled_i128_to_decimal_str(arr.value(idx), *scale);
174 writer.write_all(text.as_bytes())?;
175 }
176 DataType::Float32 => {
177 let arr = array
178 .as_any()
179 .downcast_ref::<Float32Array>()
180 .expect("DataType/Array mismatch");
181 write!(writer, "{}", arr.value(idx))?;
182 }
183 DataType::Float64 => {
184 let arr = array
185 .as_any()
186 .downcast_ref::<Float64Array>()
187 .expect("DataType/Array mismatch");
188 write!(writer, "{}", arr.value(idx))?;
189 }
190 DataType::Utf8 => {
191 let arr = array
192 .as_any()
193 .downcast_ref::<StringArray>()
194 .expect("DataType/Array mismatch");
195 let val = arr.value(idx);
196 if val
200 .bytes()
201 .any(|b| matches!(b, b',' | b'"' | b'\n' | b'\r'))
202 {
203 writer.write_all(b"\"")?;
204 let mut rest = val;
205 while let Some(pos) = rest.find('"') {
206 writer.write_all(&rest.as_bytes()[..pos])?;
207 writer.write_all(b"\"\"")?;
208 rest = &rest[pos + 1..];
209 }
210 writer.write_all(rest.as_bytes())?;
211 writer.write_all(b"\"")?;
212 } else {
213 writer.write_all(val.as_bytes())?;
214 }
215 }
216 DataType::Binary => {
217 let arr = array
218 .as_any()
219 .downcast_ref::<BinaryArray>()
220 .expect("DataType/Array mismatch");
221 write_lower_hex(writer, arr.value(idx))?;
222 }
223 DataType::FixedSizeBinary(16) => {
231 let arr = array
232 .as_any()
233 .downcast_ref::<FixedSizeBinaryArray>()
234 .expect("DataType/Array mismatch");
235 let val = arr.value(idx);
236 let mut bytes = [0u8; 16];
237 bytes.copy_from_slice(val);
238 write!(writer, "{}", uuid::Uuid::from_bytes(bytes).to_hyphenated())?;
239 }
240 DataType::Date32 => {
241 let arr = array
242 .as_any()
243 .downcast_ref::<Date32Array>()
244 .expect("DataType/Array mismatch");
245 let days = arr.value(idx);
246 let epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).expect("epoch is valid");
252 let date =
253 chrono::Duration::try_days(days as i64).and_then(|d| epoch.checked_add_signed(d));
254 if let Some(date) = date {
255 write!(writer, "{}", date)?;
256 }
257 }
258 DataType::Time64(TimeUnit::Microsecond) => {
259 let arr = array
260 .as_any()
261 .downcast_ref::<Time64MicrosecondArray>()
262 .expect("DataType/Array mismatch");
263 let micros = arr.value(idx);
264 let secs = micros / 1_000_000;
265 let frac_us = micros % 1_000_000;
266 write!(
267 writer,
268 "{:02}:{:02}:{:02}.{:06}",
269 secs / 3600,
270 (secs % 3600) / 60,
271 secs % 60,
272 frac_us
273 )?;
274 }
275 DataType::Timestamp(TimeUnit::Microsecond, _) => {
276 let arr = array
277 .as_any()
278 .downcast_ref::<TimestampMicrosecondArray>()
279 .expect("DataType/Array mismatch");
280 let micros = arr.value(idx);
281 let secs = micros / 1_000_000;
282 let nsecs = ((micros % 1_000_000) * 1_000) as u32;
283 if let Some(dt) = chrono::DateTime::from_timestamp(secs, nsecs) {
284 use chrono::{Datelike as _, Timelike as _};
285 let y = dt.year();
286 if (0..=9999).contains(&y) {
292 write!(
293 writer,
294 "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}.{:06}",
295 y,
296 dt.month(),
297 dt.day(),
298 dt.hour(),
299 dt.minute(),
300 dt.second(),
301 dt.nanosecond() / 1_000
302 )?;
303 } else {
304 write!(writer, "{}", dt.format("%Y-%m-%dT%H:%M:%S%.6f"))?;
305 }
306 }
307 }
308 other => {
309 anyhow::bail!(
313 "CSV: no serializer for Arrow type {other:?} (column should have been rejected at writer creation)"
314 );
315 }
316 }
317
318 Ok(())
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
325 use std::sync::Arc;
326
327 fn cell<A: Array + 'static>(array: A, idx: usize) -> String {
329 let mut buf = Vec::new();
330 write_csv_value(&mut buf, &array, idx).unwrap();
331 String::from_utf8(buf).unwrap()
332 }
333
334 fn null_cell(dt: DataType) -> String {
336 use arrow::array::new_null_array;
337 let arr = new_null_array(&dt, 1);
338 let mut buf = Vec::new();
339 write_csv_value(&mut buf, arr.as_ref(), 0).unwrap();
340 String::from_utf8(buf).unwrap()
341 }
342
343 #[test]
346 fn null_value_writes_empty_string() {
347 assert_eq!(null_cell(DataType::Int64), "");
348 assert_eq!(null_cell(DataType::Utf8), "");
349 assert_eq!(null_cell(DataType::Boolean), "");
350 }
351
352 #[test]
355 fn bool_true_writes_true() {
356 assert_eq!(cell(BooleanArray::from(vec![true]), 0), "true");
357 }
358
359 #[test]
360 fn bool_false_writes_false() {
361 assert_eq!(cell(BooleanArray::from(vec![false]), 0), "false");
362 }
363
364 #[test]
365 fn int16_value() {
366 assert_eq!(cell(Int16Array::from(vec![42i16]), 0), "42");
367 }
368
369 #[test]
370 fn int32_negative() {
371 assert_eq!(cell(Int32Array::from(vec![-7i32]), 0), "-7");
372 }
373
374 #[test]
375 fn decimal128_writes_exact_text() {
376 let arr = Decimal128Array::from(vec![10i128])
377 .with_precision_and_scale(18, 2)
378 .unwrap();
379 assert_eq!(cell(arr, 0), "0.10");
380 let scaled =
381 crate::types::decimal::decimal_str_to_scaled_i128("999999999999.99", 2).unwrap();
382 let arr = Decimal128Array::from(vec![scaled])
383 .with_precision_and_scale(18, 2)
384 .unwrap();
385 assert_eq!(cell(arr, 0), "999999999999.99");
386 }
387
388 #[test]
389 fn int64_large() {
390 assert_eq!(
391 cell(Int64Array::from(vec![9_999_999_999i64]), 0),
392 "9999999999"
393 );
394 }
395
396 #[test]
397 fn float32_value() {
398 let result = cell(Float32Array::from(vec![1.5f32]), 0);
399 assert!(result.starts_with("1.5"), "got: {result}");
400 }
401
402 #[test]
403 fn float64_value() {
404 let result = cell(Float64Array::from(vec![std::f64::consts::PI]), 0);
405 assert!(result.starts_with("3.14"), "got: {result}");
406 }
407
408 #[test]
419 fn float_special_values_emit_literals_not_empty() {
420 assert_eq!(cell(Float64Array::from(vec![f64::NAN]), 0), "NaN");
421 assert_eq!(cell(Float64Array::from(vec![f64::INFINITY]), 0), "inf");
422 assert_eq!(cell(Float64Array::from(vec![f64::NEG_INFINITY]), 0), "-inf");
423 assert_eq!(cell(Float32Array::from(vec![f32::NAN]), 0), "NaN");
424 assert_eq!(cell(Float32Array::from(vec![f32::INFINITY]), 0), "inf");
425 assert_eq!(cell(Float64Array::from(vec![-0.0f64]), 0), "-0");
427 }
428
429 #[test]
432 fn plain_string_no_quoting() {
433 assert_eq!(cell(StringArray::from(vec!["hello"]), 0), "hello");
434 }
435
436 #[test]
437 fn string_with_comma_is_quoted() {
438 assert_eq!(cell(StringArray::from(vec!["a,b"]), 0), "\"a,b\"");
439 }
440
441 #[test]
442 fn string_with_double_quote_is_escaped() {
443 let result = cell(StringArray::from(vec![r#"say "hi""#]), 0);
445 assert_eq!(result, r#""say ""hi""""#);
446 }
447
448 #[test]
449 fn string_with_newline_is_quoted() {
450 let result = cell(StringArray::from(vec!["line1\nline2"]), 0);
451 assert!(
452 result.starts_with('"') && result.ends_with('"'),
453 "got: {result}"
454 );
455 assert!(result.contains("line1\nline2"), "got: {result}");
456 }
457
458 #[test]
464 fn roast_string_with_carriage_return_is_quoted() {
465 let result = cell(StringArray::from(vec!["a\rb"]), 0);
466 assert_eq!(
467 result, "\"a\rb\"",
468 "lone CR must force quoting per RFC 4180, but got unquoted cell {result:?}"
469 );
470 }
471
472 #[test]
475 fn binary_is_written_as_hex() {
476 let arr = BinaryArray::from_vec(vec![&[0xDE, 0xAD, 0xBE, 0xEF][..]]);
477 assert_eq!(cell(arr, 0), "deadbeef");
478 }
479
480 #[test]
481 fn binary_empty_writes_empty() {
482 let arr = BinaryArray::from_vec(vec![&[][..]]);
483 assert_eq!(cell(arr, 0), "");
484 }
485
486 #[test]
489 fn date32_epoch_is_1970_01_01() {
490 assert_eq!(cell(Date32Array::from(vec![0i32]), 0), "1970-01-01");
491 }
492
493 #[test]
494 fn date32_positive_offset() {
495 assert_eq!(cell(Date32Array::from(vec![365i32]), 0), "1971-01-01");
497 }
498
499 #[test]
502 fn timestamp_micros_formats_as_iso() {
503 let micros: i64 = 1_672_531_200 * 1_000_000;
505 let _schema = Arc::new(Schema::new(vec![Field::new(
506 "ts",
507 DataType::Timestamp(TimeUnit::Microsecond, None),
508 true,
509 )]));
510 let arr = TimestampMicrosecondArray::from(vec![micros]);
511 let result = cell(arr, 0);
512 assert!(result.starts_with("2023-01-01T"), "got: {result}");
513 assert!(result.contains("00:00:00"), "got: {result}");
514 }
515
516 #[test]
519 fn csv_format_write_batch_tracks_bytes_and_succeeds() {
520 use crate::format::Format;
521
522 let schema = Arc::new(Schema::new(vec![
523 Field::new("id", DataType::Int64, false),
524 Field::new("name", DataType::Utf8, true),
525 ]));
526 let batch = arrow::record_batch::RecordBatch::try_new(
527 schema.clone(),
528 vec![
529 Arc::new(Int64Array::from(vec![1i64, 2])),
530 Arc::new(StringArray::from(vec![Some("alice"), None])),
531 ],
532 )
533 .unwrap();
534
535 let fmt = CsvFormat;
537 let mut writer = fmt
538 .create_writer(&schema, Box::new(Vec::<u8>::new()))
539 .unwrap();
540 writer.write_batch(&batch).unwrap();
541 assert!(
543 writer.bytes_written() > 10,
544 "expected >10 bytes, got {}",
545 writer.bytes_written()
546 );
547 writer.finish().unwrap();
548 }
549
550 #[test]
553 fn csv_rejects_array_columns_loudly() {
554 use crate::format::Format;
555 let schema = Arc::new(Schema::new(vec![
556 Field::new("id", DataType::Int64, false),
557 Field::new(
558 "tags",
559 DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
560 true,
561 ),
562 ]));
563 let Err(err) = CsvFormat.create_writer(&schema, Box::new(Vec::<u8>::new())) else {
564 panic!("CSV must reject array columns, not silently drop them");
565 };
566 let msg = format!("{err:#}");
567 assert!(msg.contains("tags"), "error must name the column: {msg}");
568 assert!(msg.to_lowercase().contains("csv"), "{msg}");
569 }
570
571 #[test]
575 fn every_serializable_type_is_actually_written() {
576 use crate::format::Format;
577 let cols: Vec<(&str, ArrayRef)> = vec![
578 ("b", Arc::new(BooleanArray::from(vec![true]))),
579 ("i16", Arc::new(Int16Array::from(vec![1i16]))),
580 ("i32", Arc::new(Int32Array::from(vec![1i32]))),
581 ("i64", Arc::new(Int64Array::from(vec![1i64]))),
582 ("u64", Arc::new(UInt64Array::from(vec![1u64]))),
583 (
584 "dec",
585 Arc::new(
586 Decimal128Array::from(vec![100i128])
587 .with_precision_and_scale(18, 2)
588 .unwrap(),
589 ),
590 ),
591 ("f32", Arc::new(Float32Array::from(vec![1.0f32]))),
592 ("f64", Arc::new(Float64Array::from(vec![1.0f64]))),
593 ("s", Arc::new(StringArray::from(vec!["x"]))),
594 ("bin", Arc::new(BinaryArray::from_vec(vec![&[1u8][..]]))),
595 (
596 "uuid",
597 Arc::new(
598 FixedSizeBinaryArray::try_from_iter(std::iter::once(vec![0u8; 16])).unwrap(),
599 ),
600 ),
601 ("d", Arc::new(Date32Array::from(vec![0i32]))),
602 ("t", Arc::new(Time64MicrosecondArray::from(vec![0i64]))),
603 ("ts", Arc::new(TimestampMicrosecondArray::from(vec![0i64]))),
604 ];
605 let fields: Vec<Field> = cols
606 .iter()
607 .map(|(n, a)| Field::new(*n, a.data_type().clone(), true))
608 .collect();
609 for f in &fields {
611 assert!(
612 csv_serializable(f.data_type()),
613 "test type {:?} not in csv_serializable",
614 f.data_type()
615 );
616 }
617 let schema = Arc::new(Schema::new(fields));
618 let arrays: Vec<ArrayRef> = cols.into_iter().map(|(_, a)| a).collect();
619 let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
620 let mut w = CsvFormat
621 .create_writer(&schema, Box::new(Vec::<u8>::new()))
622 .unwrap();
623 w.write_batch(&batch)
624 .expect("every serializable type must write without hitting the fallthrough");
625 }
626
627 #[test]
632 fn binary_hex_matches_per_byte_format_for_all_byte_values() {
633 let all: Vec<u8> = (0..=255u8).collect();
636 for case in [&all[..], &[][..], &[0x00, 0xff, 0x10, 0x0a]] {
637 let expected: String = case.iter().map(|b| format!("{b:02x}")).collect();
638 let got = cell(BinaryArray::from_vec(vec![case]), 0);
639 assert_eq!(got, expected, "hex mismatch for {case:?}");
640 }
641 }
642
643 #[test]
644 fn binary_hex_spans_chunk_boundary() {
645 let big: Vec<u8> = (0..2000u32).map(|i| (i % 256) as u8).collect();
647 let expected: String = big.iter().map(|b| format!("{b:02x}")).collect();
648 let got = cell(BinaryArray::from_vec(vec![&big[..]]), 0);
649 assert_eq!(got, expected);
650 }
651
652 #[test]
653 fn timestamp_fast_path_matches_chrono_format() {
654 let cases: [i64; 6] = [
658 0,
659 1_700_000_000_123_456, 1_000_000 * 86_399 + 999_999, -1, 253_402_300_799_000_000, 300_000_000_000_000_000, ];
665 for micros in cases {
666 let got = cell(TimestampMicrosecondArray::from(vec![micros]), 0);
667 let secs = micros / 1_000_000;
668 let nsecs = ((micros % 1_000_000) * 1_000) as u32;
669 let expected = match chrono::DateTime::from_timestamp(secs, nsecs) {
670 Some(dt) => format!("{}", dt.format("%Y-%m-%dT%H:%M:%S%.6f")),
671 None => String::new(),
672 };
673 assert_eq!(got, expected, "timestamp mismatch for micros={micros}");
674 }
675 }
676}