1use std::io::Write;
2
3use arrow::array::Time64MicrosecondArray;
4use arrow::array::types::Decimal128Type;
5use arrow::array::*;
6use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
7use arrow::record_batch::RecordBatch;
8
9use crate::error::Result;
10use crate::types::decimal::scaled_i128_to_decimal_str;
11
12pub struct CsvFormat;
13
14pub struct CsvFormatWriter {
15 writer: Box<dyn Write + Send>,
16 bytes_written: u64,
17}
18
19impl super::Format for CsvFormat {
20 fn create_writer(
21 &self,
22 schema: &SchemaRef,
23 mut writer: Box<dyn Write + Send>,
24 ) -> Result<Box<dyn super::FormatWriter + Send>> {
25 let header = schema
26 .fields()
27 .iter()
28 .map(|f| f.name().as_str())
29 .collect::<Vec<_>>()
30 .join(",");
31 let header_bytes = header.len() as u64 + 1; writeln!(writer, "{}", header)?;
33 Ok(Box::new(CsvFormatWriter {
34 writer,
35 bytes_written: header_bytes,
36 }))
37 }
38
39 fn file_extension(&self) -> &str {
40 "csv"
41 }
42}
43
44impl super::FormatWriter for CsvFormatWriter {
45 fn write_batch(&mut self, batch: &RecordBatch) -> Result<()> {
46 let mut buf = Vec::with_capacity(batch.num_rows() * batch.num_columns() * 8);
47 for row_idx in 0..batch.num_rows() {
48 for col_idx in 0..batch.num_columns() {
49 if col_idx > 0 {
50 buf.push(b',');
51 }
52 write_csv_value(&mut buf, batch.column(col_idx), row_idx)?;
53 }
54 buf.push(b'\n');
55 }
56 self.bytes_written += buf.len() as u64;
57 self.writer.write_all(&buf)?;
58 Ok(())
59 }
60
61 fn finish(self: Box<Self>) -> Result<()> {
62 Ok(())
63 }
64
65 fn bytes_written(&self) -> u64 {
66 self.bytes_written
67 }
68}
69
70fn write_csv_value(writer: &mut dyn Write, array: &dyn Array, idx: usize) -> Result<()> {
71 if array.is_null(idx) {
72 return Ok(());
73 }
74
75 match array.data_type() {
76 DataType::Boolean => {
77 let arr = array
78 .as_any()
79 .downcast_ref::<BooleanArray>()
80 .expect("DataType/Array mismatch");
81 write!(writer, "{}", arr.value(idx))?;
82 }
83 DataType::Int16 => {
84 let arr = array
85 .as_any()
86 .downcast_ref::<Int16Array>()
87 .expect("DataType/Array mismatch");
88 write!(writer, "{}", arr.value(idx))?;
89 }
90 DataType::Int32 => {
91 let arr = array
92 .as_any()
93 .downcast_ref::<Int32Array>()
94 .expect("DataType/Array mismatch");
95 write!(writer, "{}", arr.value(idx))?;
96 }
97 DataType::Int64 => {
98 let arr = array
99 .as_any()
100 .downcast_ref::<Int64Array>()
101 .expect("DataType/Array mismatch");
102 write!(writer, "{}", arr.value(idx))?;
103 }
104 DataType::UInt64 => {
105 let arr = array
106 .as_any()
107 .downcast_ref::<UInt64Array>()
108 .expect("DataType/Array mismatch");
109 write!(writer, "{}", arr.value(idx))?;
110 }
111 DataType::Decimal128(_, scale) => {
112 let arr = array.as_primitive::<Decimal128Type>();
113 let text = scaled_i128_to_decimal_str(arr.value(idx), *scale);
114 writer.write_all(text.as_bytes())?;
115 }
116 DataType::Float32 => {
117 let arr = array
118 .as_any()
119 .downcast_ref::<Float32Array>()
120 .expect("DataType/Array mismatch");
121 write!(writer, "{}", arr.value(idx))?;
122 }
123 DataType::Float64 => {
124 let arr = array
125 .as_any()
126 .downcast_ref::<Float64Array>()
127 .expect("DataType/Array mismatch");
128 write!(writer, "{}", arr.value(idx))?;
129 }
130 DataType::Utf8 => {
131 let arr = array
132 .as_any()
133 .downcast_ref::<StringArray>()
134 .expect("DataType/Array mismatch");
135 let val = arr.value(idx);
136 if val.contains(',') || val.contains('"') || val.contains('\n') {
137 writer.write_all(b"\"")?;
138 let mut rest = val;
139 while let Some(pos) = rest.find('"') {
140 writer.write_all(&rest.as_bytes()[..pos])?;
141 writer.write_all(b"\"\"")?;
142 rest = &rest[pos + 1..];
143 }
144 writer.write_all(rest.as_bytes())?;
145 writer.write_all(b"\"")?;
146 } else {
147 writer.write_all(val.as_bytes())?;
148 }
149 }
150 DataType::Binary => {
151 let arr = array
152 .as_any()
153 .downcast_ref::<BinaryArray>()
154 .expect("DataType/Array mismatch");
155 let val = arr.value(idx);
156 for byte in val {
157 write!(writer, "{:02x}", byte)?;
158 }
159 }
160 DataType::FixedSizeBinary(16) => {
168 let arr = array
169 .as_any()
170 .downcast_ref::<FixedSizeBinaryArray>()
171 .expect("DataType/Array mismatch");
172 let val = arr.value(idx);
173 let mut bytes = [0u8; 16];
174 bytes.copy_from_slice(val);
175 write!(writer, "{}", uuid::Uuid::from_bytes(bytes).to_hyphenated())?;
176 }
177 DataType::Date32 => {
178 let arr = array
179 .as_any()
180 .downcast_ref::<Date32Array>()
181 .expect("DataType/Array mismatch");
182 let days = arr.value(idx);
183 let epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).expect("epoch is valid");
189 let date =
190 chrono::Duration::try_days(days as i64).and_then(|d| epoch.checked_add_signed(d));
191 if let Some(date) = date {
192 write!(writer, "{}", date)?;
193 }
194 }
195 DataType::Time64(TimeUnit::Microsecond) => {
196 let arr = array
197 .as_any()
198 .downcast_ref::<Time64MicrosecondArray>()
199 .expect("DataType/Array mismatch");
200 let micros = arr.value(idx);
201 let secs = micros / 1_000_000;
202 let frac_us = micros % 1_000_000;
203 write!(
204 writer,
205 "{:02}:{:02}:{:02}.{:06}",
206 secs / 3600,
207 (secs % 3600) / 60,
208 secs % 60,
209 frac_us
210 )?;
211 }
212 DataType::Timestamp(TimeUnit::Microsecond, _) => {
213 let arr = array
214 .as_any()
215 .downcast_ref::<TimestampMicrosecondArray>()
216 .expect("DataType/Array mismatch");
217 let micros = arr.value(idx);
218 let secs = micros / 1_000_000;
219 let nsecs = ((micros % 1_000_000) * 1_000) as u32;
220 if let Some(dt) = chrono::DateTime::from_timestamp(secs, nsecs) {
221 write!(writer, "{}", dt.format("%Y-%m-%dT%H:%M:%S%.6f"))?;
222 }
223 }
224 other => {
225 log::warn!("CSV: unhandled Arrow type {:?}, skipping value", other);
226 }
227 }
228
229 Ok(())
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
236 use std::sync::Arc;
237
238 fn cell<A: Array + 'static>(array: A, idx: usize) -> String {
240 let mut buf = Vec::new();
241 write_csv_value(&mut buf, &array, idx).unwrap();
242 String::from_utf8(buf).unwrap()
243 }
244
245 fn null_cell(dt: DataType) -> String {
247 use arrow::array::new_null_array;
248 let arr = new_null_array(&dt, 1);
249 let mut buf = Vec::new();
250 write_csv_value(&mut buf, arr.as_ref(), 0).unwrap();
251 String::from_utf8(buf).unwrap()
252 }
253
254 #[test]
257 fn null_value_writes_empty_string() {
258 assert_eq!(null_cell(DataType::Int64), "");
259 assert_eq!(null_cell(DataType::Utf8), "");
260 assert_eq!(null_cell(DataType::Boolean), "");
261 }
262
263 #[test]
266 fn bool_true_writes_true() {
267 assert_eq!(cell(BooleanArray::from(vec![true]), 0), "true");
268 }
269
270 #[test]
271 fn bool_false_writes_false() {
272 assert_eq!(cell(BooleanArray::from(vec![false]), 0), "false");
273 }
274
275 #[test]
276 fn int16_value() {
277 assert_eq!(cell(Int16Array::from(vec![42i16]), 0), "42");
278 }
279
280 #[test]
281 fn int32_negative() {
282 assert_eq!(cell(Int32Array::from(vec![-7i32]), 0), "-7");
283 }
284
285 #[test]
286 fn decimal128_writes_exact_text() {
287 let arr = Decimal128Array::from(vec![10i128])
288 .with_precision_and_scale(18, 2)
289 .unwrap();
290 assert_eq!(cell(arr, 0), "0.10");
291 let scaled =
292 crate::types::decimal::decimal_str_to_scaled_i128("999999999999.99", 2).unwrap();
293 let arr = Decimal128Array::from(vec![scaled])
294 .with_precision_and_scale(18, 2)
295 .unwrap();
296 assert_eq!(cell(arr, 0), "999999999999.99");
297 }
298
299 #[test]
300 fn int64_large() {
301 assert_eq!(
302 cell(Int64Array::from(vec![9_999_999_999i64]), 0),
303 "9999999999"
304 );
305 }
306
307 #[test]
308 fn float32_value() {
309 let result = cell(Float32Array::from(vec![1.5f32]), 0);
310 assert!(result.starts_with("1.5"), "got: {result}");
311 }
312
313 #[test]
314 fn float64_value() {
315 let result = cell(Float64Array::from(vec![std::f64::consts::PI]), 0);
316 assert!(result.starts_with("3.14"), "got: {result}");
317 }
318
319 #[test]
322 fn plain_string_no_quoting() {
323 assert_eq!(cell(StringArray::from(vec!["hello"]), 0), "hello");
324 }
325
326 #[test]
327 fn string_with_comma_is_quoted() {
328 assert_eq!(cell(StringArray::from(vec!["a,b"]), 0), "\"a,b\"");
329 }
330
331 #[test]
332 fn string_with_double_quote_is_escaped() {
333 let result = cell(StringArray::from(vec![r#"say "hi""#]), 0);
335 assert_eq!(result, r#""say ""hi""""#);
336 }
337
338 #[test]
339 fn string_with_newline_is_quoted() {
340 let result = cell(StringArray::from(vec!["line1\nline2"]), 0);
341 assert!(
342 result.starts_with('"') && result.ends_with('"'),
343 "got: {result}"
344 );
345 assert!(result.contains("line1\nline2"), "got: {result}");
346 }
347
348 #[test]
351 fn binary_is_written_as_hex() {
352 let arr = BinaryArray::from_vec(vec![&[0xDE, 0xAD, 0xBE, 0xEF][..]]);
353 assert_eq!(cell(arr, 0), "deadbeef");
354 }
355
356 #[test]
357 fn binary_empty_writes_empty() {
358 let arr = BinaryArray::from_vec(vec![&[][..]]);
359 assert_eq!(cell(arr, 0), "");
360 }
361
362 #[test]
365 fn date32_epoch_is_1970_01_01() {
366 assert_eq!(cell(Date32Array::from(vec![0i32]), 0), "1970-01-01");
367 }
368
369 #[test]
370 fn date32_positive_offset() {
371 assert_eq!(cell(Date32Array::from(vec![365i32]), 0), "1971-01-01");
373 }
374
375 #[test]
378 fn timestamp_micros_formats_as_iso() {
379 let micros: i64 = 1_672_531_200 * 1_000_000;
381 let _schema = Arc::new(Schema::new(vec![Field::new(
382 "ts",
383 DataType::Timestamp(TimeUnit::Microsecond, None),
384 true,
385 )]));
386 let arr = TimestampMicrosecondArray::from(vec![micros]);
387 let result = cell(arr, 0);
388 assert!(result.starts_with("2023-01-01T"), "got: {result}");
389 assert!(result.contains("00:00:00"), "got: {result}");
390 }
391
392 #[test]
395 fn csv_format_write_batch_tracks_bytes_and_succeeds() {
396 use crate::format::Format;
397
398 let schema = Arc::new(Schema::new(vec![
399 Field::new("id", DataType::Int64, false),
400 Field::new("name", DataType::Utf8, true),
401 ]));
402 let batch = arrow::record_batch::RecordBatch::try_new(
403 schema.clone(),
404 vec![
405 Arc::new(Int64Array::from(vec![1i64, 2])),
406 Arc::new(StringArray::from(vec![Some("alice"), None])),
407 ],
408 )
409 .unwrap();
410
411 let fmt = CsvFormat;
413 let mut writer = fmt
414 .create_writer(&schema, Box::new(Vec::<u8>::new()))
415 .unwrap();
416 writer.write_batch(&batch).unwrap();
417 assert!(
419 writer.bytes_written() > 10,
420 "expected >10 bytes, got {}",
421 writer.bytes_written()
422 );
423 writer.finish().unwrap();
424 }
425}