1use std::any::Any;
21use std::collections::{HashMap, HashSet};
22use std::fmt::{self, Debug};
23use std::sync::Arc;
24
25use crate::source::CsvSource;
26
27use arrow::array::RecordBatch;
28use arrow::csv::WriterBuilder;
29use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
30use arrow::error::ArrowError;
31use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions};
32use datafusion_common::file_options::csv_writer::CsvWriterOptions;
33use datafusion_common::{
34 DEFAULT_CSV_EXTENSION, DataFusionError, GetExt, Result, Statistics, exec_err,
35 not_impl_err,
36};
37use datafusion_common_runtime::SpawnedTask;
38use datafusion_datasource::TableSchema;
39use datafusion_datasource::decoder::Decoder;
40use datafusion_datasource::display::FileGroupDisplay;
41use datafusion_datasource::file::FileSource;
42use datafusion_datasource::file_compression_type::FileCompressionType;
43use datafusion_datasource::file_format::{
44 DEFAULT_SCHEMA_INFER_MAX_RECORD, FileFormat, FileFormatFactory,
45};
46use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder};
47use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig};
48use datafusion_datasource::sink::{DataSink, DataSinkExec};
49use datafusion_datasource::write::BatchSerializer;
50use datafusion_datasource::write::demux::DemuxedStreamReceiver;
51use datafusion_datasource::write::orchestration::spawn_writer_tasks_and_join;
52use datafusion_execution::{SendableRecordBatchStream, TaskContext};
53use datafusion_expr::dml::InsertOp;
54use datafusion_physical_expr_common::sort_expr::LexRequirement;
55use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
56use datafusion_session::Session;
57
58use async_trait::async_trait;
59use bytes::{Buf, Bytes};
60use datafusion_datasource::source::DataSourceExec;
61use futures::stream::BoxStream;
62use futures::{Stream, StreamExt, TryStreamExt, pin_mut};
63use object_store::{ObjectMeta, ObjectStore, delimited::newline_delimited_stream};
64use regex::Regex;
65
66#[derive(Default)]
67pub struct CsvFormatFactory {
69 pub options: Option<CsvOptions>,
71}
72
73impl CsvFormatFactory {
74 pub fn new() -> Self {
76 Self { options: None }
77 }
78
79 pub fn new_with_options(options: CsvOptions) -> Self {
81 Self {
82 options: Some(options),
83 }
84 }
85}
86
87impl Debug for CsvFormatFactory {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 f.debug_struct("CsvFormatFactory")
90 .field("options", &self.options)
91 .finish()
92 }
93}
94
95impl FileFormatFactory for CsvFormatFactory {
96 fn create(
97 &self,
98 state: &dyn Session,
99 format_options: &HashMap<String, String>,
100 ) -> Result<Arc<dyn FileFormat>> {
101 let csv_options = match &self.options {
102 None => {
103 let mut table_options = state.default_table_options();
104 table_options.set_config_format(ConfigFileType::CSV);
105 table_options.alter_with_string_hash_map(format_options)?;
106 table_options.csv
107 }
108 Some(csv_options) => {
109 let mut csv_options = csv_options.clone();
110 for (k, v) in format_options {
111 csv_options.set(k, v)?;
112 }
113 csv_options
114 }
115 };
116
117 Ok(Arc::new(CsvFormat::default().with_options(csv_options)))
118 }
119
120 fn default(&self) -> Arc<dyn FileFormat> {
121 Arc::new(CsvFormat::default())
122 }
123
124 fn as_any(&self) -> &dyn Any {
125 self
126 }
127}
128
129impl GetExt for CsvFormatFactory {
130 fn get_ext(&self) -> String {
131 DEFAULT_CSV_EXTENSION[1..].to_string()
133 }
134}
135
136#[derive(Debug, Default)]
138pub struct CsvFormat {
139 options: CsvOptions,
140}
141
142impl CsvFormat {
143 async fn read_to_delimited_chunks<'a>(
147 &self,
148 store: &Arc<dyn ObjectStore>,
149 object: &ObjectMeta,
150 ) -> BoxStream<'a, Result<Bytes>> {
151 let stream = store
153 .get(&object.location)
154 .await
155 .map_err(|e| DataFusionError::ObjectStore(Box::new(e)));
156 let stream = match stream {
157 Ok(stream) => self
158 .read_to_delimited_chunks_from_stream(
159 stream
160 .into_stream()
161 .map_err(|e| DataFusionError::ObjectStore(Box::new(e)))
162 .boxed(),
163 )
164 .await
165 .map_err(DataFusionError::from)
166 .left_stream(),
167 Err(e) => {
168 futures::stream::once(futures::future::ready(Err(e))).right_stream()
169 }
170 };
171 stream.boxed()
172 }
173
174 pub async fn read_to_delimited_chunks_from_stream<'a>(
177 &self,
178 stream: BoxStream<'a, Result<Bytes>>,
179 ) -> BoxStream<'a, Result<Bytes>> {
180 let file_compression_type: FileCompressionType = self.options.compression.into();
181 let decoder = file_compression_type.convert_stream(stream);
182 let stream = match decoder {
183 Ok(decoded_stream) => {
184 newline_delimited_stream(decoded_stream.map_err(|e| match e {
185 DataFusionError::ObjectStore(e) => *e,
186 err => object_store::Error::Generic {
187 store: "read to delimited chunks failed",
188 source: Box::new(err),
189 },
190 }))
191 .map_err(DataFusionError::from)
192 .left_stream()
193 }
194 Err(e) => {
195 futures::stream::once(futures::future::ready(Err(e))).right_stream()
196 }
197 };
198 stream.boxed()
199 }
200
201 pub fn with_options(mut self, options: CsvOptions) -> Self {
203 self.options = options;
204 self
205 }
206
207 pub fn options(&self) -> &CsvOptions {
209 &self.options
210 }
211
212 pub fn with_schema_infer_max_rec(mut self, max_rec: usize) -> Self {
220 self.options.schema_infer_max_rec = Some(max_rec);
221 self
222 }
223
224 pub fn with_has_header(mut self, has_header: bool) -> Self {
227 self.options.has_header = Some(has_header);
228 self
229 }
230
231 pub fn with_truncated_rows(mut self, truncated_rows: bool) -> Self {
232 self.options.truncated_rows = Some(truncated_rows);
233 self
234 }
235
236 pub fn with_null_regex(mut self, null_regex: Option<String>) -> Self {
239 self.options.null_regex = null_regex;
240 self
241 }
242
243 pub fn has_header(&self) -> Option<bool> {
246 self.options.has_header
247 }
248
249 pub fn with_comment(mut self, comment: Option<u8>) -> Self {
251 self.options.comment = comment;
252 self
253 }
254
255 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
258 self.options.delimiter = delimiter;
259 self
260 }
261
262 pub fn with_quote(mut self, quote: u8) -> Self {
265 self.options.quote = quote;
266 self
267 }
268
269 pub fn with_escape(mut self, escape: Option<u8>) -> Self {
272 self.options.escape = escape;
273 self
274 }
275
276 pub fn with_terminator(mut self, terminator: Option<u8>) -> Self {
279 self.options.terminator = terminator;
280 self
281 }
282
283 pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self {
291 self.options.newlines_in_values = Some(newlines_in_values);
292 self
293 }
294
295 pub fn with_file_compression_type(
298 mut self,
299 file_compression_type: FileCompressionType,
300 ) -> Self {
301 self.options.compression = file_compression_type.into();
302 self
303 }
304
305 pub fn with_truncate_rows(mut self, truncate_rows: bool) -> Self {
308 self.options.truncated_rows = Some(truncate_rows);
309 self
310 }
311
312 pub fn delimiter(&self) -> u8 {
314 self.options.delimiter
315 }
316
317 pub fn quote(&self) -> u8 {
319 self.options.quote
320 }
321
322 pub fn escape(&self) -> Option<u8> {
324 self.options.escape
325 }
326}
327
328#[derive(Debug)]
329pub struct CsvDecoder {
330 inner: arrow::csv::reader::Decoder,
331}
332
333impl CsvDecoder {
334 pub fn new(decoder: arrow::csv::reader::Decoder) -> Self {
335 Self { inner: decoder }
336 }
337}
338
339impl Decoder for CsvDecoder {
340 fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
341 self.inner.decode(buf)
342 }
343
344 fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
345 self.inner.flush()
346 }
347
348 fn can_flush_early(&self) -> bool {
349 self.inner.capacity() == 0
350 }
351}
352
353impl Debug for CsvSerializer {
354 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
355 f.debug_struct("CsvSerializer")
356 .field("header", &self.header)
357 .finish()
358 }
359}
360
361#[async_trait]
362impl FileFormat for CsvFormat {
363 fn as_any(&self) -> &dyn Any {
364 self
365 }
366
367 fn get_ext(&self) -> String {
368 CsvFormatFactory::new().get_ext()
369 }
370
371 fn get_ext_with_compression(
372 &self,
373 file_compression_type: &FileCompressionType,
374 ) -> Result<String> {
375 let ext = self.get_ext();
376 Ok(format!("{}{}", ext, file_compression_type.get_ext()))
377 }
378
379 fn compression_type(&self) -> Option<FileCompressionType> {
380 Some(self.options.compression.into())
381 }
382
383 async fn infer_schema(
384 &self,
385 state: &dyn Session,
386 store: &Arc<dyn ObjectStore>,
387 objects: &[ObjectMeta],
388 ) -> Result<SchemaRef> {
389 let mut schemas = vec![];
390
391 let mut records_to_read = self
392 .options
393 .schema_infer_max_rec
394 .unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD);
395
396 for object in objects {
397 let stream = self.read_to_delimited_chunks(store, object).await;
398 let (schema, records_read) = self
399 .infer_schema_from_stream(state, records_to_read, stream)
400 .await
401 .map_err(|err| {
402 DataFusionError::Context(
403 format!("Error when processing CSV file {}", &object.location),
404 Box::new(err),
405 )
406 })?;
407 records_to_read -= records_read;
408 schemas.push(schema);
409 if records_to_read == 0 {
410 break;
411 }
412 }
413
414 let merged_schema = Schema::try_merge(schemas)?;
415 Ok(Arc::new(merged_schema))
416 }
417
418 async fn infer_stats(
419 &self,
420 _state: &dyn Session,
421 _store: &Arc<dyn ObjectStore>,
422 table_schema: SchemaRef,
423 _object: &ObjectMeta,
424 ) -> Result<Statistics> {
425 Ok(Statistics::new_unknown(&table_schema))
426 }
427
428 async fn create_physical_plan(
429 &self,
430 state: &dyn Session,
431 conf: FileScanConfig,
432 ) -> Result<Arc<dyn ExecutionPlan>> {
433 let has_header = self
435 .options
436 .has_header
437 .unwrap_or_else(|| state.config_options().catalog.has_header);
438 let newlines_in_values = self
439 .options
440 .newlines_in_values
441 .unwrap_or_else(|| state.config_options().catalog.newlines_in_values);
442
443 let mut csv_options = self.options.clone();
444 csv_options.has_header = Some(has_header);
445 csv_options.newlines_in_values = Some(newlines_in_values);
446
447 let csv_source = conf
450 .file_source
451 .as_any()
452 .downcast_ref::<CsvSource>()
453 .expect("file_source should be a CsvSource");
454 let source = Arc::new(csv_source.clone().with_csv_options(csv_options));
455
456 let config = FileScanConfigBuilder::from(conf)
457 .with_file_compression_type(self.options.compression.into())
458 .with_source(source)
459 .build();
460
461 Ok(DataSourceExec::from_data_source(config))
462 }
463
464 async fn create_writer_physical_plan(
465 &self,
466 input: Arc<dyn ExecutionPlan>,
467 state: &dyn Session,
468 conf: FileSinkConfig,
469 order_requirements: Option<LexRequirement>,
470 ) -> Result<Arc<dyn ExecutionPlan>> {
471 if conf.insert_op != InsertOp::Append {
472 return not_impl_err!("Overwrites are not implemented yet for CSV");
473 }
474
475 let has_header = self
480 .options()
481 .has_header
482 .unwrap_or_else(|| state.config_options().catalog.has_header);
483 let newlines_in_values = self
484 .options()
485 .newlines_in_values
486 .unwrap_or_else(|| state.config_options().catalog.newlines_in_values);
487
488 let options = self
489 .options()
490 .clone()
491 .with_has_header(has_header)
492 .with_newlines_in_values(newlines_in_values);
493
494 let writer_options = CsvWriterOptions::try_from(&options)?;
495
496 let sink = Arc::new(CsvSink::new(conf, writer_options));
497
498 Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _)
499 }
500
501 fn file_source(&self, table_schema: TableSchema) -> Arc<dyn FileSource> {
502 let mut csv_options = self.options.clone();
503 if csv_options.has_header.is_none() {
504 csv_options.has_header = Some(true);
505 }
506 Arc::new(CsvSource::new(table_schema).with_csv_options(csv_options))
507 }
508}
509
510impl CsvFormat {
511 pub async fn infer_schema_from_stream(
528 &self,
529 state: &dyn Session,
530 mut records_to_read: usize,
531 stream: impl Stream<Item = Result<Bytes>>,
532 ) -> Result<(Schema, usize)> {
533 let mut total_records_read = 0;
534 let mut column_names = vec![];
535 let mut column_type_possibilities = vec![];
536 let mut record_number = -1;
537 let initial_records_to_read = records_to_read;
538
539 pin_mut!(stream);
540
541 while let Some(chunk) = stream.next().await.transpose()? {
542 record_number += 1;
543 let first_chunk = record_number == 0;
544 let mut format = arrow::csv::reader::Format::default()
545 .with_header(
546 first_chunk
547 && self
548 .options
549 .has_header
550 .unwrap_or_else(|| state.config_options().catalog.has_header),
551 )
552 .with_delimiter(self.options.delimiter)
553 .with_quote(self.options.quote)
554 .with_truncated_rows(self.options.truncated_rows.unwrap_or(false));
555
556 if let Some(null_regex) = &self.options.null_regex {
557 let regex = Regex::new(null_regex.as_str())
558 .expect("Unable to parse CSV null regex.");
559 format = format.with_null_regex(regex);
560 }
561
562 if let Some(escape) = self.options.escape {
563 format = format.with_escape(escape);
564 }
565
566 if let Some(comment) = self.options.comment {
567 format = format.with_comment(comment);
568 }
569
570 let (Schema { fields, .. }, records_read) =
571 format.infer_schema(chunk.reader(), Some(records_to_read))?;
572
573 records_to_read -= records_read;
574 total_records_read += records_read;
575
576 if first_chunk {
577 (column_names, column_type_possibilities) = fields
579 .into_iter()
580 .map(|field| {
581 let mut possibilities = HashSet::new();
582 if records_read > 0 {
583 possibilities.insert(field.data_type().clone());
585 }
586 (field.name().clone(), possibilities)
587 })
588 .unzip();
589 } else {
590 if fields.len() != column_type_possibilities.len()
591 && !self.options.truncated_rows.unwrap_or(false)
592 {
593 return exec_err!(
594 "Encountered unequal lengths between records on CSV file whilst inferring schema. \
595 Expected {} fields, found {} fields at record {}",
596 column_type_possibilities.len(),
597 fields.len(),
598 record_number + 1
599 );
600 }
601
602 column_type_possibilities.iter_mut().zip(&fields).for_each(
604 |(possibilities, field)| {
605 possibilities.insert(field.data_type().clone());
606 },
607 );
608
609 if fields.len() > column_type_possibilities.len() {
611 for field in fields.iter().skip(column_type_possibilities.len()) {
613 column_names.push(field.name().clone());
614 let mut possibilities = HashSet::new();
615 if records_read > 0 {
616 possibilities.insert(field.data_type().clone());
617 }
618 column_type_possibilities.push(possibilities);
619 }
620 }
621 }
622
623 if records_to_read == 0 {
624 break;
625 }
626 }
627
628 let schema = build_schema_helper(
629 column_names,
630 column_type_possibilities,
631 initial_records_to_read == 0,
632 );
633 Ok((schema, total_records_read))
634 }
635}
636
637fn build_schema_helper(
649 names: Vec<String>,
650 types: Vec<HashSet<DataType>>,
651 disable_inference: bool,
652) -> Schema {
653 let fields = names
654 .into_iter()
655 .zip(types)
656 .map(|(field_name, mut data_type_possibilities)| {
657 data_type_possibilities.remove(&DataType::Null);
663
664 match data_type_possibilities.len() {
665 0 => {
670 if disable_inference {
671 Field::new(field_name, DataType::Utf8, true)
672 } else {
673 Field::new(field_name, DataType::Null, true)
674 }
675 }
676 1 => Field::new(
677 field_name,
678 data_type_possibilities.iter().next().unwrap().clone(),
679 true,
680 ),
681 2 => {
682 if data_type_possibilities.contains(&DataType::Int64)
683 && data_type_possibilities.contains(&DataType::Float64)
684 {
685 Field::new(field_name, DataType::Float64, true)
687 } else {
688 Field::new(field_name, DataType::Utf8, true)
690 }
691 }
692 _ => Field::new(field_name, DataType::Utf8, true),
693 }
694 })
695 .collect::<Fields>();
696 Schema::new(fields)
697}
698
699impl Default for CsvSerializer {
700 fn default() -> Self {
701 Self::new()
702 }
703}
704
705pub struct CsvSerializer {
707 builder: WriterBuilder,
709 header: bool,
711}
712
713impl CsvSerializer {
714 pub fn new() -> Self {
716 Self {
717 builder: WriterBuilder::new(),
718 header: true,
719 }
720 }
721
722 pub fn with_builder(mut self, builder: WriterBuilder) -> Self {
724 self.builder = builder;
725 self
726 }
727
728 pub fn with_header(mut self, header: bool) -> Self {
730 self.header = header;
731 self
732 }
733}
734
735impl BatchSerializer for CsvSerializer {
736 fn serialize(&self, batch: RecordBatch, initial: bool) -> Result<Bytes> {
737 let mut buffer = Vec::with_capacity(4096);
738 let builder = self.builder.clone();
739 let header = self.header && initial;
740 let mut writer = builder.with_header(header).build(&mut buffer);
741 writer.write(&batch)?;
742 drop(writer);
743 Ok(Bytes::from(buffer))
744 }
745}
746
747pub struct CsvSink {
749 config: FileSinkConfig,
751 writer_options: CsvWriterOptions,
752}
753
754impl Debug for CsvSink {
755 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
756 f.debug_struct("CsvSink").finish()
757 }
758}
759
760impl DisplayAs for CsvSink {
761 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
762 match t {
763 DisplayFormatType::Default | DisplayFormatType::Verbose => {
764 write!(f, "CsvSink(file_groups=",)?;
765 FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?;
766 write!(f, ")")
767 }
768 DisplayFormatType::TreeRender => {
769 writeln!(f, "format: csv")?;
770 write!(f, "file={}", &self.config.original_url)
771 }
772 }
773 }
774}
775
776impl CsvSink {
777 pub fn new(config: FileSinkConfig, writer_options: CsvWriterOptions) -> Self {
779 Self {
780 config,
781 writer_options,
782 }
783 }
784
785 pub fn writer_options(&self) -> &CsvWriterOptions {
787 &self.writer_options
788 }
789}
790
791#[async_trait]
792impl FileSink for CsvSink {
793 fn config(&self) -> &FileSinkConfig {
794 &self.config
795 }
796
797 async fn spawn_writer_tasks_and_join(
798 &self,
799 context: &Arc<TaskContext>,
800 demux_task: SpawnedTask<Result<()>>,
801 file_stream_rx: DemuxedStreamReceiver,
802 object_store: Arc<dyn ObjectStore>,
803 ) -> Result<u64> {
804 let builder = self.writer_options.writer_options.clone();
805 let header = builder.header();
806 let serializer = Arc::new(
807 CsvSerializer::new()
808 .with_builder(builder)
809 .with_header(header),
810 ) as _;
811 spawn_writer_tasks_and_join(
812 context,
813 serializer,
814 self.writer_options.compression.into(),
815 self.writer_options.compression_level,
816 object_store,
817 demux_task,
818 file_stream_rx,
819 )
820 .await
821 }
822}
823
824#[async_trait]
825impl DataSink for CsvSink {
826 fn as_any(&self) -> &dyn Any {
827 self
828 }
829
830 fn schema(&self) -> &SchemaRef {
831 self.config.output_schema()
832 }
833
834 async fn write_all(
835 &self,
836 data: SendableRecordBatchStream,
837 context: &Arc<TaskContext>,
838 ) -> Result<u64> {
839 FileSink::write_all(self, data, context).await
840 }
841}
842
843#[cfg(test)]
844mod tests {
845 use super::build_schema_helper;
846 use arrow::datatypes::DataType;
847 use std::collections::HashSet;
848
849 #[test]
850 fn test_build_schema_helper_different_column_counts() {
851 let mut column_names =
853 vec!["col1".to_string(), "col2".to_string(), "col3".to_string()];
854
855 column_names.push("col4".to_string());
857 column_names.push("col5".to_string());
858
859 let column_type_possibilities = vec![
860 HashSet::from([DataType::Int64]),
861 HashSet::from([DataType::Utf8]),
862 HashSet::from([DataType::Float64]),
863 HashSet::from([DataType::Utf8]), HashSet::from([DataType::Utf8]), ];
866
867 let schema = build_schema_helper(column_names, column_type_possibilities, false);
868
869 assert_eq!(schema.fields().len(), 5);
871 assert_eq!(schema.field(0).name(), "col1");
872 assert_eq!(schema.field(1).name(), "col2");
873 assert_eq!(schema.field(2).name(), "col3");
874 assert_eq!(schema.field(3).name(), "col4");
875 assert_eq!(schema.field(4).name(), "col5");
876
877 for field in schema.fields() {
879 assert!(
880 field.is_nullable(),
881 "Field {} should be nullable",
882 field.name()
883 );
884 }
885 }
886
887 #[test]
888 fn test_build_schema_helper_type_merging() {
889 let column_names = vec!["col1".to_string(), "col2".to_string()];
891
892 let column_type_possibilities = vec![
893 HashSet::from([DataType::Int64, DataType::Float64]), HashSet::from([DataType::Utf8]), ];
896
897 let schema = build_schema_helper(column_names, column_type_possibilities, false);
898
899 assert_eq!(*schema.field(0).data_type(), DataType::Float64);
901
902 assert_eq!(*schema.field(1).data_type(), DataType::Utf8);
904 }
905
906 #[test]
907 fn test_build_schema_helper_conflicting_types() {
908 let column_names = vec!["col1".to_string()];
910
911 let column_type_possibilities = vec![
912 HashSet::from([DataType::Boolean, DataType::Int64, DataType::Utf8]), ];
914
915 let schema = build_schema_helper(column_names, column_type_possibilities, false);
916
917 assert_eq!(*schema.field(0).data_type(), DataType::Utf8);
919 }
920}