1use std::any::Any;
21use std::collections::{HashMap, HashSet};
22use std::fmt::{self, Debug};
23use std::sync::Arc;
24
25use crate::source::CsvSource;
26
27use arrow::array::RecordBatch;
28use arrow::csv::WriterBuilder;
29use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
30use arrow::error::ArrowError;
31use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions};
32use datafusion_common::file_options::csv_writer::CsvWriterOptions;
33use datafusion_common::{
34 exec_err, not_impl_err, DataFusionError, GetExt, Result, Statistics,
35 DEFAULT_CSV_EXTENSION,
36};
37use datafusion_common_runtime::SpawnedTask;
38use datafusion_datasource::decoder::Decoder;
39use datafusion_datasource::display::FileGroupDisplay;
40use datafusion_datasource::file::FileSource;
41use datafusion_datasource::file_compression_type::FileCompressionType;
42use datafusion_datasource::file_format::{
43 FileFormat, FileFormatFactory, DEFAULT_SCHEMA_INFER_MAX_RECORD,
44};
45use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder};
46use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig};
47use datafusion_datasource::sink::{DataSink, DataSinkExec};
48use datafusion_datasource::write::demux::DemuxedStreamReceiver;
49use datafusion_datasource::write::orchestration::spawn_writer_tasks_and_join;
50use datafusion_datasource::write::BatchSerializer;
51use datafusion_execution::{SendableRecordBatchStream, TaskContext};
52use datafusion_expr::dml::InsertOp;
53use datafusion_physical_expr_common::sort_expr::LexRequirement;
54use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
55use datafusion_session::Session;
56
57use async_trait::async_trait;
58use bytes::{Buf, Bytes};
59use datafusion_datasource::source::DataSourceExec;
60use futures::stream::BoxStream;
61use futures::{pin_mut, Stream, StreamExt, TryStreamExt};
62use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore};
63use regex::Regex;
64
65#[derive(Default)]
66pub struct CsvFormatFactory {
68 pub options: Option<CsvOptions>,
70}
71
72impl CsvFormatFactory {
73 pub fn new() -> Self {
75 Self { options: None }
76 }
77
78 pub fn new_with_options(options: CsvOptions) -> Self {
80 Self {
81 options: Some(options),
82 }
83 }
84}
85
86impl Debug for CsvFormatFactory {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 f.debug_struct("CsvFormatFactory")
89 .field("options", &self.options)
90 .finish()
91 }
92}
93
94impl FileFormatFactory for CsvFormatFactory {
95 fn create(
96 &self,
97 state: &dyn Session,
98 format_options: &HashMap<String, String>,
99 ) -> Result<Arc<dyn FileFormat>> {
100 let csv_options = match &self.options {
101 None => {
102 let mut table_options = state.default_table_options();
103 table_options.set_config_format(ConfigFileType::CSV);
104 table_options.alter_with_string_hash_map(format_options)?;
105 table_options.csv
106 }
107 Some(csv_options) => {
108 let mut csv_options = csv_options.clone();
109 for (k, v) in format_options {
110 csv_options.set(k, v)?;
111 }
112 csv_options
113 }
114 };
115
116 Ok(Arc::new(CsvFormat::default().with_options(csv_options)))
117 }
118
119 fn default(&self) -> Arc<dyn FileFormat> {
120 Arc::new(CsvFormat::default())
121 }
122
123 fn as_any(&self) -> &dyn Any {
124 self
125 }
126}
127
128impl GetExt for CsvFormatFactory {
129 fn get_ext(&self) -> String {
130 DEFAULT_CSV_EXTENSION[1..].to_string()
132 }
133}
134
135#[derive(Debug, Default)]
137pub struct CsvFormat {
138 options: CsvOptions,
139}
140
141impl CsvFormat {
142 async fn read_to_delimited_chunks<'a>(
146 &self,
147 store: &Arc<dyn ObjectStore>,
148 object: &ObjectMeta,
149 ) -> BoxStream<'a, Result<Bytes>> {
150 let stream = store
152 .get(&object.location)
153 .await
154 .map_err(DataFusionError::ObjectStore);
155 let stream = match stream {
156 Ok(stream) => self
157 .read_to_delimited_chunks_from_stream(
158 stream
159 .into_stream()
160 .map_err(DataFusionError::ObjectStore)
161 .boxed(),
162 )
163 .await
164 .map_err(DataFusionError::from)
165 .left_stream(),
166 Err(e) => {
167 futures::stream::once(futures::future::ready(Err(e))).right_stream()
168 }
169 };
170 stream.boxed()
171 }
172
173 pub async fn read_to_delimited_chunks_from_stream<'a>(
176 &self,
177 stream: BoxStream<'a, Result<Bytes>>,
178 ) -> BoxStream<'a, Result<Bytes>> {
179 let file_compression_type: FileCompressionType = self.options.compression.into();
180 let decoder = file_compression_type.convert_stream(stream);
181 let stream = match decoder {
182 Ok(decoded_stream) => {
183 newline_delimited_stream(decoded_stream.map_err(|e| match e {
184 DataFusionError::ObjectStore(e) => e,
185 err => object_store::Error::Generic {
186 store: "read to delimited chunks failed",
187 source: Box::new(err),
188 },
189 }))
190 .map_err(DataFusionError::from)
191 .left_stream()
192 }
193 Err(e) => {
194 futures::stream::once(futures::future::ready(Err(e))).right_stream()
195 }
196 };
197 stream.boxed()
198 }
199
200 pub fn with_options(mut self, options: CsvOptions) -> Self {
202 self.options = options;
203 self
204 }
205
206 pub fn options(&self) -> &CsvOptions {
208 &self.options
209 }
210
211 pub fn with_schema_infer_max_rec(mut self, max_rec: usize) -> Self {
214 self.options.schema_infer_max_rec = Some(max_rec);
215 self
216 }
217
218 pub fn with_has_header(mut self, has_header: bool) -> Self {
221 self.options.has_header = Some(has_header);
222 self
223 }
224
225 pub fn with_null_regex(mut self, null_regex: Option<String>) -> Self {
228 self.options.null_regex = null_regex;
229 self
230 }
231
232 pub fn has_header(&self) -> Option<bool> {
235 self.options.has_header
236 }
237
238 pub fn with_comment(mut self, comment: Option<u8>) -> Self {
240 self.options.comment = comment;
241 self
242 }
243
244 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
247 self.options.delimiter = delimiter;
248 self
249 }
250
251 pub fn with_quote(mut self, quote: u8) -> Self {
254 self.options.quote = quote;
255 self
256 }
257
258 pub fn with_escape(mut self, escape: Option<u8>) -> Self {
261 self.options.escape = escape;
262 self
263 }
264
265 pub fn with_terminator(mut self, terminator: Option<u8>) -> Self {
268 self.options.terminator = terminator;
269 self
270 }
271
272 pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self {
280 self.options.newlines_in_values = Some(newlines_in_values);
281 self
282 }
283
284 pub fn with_file_compression_type(
287 mut self,
288 file_compression_type: FileCompressionType,
289 ) -> Self {
290 self.options.compression = file_compression_type.into();
291 self
292 }
293
294 pub fn delimiter(&self) -> u8 {
296 self.options.delimiter
297 }
298
299 pub fn quote(&self) -> u8 {
301 self.options.quote
302 }
303
304 pub fn escape(&self) -> Option<u8> {
306 self.options.escape
307 }
308}
309
310#[derive(Debug)]
311pub struct CsvDecoder {
312 inner: arrow::csv::reader::Decoder,
313}
314
315impl CsvDecoder {
316 pub fn new(decoder: arrow::csv::reader::Decoder) -> Self {
317 Self { inner: decoder }
318 }
319}
320
321impl Decoder for CsvDecoder {
322 fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
323 self.inner.decode(buf)
324 }
325
326 fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
327 self.inner.flush()
328 }
329
330 fn can_flush_early(&self) -> bool {
331 self.inner.capacity() == 0
332 }
333}
334
335impl Debug for CsvSerializer {
336 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
337 f.debug_struct("CsvSerializer")
338 .field("header", &self.header)
339 .finish()
340 }
341}
342
343#[async_trait]
344impl FileFormat for CsvFormat {
345 fn as_any(&self) -> &dyn Any {
346 self
347 }
348
349 fn get_ext(&self) -> String {
350 CsvFormatFactory::new().get_ext()
351 }
352
353 fn get_ext_with_compression(
354 &self,
355 file_compression_type: &FileCompressionType,
356 ) -> Result<String> {
357 let ext = self.get_ext();
358 Ok(format!("{}{}", ext, file_compression_type.get_ext()))
359 }
360
361 async fn infer_schema(
362 &self,
363 state: &dyn Session,
364 store: &Arc<dyn ObjectStore>,
365 objects: &[ObjectMeta],
366 ) -> Result<SchemaRef> {
367 let mut schemas = vec![];
368
369 let mut records_to_read = self
370 .options
371 .schema_infer_max_rec
372 .unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD);
373
374 for object in objects {
375 let stream = self.read_to_delimited_chunks(store, object).await;
376 let (schema, records_read) = self
377 .infer_schema_from_stream(state, records_to_read, stream)
378 .await
379 .map_err(|err| {
380 DataFusionError::Context(
381 format!("Error when processing CSV file {}", &object.location),
382 Box::new(err),
383 )
384 })?;
385 records_to_read -= records_read;
386 schemas.push(schema);
387 if records_to_read == 0 {
388 break;
389 }
390 }
391
392 let merged_schema = Schema::try_merge(schemas)?;
393 Ok(Arc::new(merged_schema))
394 }
395
396 async fn infer_stats(
397 &self,
398 _state: &dyn Session,
399 _store: &Arc<dyn ObjectStore>,
400 table_schema: SchemaRef,
401 _object: &ObjectMeta,
402 ) -> Result<Statistics> {
403 Ok(Statistics::new_unknown(&table_schema))
404 }
405
406 async fn create_physical_plan(
407 &self,
408 state: &dyn Session,
409 conf: FileScanConfig,
410 ) -> Result<Arc<dyn ExecutionPlan>> {
411 let has_header = self
413 .options
414 .has_header
415 .unwrap_or_else(|| state.config_options().catalog.has_header);
416 let newlines_in_values = self
417 .options
418 .newlines_in_values
419 .unwrap_or_else(|| state.config_options().catalog.newlines_in_values);
420
421 let conf_builder = FileScanConfigBuilder::from(conf)
422 .with_file_compression_type(self.options.compression.into())
423 .with_newlines_in_values(newlines_in_values);
424
425 let source = Arc::new(
426 CsvSource::new(has_header, self.options.delimiter, self.options.quote)
427 .with_escape(self.options.escape)
428 .with_terminator(self.options.terminator)
429 .with_comment(self.options.comment),
430 );
431
432 let config = conf_builder.with_source(source).build();
433
434 Ok(DataSourceExec::from_data_source(config))
435 }
436
437 async fn create_writer_physical_plan(
438 &self,
439 input: Arc<dyn ExecutionPlan>,
440 state: &dyn Session,
441 conf: FileSinkConfig,
442 order_requirements: Option<LexRequirement>,
443 ) -> Result<Arc<dyn ExecutionPlan>> {
444 if conf.insert_op != InsertOp::Append {
445 return not_impl_err!("Overwrites are not implemented yet for CSV");
446 }
447
448 let has_header = self
453 .options()
454 .has_header
455 .unwrap_or_else(|| state.config_options().catalog.has_header);
456 let newlines_in_values = self
457 .options()
458 .newlines_in_values
459 .unwrap_or_else(|| state.config_options().catalog.newlines_in_values);
460
461 let options = self
462 .options()
463 .clone()
464 .with_has_header(has_header)
465 .with_newlines_in_values(newlines_in_values);
466
467 let writer_options = CsvWriterOptions::try_from(&options)?;
468
469 let sink = Arc::new(CsvSink::new(conf, writer_options));
470
471 Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _)
472 }
473
474 fn file_source(&self) -> Arc<dyn FileSource> {
475 Arc::new(CsvSource::default())
476 }
477}
478
479impl CsvFormat {
480 pub async fn infer_schema_from_stream(
484 &self,
485 state: &dyn Session,
486 mut records_to_read: usize,
487 stream: impl Stream<Item = Result<Bytes>>,
488 ) -> Result<(Schema, usize)> {
489 let mut total_records_read = 0;
490 let mut column_names = vec![];
491 let mut column_type_possibilities = vec![];
492 let mut record_number = -1;
493
494 pin_mut!(stream);
495
496 while let Some(chunk) = stream.next().await.transpose()? {
497 record_number += 1;
498 let first_chunk = record_number == 0;
499 let mut format = arrow::csv::reader::Format::default()
500 .with_header(
501 first_chunk
502 && self
503 .options
504 .has_header
505 .unwrap_or_else(|| state.config_options().catalog.has_header),
506 )
507 .with_delimiter(self.options.delimiter)
508 .with_quote(self.options.quote);
509
510 if let Some(null_regex) = &self.options.null_regex {
511 let regex = Regex::new(null_regex.as_str())
512 .expect("Unable to parse CSV null regex.");
513 format = format.with_null_regex(regex);
514 }
515
516 if let Some(escape) = self.options.escape {
517 format = format.with_escape(escape);
518 }
519
520 if let Some(comment) = self.options.comment {
521 format = format.with_comment(comment);
522 }
523
524 let (Schema { fields, .. }, records_read) =
525 format.infer_schema(chunk.reader(), Some(records_to_read))?;
526
527 records_to_read -= records_read;
528 total_records_read += records_read;
529
530 if first_chunk {
531 (column_names, column_type_possibilities) = fields
533 .into_iter()
534 .map(|field| {
535 let mut possibilities = HashSet::new();
536 if records_read > 0 {
537 possibilities.insert(field.data_type().clone());
539 }
540 (field.name().clone(), possibilities)
541 })
542 .unzip();
543 } else {
544 if fields.len() != column_type_possibilities.len() {
545 return exec_err!(
546 "Encountered unequal lengths between records on CSV file whilst inferring schema. \
547 Expected {} fields, found {} fields at record {}",
548 column_type_possibilities.len(),
549 fields.len(),
550 record_number + 1
551 );
552 }
553
554 column_type_possibilities.iter_mut().zip(&fields).for_each(
555 |(possibilities, field)| {
556 possibilities.insert(field.data_type().clone());
557 },
558 );
559 }
560
561 if records_to_read == 0 {
562 break;
563 }
564 }
565
566 let schema = build_schema_helper(column_names, &column_type_possibilities);
567 Ok((schema, total_records_read))
568 }
569}
570
571fn build_schema_helper(names: Vec<String>, types: &[HashSet<DataType>]) -> Schema {
572 let fields = names
573 .into_iter()
574 .zip(types)
575 .map(|(field_name, data_type_possibilities)| {
576 match data_type_possibilities.len() {
580 1 => Field::new(
581 field_name,
582 data_type_possibilities.iter().next().unwrap().clone(),
583 true,
584 ),
585 2 => {
586 if data_type_possibilities.contains(&DataType::Int64)
587 && data_type_possibilities.contains(&DataType::Float64)
588 {
589 Field::new(field_name, DataType::Float64, true)
591 } else {
592 Field::new(field_name, DataType::Utf8, true)
594 }
595 }
596 _ => Field::new(field_name, DataType::Utf8, true),
597 }
598 })
599 .collect::<Fields>();
600 Schema::new(fields)
601}
602
603impl Default for CsvSerializer {
604 fn default() -> Self {
605 Self::new()
606 }
607}
608
609pub struct CsvSerializer {
611 builder: WriterBuilder,
613 header: bool,
615}
616
617impl CsvSerializer {
618 pub fn new() -> Self {
620 Self {
621 builder: WriterBuilder::new(),
622 header: true,
623 }
624 }
625
626 pub fn with_builder(mut self, builder: WriterBuilder) -> Self {
628 self.builder = builder;
629 self
630 }
631
632 pub fn with_header(mut self, header: bool) -> Self {
634 self.header = header;
635 self
636 }
637}
638
639impl BatchSerializer for CsvSerializer {
640 fn serialize(&self, batch: RecordBatch, initial: bool) -> Result<Bytes> {
641 let mut buffer = Vec::with_capacity(4096);
642 let builder = self.builder.clone();
643 let header = self.header && initial;
644 let mut writer = builder.with_header(header).build(&mut buffer);
645 writer.write(&batch)?;
646 drop(writer);
647 Ok(Bytes::from(buffer))
648 }
649}
650
651pub struct CsvSink {
653 config: FileSinkConfig,
655 writer_options: CsvWriterOptions,
656}
657
658impl Debug for CsvSink {
659 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
660 f.debug_struct("CsvSink").finish()
661 }
662}
663
664impl DisplayAs for CsvSink {
665 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
666 match t {
667 DisplayFormatType::Default | DisplayFormatType::Verbose => {
668 write!(f, "CsvSink(file_groups=",)?;
669 FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?;
670 write!(f, ")")
671 }
672 DisplayFormatType::TreeRender => {
673 writeln!(f, "format: csv")?;
674 write!(f, "file={}", &self.config.original_url)
675 }
676 }
677 }
678}
679
680impl CsvSink {
681 pub fn new(config: FileSinkConfig, writer_options: CsvWriterOptions) -> Self {
683 Self {
684 config,
685 writer_options,
686 }
687 }
688
689 pub fn writer_options(&self) -> &CsvWriterOptions {
691 &self.writer_options
692 }
693}
694
695#[async_trait]
696impl FileSink for CsvSink {
697 fn config(&self) -> &FileSinkConfig {
698 &self.config
699 }
700
701 async fn spawn_writer_tasks_and_join(
702 &self,
703 context: &Arc<TaskContext>,
704 demux_task: SpawnedTask<Result<()>>,
705 file_stream_rx: DemuxedStreamReceiver,
706 object_store: Arc<dyn ObjectStore>,
707 ) -> Result<u64> {
708 let builder = self.writer_options.writer_options.clone();
709 let header = builder.header();
710 let serializer = Arc::new(
711 CsvSerializer::new()
712 .with_builder(builder)
713 .with_header(header),
714 ) as _;
715 spawn_writer_tasks_and_join(
716 context,
717 serializer,
718 self.writer_options.compression.into(),
719 object_store,
720 demux_task,
721 file_stream_rx,
722 )
723 .await
724 }
725}
726
727#[async_trait]
728impl DataSink for CsvSink {
729 fn as_any(&self) -> &dyn Any {
730 self
731 }
732
733 fn schema(&self) -> &SchemaRef {
734 self.config.output_schema()
735 }
736
737 async fn write_all(
738 &self,
739 data: SendableRecordBatchStream,
740 context: &Arc<TaskContext>,
741 ) -> Result<u64> {
742 FileSink::write_all(self, data, context).await
743 }
744}