1use std::any::Any;
21use std::collections::{HashMap, HashSet};
22use std::fmt::{self, Debug};
23use std::sync::Arc;
24
25use crate::source::CsvSource;
26
27use arrow::array::RecordBatch;
28use arrow::csv::WriterBuilder;
29use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
30use arrow::error::ArrowError;
31use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions};
32use datafusion_common::file_options::csv_writer::CsvWriterOptions;
33use datafusion_common::{
34 DEFAULT_CSV_EXTENSION, DataFusionError, GetExt, Result, Statistics, exec_err,
35 not_impl_err,
36};
37use datafusion_common_runtime::SpawnedTask;
38use datafusion_datasource::TableSchema;
39use datafusion_datasource::decoder::Decoder;
40use datafusion_datasource::display::FileGroupDisplay;
41use datafusion_datasource::file::FileSource;
42use datafusion_datasource::file_compression_type::FileCompressionType;
43use datafusion_datasource::file_format::{
44 DEFAULT_SCHEMA_INFER_MAX_RECORD, FileFormat, FileFormatFactory,
45};
46use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder};
47use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig};
48use datafusion_datasource::sink::{DataSink, DataSinkExec};
49use datafusion_datasource::write::BatchSerializer;
50use datafusion_datasource::write::demux::DemuxedStreamReceiver;
51use datafusion_datasource::write::orchestration::spawn_writer_tasks_and_join;
52use datafusion_execution::{SendableRecordBatchStream, TaskContext};
53use datafusion_expr::dml::InsertOp;
54use datafusion_physical_expr_common::sort_expr::LexRequirement;
55use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
56use datafusion_session::Session;
57
58use async_trait::async_trait;
59use bytes::{Buf, Bytes};
60use datafusion_datasource::source::DataSourceExec;
61use futures::stream::BoxStream;
62use futures::{Stream, StreamExt, TryStreamExt, pin_mut};
63use object_store::{
64 ObjectMeta, ObjectStore, ObjectStoreExt, delimited::newline_delimited_stream,
65};
66use regex::Regex;
67
68#[derive(Default)]
69pub struct CsvFormatFactory {
71 pub options: Option<CsvOptions>,
73}
74
75impl CsvFormatFactory {
76 pub fn new() -> Self {
78 Self { options: None }
79 }
80
81 pub fn new_with_options(options: CsvOptions) -> Self {
83 Self {
84 options: Some(options),
85 }
86 }
87}
88
89impl Debug for CsvFormatFactory {
90 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91 f.debug_struct("CsvFormatFactory")
92 .field("options", &self.options)
93 .finish()
94 }
95}
96
97impl FileFormatFactory for CsvFormatFactory {
98 fn create(
99 &self,
100 state: &dyn Session,
101 format_options: &HashMap<String, String>,
102 ) -> Result<Arc<dyn FileFormat>> {
103 let csv_options = match &self.options {
104 None => {
105 let mut table_options = state.default_table_options();
106 table_options.set_config_format(ConfigFileType::CSV);
107 table_options.alter_with_string_hash_map(format_options)?;
108 table_options.csv
109 }
110 Some(csv_options) => {
111 let mut csv_options = csv_options.clone();
112 for (k, v) in format_options {
113 csv_options.set(k, v)?;
114 }
115 csv_options
116 }
117 };
118
119 Ok(Arc::new(CsvFormat::default().with_options(csv_options)))
120 }
121
122 fn default(&self) -> Arc<dyn FileFormat> {
123 Arc::new(CsvFormat::default())
124 }
125
126 fn as_any(&self) -> &dyn Any {
127 self
128 }
129}
130
131impl GetExt for CsvFormatFactory {
132 fn get_ext(&self) -> String {
133 DEFAULT_CSV_EXTENSION[1..].to_string()
135 }
136}
137
138#[derive(Debug, Default)]
140pub struct CsvFormat {
141 options: CsvOptions,
142}
143
144impl CsvFormat {
145 async fn read_to_delimited_chunks<'a>(
149 &self,
150 store: &Arc<dyn ObjectStore>,
151 object: &ObjectMeta,
152 ) -> BoxStream<'a, Result<Bytes>> {
153 let stream = store
155 .get(&object.location)
156 .await
157 .map_err(|e| DataFusionError::ObjectStore(Box::new(e)));
158 let stream = match stream {
159 Ok(stream) => self
160 .read_to_delimited_chunks_from_stream(
161 stream
162 .into_stream()
163 .map_err(|e| DataFusionError::ObjectStore(Box::new(e)))
164 .boxed(),
165 )
166 .await
167 .map_err(DataFusionError::from)
168 .left_stream(),
169 Err(e) => {
170 futures::stream::once(futures::future::ready(Err(e))).right_stream()
171 }
172 };
173 stream.boxed()
174 }
175
176 pub async fn read_to_delimited_chunks_from_stream<'a>(
179 &self,
180 stream: BoxStream<'a, Result<Bytes>>,
181 ) -> BoxStream<'a, Result<Bytes>> {
182 let file_compression_type: FileCompressionType = self.options.compression.into();
183 let decoder = file_compression_type.convert_stream(stream);
184 let stream = match decoder {
185 Ok(decoded_stream) => {
186 newline_delimited_stream(decoded_stream.map_err(|e| match e {
187 DataFusionError::ObjectStore(e) => *e,
188 err => object_store::Error::Generic {
189 store: "read to delimited chunks failed",
190 source: Box::new(err),
191 },
192 }))
193 .map_err(DataFusionError::from)
194 .left_stream()
195 }
196 Err(e) => {
197 futures::stream::once(futures::future::ready(Err(e))).right_stream()
198 }
199 };
200 stream.boxed()
201 }
202
203 pub fn with_options(mut self, options: CsvOptions) -> Self {
205 self.options = options;
206 self
207 }
208
209 pub fn options(&self) -> &CsvOptions {
211 &self.options
212 }
213
214 pub fn with_schema_infer_max_rec(mut self, max_rec: usize) -> Self {
222 self.options.schema_infer_max_rec = Some(max_rec);
223 self
224 }
225
226 pub fn with_has_header(mut self, has_header: bool) -> Self {
229 self.options.has_header = Some(has_header);
230 self
231 }
232
233 pub fn with_truncated_rows(mut self, truncated_rows: bool) -> Self {
234 self.options.truncated_rows = Some(truncated_rows);
235 self
236 }
237
238 pub fn with_null_regex(mut self, null_regex: Option<String>) -> Self {
241 self.options.null_regex = null_regex;
242 self
243 }
244
245 pub fn has_header(&self) -> Option<bool> {
248 self.options.has_header
249 }
250
251 pub fn with_comment(mut self, comment: Option<u8>) -> Self {
253 self.options.comment = comment;
254 self
255 }
256
257 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
260 self.options.delimiter = delimiter;
261 self
262 }
263
264 pub fn with_quote(mut self, quote: u8) -> Self {
267 self.options.quote = quote;
268 self
269 }
270
271 pub fn with_escape(mut self, escape: Option<u8>) -> Self {
274 self.options.escape = escape;
275 self
276 }
277
278 pub fn with_terminator(mut self, terminator: Option<u8>) -> Self {
281 self.options.terminator = terminator;
282 self
283 }
284
285 pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self {
293 self.options.newlines_in_values = Some(newlines_in_values);
294 self
295 }
296
297 pub fn with_file_compression_type(
300 mut self,
301 file_compression_type: FileCompressionType,
302 ) -> Self {
303 self.options.compression = file_compression_type.into();
304 self
305 }
306
307 pub fn with_truncate_rows(mut self, truncate_rows: bool) -> Self {
310 self.options.truncated_rows = Some(truncate_rows);
311 self
312 }
313
314 pub fn delimiter(&self) -> u8 {
316 self.options.delimiter
317 }
318
319 pub fn quote(&self) -> u8 {
321 self.options.quote
322 }
323
324 pub fn escape(&self) -> Option<u8> {
326 self.options.escape
327 }
328}
329
330#[derive(Debug)]
331pub struct CsvDecoder {
332 inner: arrow::csv::reader::Decoder,
333}
334
335impl CsvDecoder {
336 pub fn new(decoder: arrow::csv::reader::Decoder) -> Self {
337 Self { inner: decoder }
338 }
339}
340
341impl Decoder for CsvDecoder {
342 fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
343 self.inner.decode(buf)
344 }
345
346 fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
347 self.inner.flush()
348 }
349
350 fn can_flush_early(&self) -> bool {
351 self.inner.capacity() == 0
352 }
353}
354
355impl Debug for CsvSerializer {
356 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
357 f.debug_struct("CsvSerializer")
358 .field("header", &self.header)
359 .finish()
360 }
361}
362
363#[async_trait]
364impl FileFormat for CsvFormat {
365 fn as_any(&self) -> &dyn Any {
366 self
367 }
368
369 fn get_ext(&self) -> String {
370 CsvFormatFactory::new().get_ext()
371 }
372
373 fn get_ext_with_compression(
374 &self,
375 file_compression_type: &FileCompressionType,
376 ) -> Result<String> {
377 let ext = self.get_ext();
378 Ok(format!("{}{}", ext, file_compression_type.get_ext()))
379 }
380
381 fn compression_type(&self) -> Option<FileCompressionType> {
382 Some(self.options.compression.into())
383 }
384
385 async fn infer_schema(
386 &self,
387 state: &dyn Session,
388 store: &Arc<dyn ObjectStore>,
389 objects: &[ObjectMeta],
390 ) -> Result<SchemaRef> {
391 let mut schemas = vec![];
392
393 let mut records_to_read = self
394 .options
395 .schema_infer_max_rec
396 .unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD);
397
398 for object in objects {
399 let stream = self.read_to_delimited_chunks(store, object).await;
400 let (schema, records_read) = self
401 .infer_schema_from_stream(state, records_to_read, stream)
402 .await
403 .map_err(|err| {
404 DataFusionError::Context(
405 format!("Error when processing CSV file {}", &object.location),
406 Box::new(err),
407 )
408 })?;
409 records_to_read -= records_read;
410 schemas.push(schema);
411 if records_to_read == 0 {
412 break;
413 }
414 }
415
416 let merged_schema = Schema::try_merge(schemas)?;
417 Ok(Arc::new(merged_schema))
418 }
419
420 async fn infer_stats(
421 &self,
422 _state: &dyn Session,
423 _store: &Arc<dyn ObjectStore>,
424 table_schema: SchemaRef,
425 _object: &ObjectMeta,
426 ) -> Result<Statistics> {
427 Ok(Statistics::new_unknown(&table_schema))
428 }
429
430 async fn create_physical_plan(
431 &self,
432 state: &dyn Session,
433 conf: FileScanConfig,
434 ) -> Result<Arc<dyn ExecutionPlan>> {
435 let has_header = self
437 .options
438 .has_header
439 .unwrap_or_else(|| state.config_options().catalog.has_header);
440 let newlines_in_values = self
441 .options
442 .newlines_in_values
443 .unwrap_or_else(|| state.config_options().catalog.newlines_in_values);
444
445 let mut csv_options = self.options.clone();
446 csv_options.has_header = Some(has_header);
447 csv_options.newlines_in_values = Some(newlines_in_values);
448
449 let csv_source = conf
452 .file_source
453 .as_any()
454 .downcast_ref::<CsvSource>()
455 .expect("file_source should be a CsvSource");
456 let source = Arc::new(csv_source.clone().with_csv_options(csv_options));
457
458 let config = FileScanConfigBuilder::from(conf)
459 .with_file_compression_type(self.options.compression.into())
460 .with_source(source)
461 .build();
462
463 Ok(DataSourceExec::from_data_source(config))
464 }
465
466 async fn create_writer_physical_plan(
467 &self,
468 input: Arc<dyn ExecutionPlan>,
469 state: &dyn Session,
470 conf: FileSinkConfig,
471 order_requirements: Option<LexRequirement>,
472 ) -> Result<Arc<dyn ExecutionPlan>> {
473 if conf.insert_op != InsertOp::Append {
474 return not_impl_err!("Overwrites are not implemented yet for CSV");
475 }
476
477 let has_header = self
482 .options()
483 .has_header
484 .unwrap_or_else(|| state.config_options().catalog.has_header);
485 let newlines_in_values = self
486 .options()
487 .newlines_in_values
488 .unwrap_or_else(|| state.config_options().catalog.newlines_in_values);
489
490 let options = self
491 .options()
492 .clone()
493 .with_has_header(has_header)
494 .with_newlines_in_values(newlines_in_values);
495
496 let writer_options = CsvWriterOptions::try_from(&options)?;
497
498 let sink = Arc::new(CsvSink::new(conf, writer_options));
499
500 Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _)
501 }
502
503 fn file_source(&self, table_schema: TableSchema) -> Arc<dyn FileSource> {
504 let mut csv_options = self.options.clone();
505 if csv_options.has_header.is_none() {
506 csv_options.has_header = Some(true);
507 }
508 Arc::new(CsvSource::new(table_schema).with_csv_options(csv_options))
509 }
510}
511
512impl CsvFormat {
513 pub async fn infer_schema_from_stream(
530 &self,
531 state: &dyn Session,
532 mut records_to_read: usize,
533 stream: impl Stream<Item = Result<Bytes>>,
534 ) -> Result<(Schema, usize)> {
535 let mut total_records_read = 0;
536 let mut column_names = vec![];
537 let mut column_type_possibilities = vec![];
538 let mut record_number = -1;
539 let initial_records_to_read = records_to_read;
540
541 pin_mut!(stream);
542
543 while let Some(chunk) = stream.next().await.transpose()? {
544 record_number += 1;
545 let first_chunk = record_number == 0;
546 let mut format = arrow::csv::reader::Format::default()
547 .with_header(
548 first_chunk
549 && self
550 .options
551 .has_header
552 .unwrap_or_else(|| state.config_options().catalog.has_header),
553 )
554 .with_delimiter(self.options.delimiter)
555 .with_quote(self.options.quote)
556 .with_truncated_rows(self.options.truncated_rows.unwrap_or(false));
557
558 if let Some(null_regex) = &self.options.null_regex {
559 let regex = Regex::new(null_regex.as_str())
560 .expect("Unable to parse CSV null regex.");
561 format = format.with_null_regex(regex);
562 }
563
564 if let Some(escape) = self.options.escape {
565 format = format.with_escape(escape);
566 }
567
568 if let Some(comment) = self.options.comment {
569 format = format.with_comment(comment);
570 }
571
572 let (Schema { fields, .. }, records_read) =
573 format.infer_schema(chunk.reader(), Some(records_to_read))?;
574
575 records_to_read -= records_read;
576 total_records_read += records_read;
577
578 if first_chunk {
579 (column_names, column_type_possibilities) = fields
581 .into_iter()
582 .map(|field| {
583 let mut possibilities = HashSet::new();
584 if records_read > 0 {
585 possibilities.insert(field.data_type().clone());
587 }
588 (field.name().clone(), possibilities)
589 })
590 .unzip();
591 } else {
592 if fields.len() != column_type_possibilities.len()
593 && !self.options.truncated_rows.unwrap_or(false)
594 {
595 return exec_err!(
596 "Encountered unequal lengths between records on CSV file whilst inferring schema. \
597 Expected {} fields, found {} fields at record {}",
598 column_type_possibilities.len(),
599 fields.len(),
600 record_number + 1
601 );
602 }
603
604 column_type_possibilities.iter_mut().zip(&fields).for_each(
606 |(possibilities, field)| {
607 possibilities.insert(field.data_type().clone());
608 },
609 );
610
611 if fields.len() > column_type_possibilities.len() {
613 for field in fields.iter().skip(column_type_possibilities.len()) {
615 column_names.push(field.name().clone());
616 let mut possibilities = HashSet::new();
617 if records_read > 0 {
618 possibilities.insert(field.data_type().clone());
619 }
620 column_type_possibilities.push(possibilities);
621 }
622 }
623 }
624
625 if records_to_read == 0 {
626 break;
627 }
628 }
629
630 let schema = build_schema_helper(
631 column_names,
632 column_type_possibilities,
633 initial_records_to_read == 0,
634 );
635 Ok((schema, total_records_read))
636 }
637}
638
639fn build_schema_helper(
651 names: Vec<String>,
652 types: Vec<HashSet<DataType>>,
653 disable_inference: bool,
654) -> Schema {
655 let fields = names
656 .into_iter()
657 .zip(types)
658 .map(|(field_name, mut data_type_possibilities)| {
659 data_type_possibilities.remove(&DataType::Null);
665
666 match data_type_possibilities.len() {
667 0 => {
672 if disable_inference {
673 Field::new(field_name, DataType::Utf8, true)
674 } else {
675 Field::new(field_name, DataType::Null, true)
676 }
677 }
678 1 => Field::new(
679 field_name,
680 data_type_possibilities.iter().next().unwrap().clone(),
681 true,
682 ),
683 2 => {
684 if data_type_possibilities.contains(&DataType::Int64)
685 && data_type_possibilities.contains(&DataType::Float64)
686 {
687 Field::new(field_name, DataType::Float64, true)
689 } else {
690 Field::new(field_name, DataType::Utf8, true)
692 }
693 }
694 _ => Field::new(field_name, DataType::Utf8, true),
695 }
696 })
697 .collect::<Fields>();
698 Schema::new(fields)
699}
700
701impl Default for CsvSerializer {
702 fn default() -> Self {
703 Self::new()
704 }
705}
706
707pub struct CsvSerializer {
709 builder: WriterBuilder,
711 header: bool,
713}
714
715impl CsvSerializer {
716 pub fn new() -> Self {
718 Self {
719 builder: WriterBuilder::new(),
720 header: true,
721 }
722 }
723
724 pub fn with_builder(mut self, builder: WriterBuilder) -> Self {
726 self.builder = builder;
727 self
728 }
729
730 pub fn with_header(mut self, header: bool) -> Self {
732 self.header = header;
733 self
734 }
735}
736
737impl BatchSerializer for CsvSerializer {
738 fn serialize(&self, batch: RecordBatch, initial: bool) -> Result<Bytes> {
739 let mut buffer = Vec::with_capacity(4096);
740 let builder = self.builder.clone();
741 let header = self.header && initial;
742 let mut writer = builder.with_header(header).build(&mut buffer);
743 writer.write(&batch)?;
744 drop(writer);
745 Ok(Bytes::from(buffer))
746 }
747}
748
749pub struct CsvSink {
751 config: FileSinkConfig,
753 writer_options: CsvWriterOptions,
754}
755
756impl Debug for CsvSink {
757 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
758 f.debug_struct("CsvSink").finish()
759 }
760}
761
762impl DisplayAs for CsvSink {
763 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
764 match t {
765 DisplayFormatType::Default | DisplayFormatType::Verbose => {
766 write!(f, "CsvSink(file_groups=",)?;
767 FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?;
768 write!(f, ")")
769 }
770 DisplayFormatType::TreeRender => {
771 writeln!(f, "format: csv")?;
772 write!(f, "file={}", &self.config.original_url)
773 }
774 }
775 }
776}
777
778impl CsvSink {
779 pub fn new(config: FileSinkConfig, writer_options: CsvWriterOptions) -> Self {
781 Self {
782 config,
783 writer_options,
784 }
785 }
786
787 pub fn writer_options(&self) -> &CsvWriterOptions {
789 &self.writer_options
790 }
791}
792
793#[async_trait]
794impl FileSink for CsvSink {
795 fn config(&self) -> &FileSinkConfig {
796 &self.config
797 }
798
799 async fn spawn_writer_tasks_and_join(
800 &self,
801 context: &Arc<TaskContext>,
802 demux_task: SpawnedTask<Result<()>>,
803 file_stream_rx: DemuxedStreamReceiver,
804 object_store: Arc<dyn ObjectStore>,
805 ) -> Result<u64> {
806 let builder = self.writer_options.writer_options.clone();
807 let header = builder.header();
808 let serializer = Arc::new(
809 CsvSerializer::new()
810 .with_builder(builder)
811 .with_header(header),
812 ) as _;
813 spawn_writer_tasks_and_join(
814 context,
815 serializer,
816 self.writer_options.compression.into(),
817 self.writer_options.compression_level,
818 object_store,
819 demux_task,
820 file_stream_rx,
821 )
822 .await
823 }
824}
825
826#[async_trait]
827impl DataSink for CsvSink {
828 fn as_any(&self) -> &dyn Any {
829 self
830 }
831
832 fn schema(&self) -> &SchemaRef {
833 self.config.output_schema()
834 }
835
836 async fn write_all(
837 &self,
838 data: SendableRecordBatchStream,
839 context: &Arc<TaskContext>,
840 ) -> Result<u64> {
841 FileSink::write_all(self, data, context).await
842 }
843}
844
845#[cfg(test)]
846mod tests {
847 use super::build_schema_helper;
848 use arrow::datatypes::DataType;
849 use std::collections::HashSet;
850
851 #[test]
852 fn test_build_schema_helper_different_column_counts() {
853 let mut column_names =
855 vec!["col1".to_string(), "col2".to_string(), "col3".to_string()];
856
857 column_names.push("col4".to_string());
859 column_names.push("col5".to_string());
860
861 let column_type_possibilities = vec![
862 HashSet::from([DataType::Int64]),
863 HashSet::from([DataType::Utf8]),
864 HashSet::from([DataType::Float64]),
865 HashSet::from([DataType::Utf8]), HashSet::from([DataType::Utf8]), ];
868
869 let schema = build_schema_helper(column_names, column_type_possibilities, false);
870
871 assert_eq!(schema.fields().len(), 5);
873 assert_eq!(schema.field(0).name(), "col1");
874 assert_eq!(schema.field(1).name(), "col2");
875 assert_eq!(schema.field(2).name(), "col3");
876 assert_eq!(schema.field(3).name(), "col4");
877 assert_eq!(schema.field(4).name(), "col5");
878
879 for field in schema.fields() {
881 assert!(
882 field.is_nullable(),
883 "Field {} should be nullable",
884 field.name()
885 );
886 }
887 }
888
889 #[test]
890 fn test_build_schema_helper_type_merging() {
891 let column_names = vec!["col1".to_string(), "col2".to_string()];
893
894 let column_type_possibilities = vec![
895 HashSet::from([DataType::Int64, DataType::Float64]), HashSet::from([DataType::Utf8]), ];
898
899 let schema = build_schema_helper(column_names, column_type_possibilities, false);
900
901 assert_eq!(*schema.field(0).data_type(), DataType::Float64);
903
904 assert_eq!(*schema.field(1).data_type(), DataType::Utf8);
906 }
907
908 #[test]
909 fn test_build_schema_helper_conflicting_types() {
910 let column_names = vec!["col1".to_string()];
912
913 let column_type_possibilities = vec![
914 HashSet::from([DataType::Boolean, DataType::Int64, DataType::Utf8]), ];
916
917 let schema = build_schema_helper(column_names, column_type_possibilities, false);
918
919 assert_eq!(*schema.field(0).data_type(), DataType::Utf8);
921 }
922}