1use std::any::Any;
21use std::collections::{HashMap, HashSet};
22use std::fmt::{self, Debug};
23use std::sync::Arc;
24
25use crate::source::CsvSource;
26
27use arrow::array::RecordBatch;
28use arrow::csv::WriterBuilder;
29use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
30use arrow::error::ArrowError;
31use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions};
32use datafusion_common::file_options::csv_writer::CsvWriterOptions;
33use datafusion_common::{
34 exec_err, not_impl_err, DataFusionError, GetExt, Result, Statistics,
35 DEFAULT_CSV_EXTENSION,
36};
37use datafusion_common_runtime::SpawnedTask;
38use datafusion_datasource::decoder::Decoder;
39use datafusion_datasource::display::FileGroupDisplay;
40use datafusion_datasource::file::FileSource;
41use datafusion_datasource::file_compression_type::FileCompressionType;
42use datafusion_datasource::file_format::{
43 FileFormat, FileFormatFactory, DEFAULT_SCHEMA_INFER_MAX_RECORD,
44};
45use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder};
46use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig};
47use datafusion_datasource::sink::{DataSink, DataSinkExec};
48use datafusion_datasource::write::demux::DemuxedStreamReceiver;
49use datafusion_datasource::write::orchestration::spawn_writer_tasks_and_join;
50use datafusion_datasource::write::BatchSerializer;
51use datafusion_execution::{SendableRecordBatchStream, TaskContext};
52use datafusion_expr::dml::InsertOp;
53use datafusion_physical_expr_common::sort_expr::LexRequirement;
54use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
55use datafusion_session::Session;
56
57use async_trait::async_trait;
58use bytes::{Buf, Bytes};
59use datafusion_datasource::source::DataSourceExec;
60use futures::stream::BoxStream;
61use futures::{pin_mut, Stream, StreamExt, TryStreamExt};
62use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore};
63use regex::Regex;
64
65#[derive(Default)]
66pub struct CsvFormatFactory {
68 pub options: Option<CsvOptions>,
70}
71
72impl CsvFormatFactory {
73 pub fn new() -> Self {
75 Self { options: None }
76 }
77
78 pub fn new_with_options(options: CsvOptions) -> Self {
80 Self {
81 options: Some(options),
82 }
83 }
84}
85
86impl Debug for CsvFormatFactory {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 f.debug_struct("CsvFormatFactory")
89 .field("options", &self.options)
90 .finish()
91 }
92}
93
94impl FileFormatFactory for CsvFormatFactory {
95 fn create(
96 &self,
97 state: &dyn Session,
98 format_options: &HashMap<String, String>,
99 ) -> Result<Arc<dyn FileFormat>> {
100 let csv_options = match &self.options {
101 None => {
102 let mut table_options = state.default_table_options();
103 table_options.set_config_format(ConfigFileType::CSV);
104 table_options.alter_with_string_hash_map(format_options)?;
105 table_options.csv
106 }
107 Some(csv_options) => {
108 let mut csv_options = csv_options.clone();
109 for (k, v) in format_options {
110 csv_options.set(k, v)?;
111 }
112 csv_options
113 }
114 };
115
116 Ok(Arc::new(CsvFormat::default().with_options(csv_options)))
117 }
118
119 fn default(&self) -> Arc<dyn FileFormat> {
120 Arc::new(CsvFormat::default())
121 }
122
123 fn as_any(&self) -> &dyn Any {
124 self
125 }
126}
127
128impl GetExt for CsvFormatFactory {
129 fn get_ext(&self) -> String {
130 DEFAULT_CSV_EXTENSION[1..].to_string()
132 }
133}
134
135#[derive(Debug, Default)]
137pub struct CsvFormat {
138 options: CsvOptions,
139}
140
141impl CsvFormat {
142 async fn read_to_delimited_chunks<'a>(
146 &self,
147 store: &Arc<dyn ObjectStore>,
148 object: &ObjectMeta,
149 ) -> BoxStream<'a, Result<Bytes>> {
150 let stream = store
152 .get(&object.location)
153 .await
154 .map_err(|e| DataFusionError::ObjectStore(Box::new(e)));
155 let stream = match stream {
156 Ok(stream) => self
157 .read_to_delimited_chunks_from_stream(
158 stream
159 .into_stream()
160 .map_err(|e| DataFusionError::ObjectStore(Box::new(e)))
161 .boxed(),
162 )
163 .await
164 .map_err(DataFusionError::from)
165 .left_stream(),
166 Err(e) => {
167 futures::stream::once(futures::future::ready(Err(e))).right_stream()
168 }
169 };
170 stream.boxed()
171 }
172
173 pub async fn read_to_delimited_chunks_from_stream<'a>(
176 &self,
177 stream: BoxStream<'a, Result<Bytes>>,
178 ) -> BoxStream<'a, Result<Bytes>> {
179 let file_compression_type: FileCompressionType = self.options.compression.into();
180 let decoder = file_compression_type.convert_stream(stream);
181 let stream = match decoder {
182 Ok(decoded_stream) => {
183 newline_delimited_stream(decoded_stream.map_err(|e| match e {
184 DataFusionError::ObjectStore(e) => *e,
185 err => object_store::Error::Generic {
186 store: "read to delimited chunks failed",
187 source: Box::new(err),
188 },
189 }))
190 .map_err(DataFusionError::from)
191 .left_stream()
192 }
193 Err(e) => {
194 futures::stream::once(futures::future::ready(Err(e))).right_stream()
195 }
196 };
197 stream.boxed()
198 }
199
200 pub fn with_options(mut self, options: CsvOptions) -> Self {
202 self.options = options;
203 self
204 }
205
206 pub fn options(&self) -> &CsvOptions {
208 &self.options
209 }
210
211 pub fn with_schema_infer_max_rec(mut self, max_rec: usize) -> Self {
214 self.options.schema_infer_max_rec = Some(max_rec);
215 self
216 }
217
218 pub fn with_has_header(mut self, has_header: bool) -> Self {
221 self.options.has_header = Some(has_header);
222 self
223 }
224
225 pub fn with_truncated_rows(mut self, truncated_rows: bool) -> Self {
226 self.options.truncated_rows = Some(truncated_rows);
227 self
228 }
229
230 pub fn with_null_regex(mut self, null_regex: Option<String>) -> Self {
233 self.options.null_regex = null_regex;
234 self
235 }
236
237 pub fn has_header(&self) -> Option<bool> {
240 self.options.has_header
241 }
242
243 pub fn with_comment(mut self, comment: Option<u8>) -> Self {
245 self.options.comment = comment;
246 self
247 }
248
249 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
252 self.options.delimiter = delimiter;
253 self
254 }
255
256 pub fn with_quote(mut self, quote: u8) -> Self {
259 self.options.quote = quote;
260 self
261 }
262
263 pub fn with_escape(mut self, escape: Option<u8>) -> Self {
266 self.options.escape = escape;
267 self
268 }
269
270 pub fn with_terminator(mut self, terminator: Option<u8>) -> Self {
273 self.options.terminator = terminator;
274 self
275 }
276
277 pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self {
285 self.options.newlines_in_values = Some(newlines_in_values);
286 self
287 }
288
289 pub fn with_file_compression_type(
292 mut self,
293 file_compression_type: FileCompressionType,
294 ) -> Self {
295 self.options.compression = file_compression_type.into();
296 self
297 }
298
299 pub fn with_truncate_rows(mut self, truncate_rows: bool) -> Self {
302 self.options.truncated_rows = Some(truncate_rows);
303 self
304 }
305
306 pub fn delimiter(&self) -> u8 {
308 self.options.delimiter
309 }
310
311 pub fn quote(&self) -> u8 {
313 self.options.quote
314 }
315
316 pub fn escape(&self) -> Option<u8> {
318 self.options.escape
319 }
320}
321
322#[derive(Debug)]
323pub struct CsvDecoder {
324 inner: arrow::csv::reader::Decoder,
325}
326
327impl CsvDecoder {
328 pub fn new(decoder: arrow::csv::reader::Decoder) -> Self {
329 Self { inner: decoder }
330 }
331}
332
333impl Decoder for CsvDecoder {
334 fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
335 self.inner.decode(buf)
336 }
337
338 fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
339 self.inner.flush()
340 }
341
342 fn can_flush_early(&self) -> bool {
343 self.inner.capacity() == 0
344 }
345}
346
347impl Debug for CsvSerializer {
348 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
349 f.debug_struct("CsvSerializer")
350 .field("header", &self.header)
351 .finish()
352 }
353}
354
355#[async_trait]
356impl FileFormat for CsvFormat {
357 fn as_any(&self) -> &dyn Any {
358 self
359 }
360
361 fn get_ext(&self) -> String {
362 CsvFormatFactory::new().get_ext()
363 }
364
365 fn get_ext_with_compression(
366 &self,
367 file_compression_type: &FileCompressionType,
368 ) -> Result<String> {
369 let ext = self.get_ext();
370 Ok(format!("{}{}", ext, file_compression_type.get_ext()))
371 }
372
373 fn compression_type(&self) -> Option<FileCompressionType> {
374 Some(self.options.compression.into())
375 }
376
377 async fn infer_schema(
378 &self,
379 state: &dyn Session,
380 store: &Arc<dyn ObjectStore>,
381 objects: &[ObjectMeta],
382 ) -> Result<SchemaRef> {
383 let mut schemas = vec![];
384
385 let mut records_to_read = self
386 .options
387 .schema_infer_max_rec
388 .unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD);
389
390 for object in objects {
391 let stream = self.read_to_delimited_chunks(store, object).await;
392 let (schema, records_read) = self
393 .infer_schema_from_stream(state, records_to_read, stream)
394 .await
395 .map_err(|err| {
396 DataFusionError::Context(
397 format!("Error when processing CSV file {}", &object.location),
398 Box::new(err),
399 )
400 })?;
401 records_to_read -= records_read;
402 schemas.push(schema);
403 if records_to_read == 0 {
404 break;
405 }
406 }
407
408 let merged_schema = Schema::try_merge(schemas)?;
409 Ok(Arc::new(merged_schema))
410 }
411
412 async fn infer_stats(
413 &self,
414 _state: &dyn Session,
415 _store: &Arc<dyn ObjectStore>,
416 table_schema: SchemaRef,
417 _object: &ObjectMeta,
418 ) -> Result<Statistics> {
419 Ok(Statistics::new_unknown(&table_schema))
420 }
421
422 async fn create_physical_plan(
423 &self,
424 state: &dyn Session,
425 conf: FileScanConfig,
426 ) -> Result<Arc<dyn ExecutionPlan>> {
427 let has_header = self
429 .options
430 .has_header
431 .unwrap_or_else(|| state.config_options().catalog.has_header);
432 let newlines_in_values = self
433 .options
434 .newlines_in_values
435 .unwrap_or_else(|| state.config_options().catalog.newlines_in_values);
436
437 let conf_builder = FileScanConfigBuilder::from(conf)
438 .with_file_compression_type(self.options.compression.into())
439 .with_newlines_in_values(newlines_in_values);
440
441 let truncated_rows = self.options.truncated_rows.unwrap_or(false);
442 let source = Arc::new(
443 CsvSource::new(has_header, self.options.delimiter, self.options.quote)
444 .with_escape(self.options.escape)
445 .with_terminator(self.options.terminator)
446 .with_comment(self.options.comment)
447 .with_truncate_rows(truncated_rows),
448 );
449
450 let config = conf_builder.with_source(source).build();
451
452 Ok(DataSourceExec::from_data_source(config))
453 }
454
455 async fn create_writer_physical_plan(
456 &self,
457 input: Arc<dyn ExecutionPlan>,
458 state: &dyn Session,
459 conf: FileSinkConfig,
460 order_requirements: Option<LexRequirement>,
461 ) -> Result<Arc<dyn ExecutionPlan>> {
462 if conf.insert_op != InsertOp::Append {
463 return not_impl_err!("Overwrites are not implemented yet for CSV");
464 }
465
466 let has_header = self
471 .options()
472 .has_header
473 .unwrap_or_else(|| state.config_options().catalog.has_header);
474 let newlines_in_values = self
475 .options()
476 .newlines_in_values
477 .unwrap_or_else(|| state.config_options().catalog.newlines_in_values);
478
479 let options = self
480 .options()
481 .clone()
482 .with_has_header(has_header)
483 .with_newlines_in_values(newlines_in_values);
484
485 let writer_options = CsvWriterOptions::try_from(&options)?;
486
487 let sink = Arc::new(CsvSink::new(conf, writer_options));
488
489 Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _)
490 }
491
492 fn file_source(&self) -> Arc<dyn FileSource> {
493 Arc::new(CsvSource::default())
494 }
495}
496
497impl CsvFormat {
498 pub async fn infer_schema_from_stream(
515 &self,
516 state: &dyn Session,
517 mut records_to_read: usize,
518 stream: impl Stream<Item = Result<Bytes>>,
519 ) -> Result<(Schema, usize)> {
520 let mut total_records_read = 0;
521 let mut column_names = vec![];
522 let mut column_type_possibilities = vec![];
523 let mut record_number = -1;
524
525 pin_mut!(stream);
526
527 while let Some(chunk) = stream.next().await.transpose()? {
528 record_number += 1;
529 let first_chunk = record_number == 0;
530 let mut format = arrow::csv::reader::Format::default()
531 .with_header(
532 first_chunk
533 && self
534 .options
535 .has_header
536 .unwrap_or_else(|| state.config_options().catalog.has_header),
537 )
538 .with_delimiter(self.options.delimiter)
539 .with_quote(self.options.quote)
540 .with_truncated_rows(self.options.truncated_rows.unwrap_or(false));
541
542 if let Some(null_regex) = &self.options.null_regex {
543 let regex = Regex::new(null_regex.as_str())
544 .expect("Unable to parse CSV null regex.");
545 format = format.with_null_regex(regex);
546 }
547
548 if let Some(escape) = self.options.escape {
549 format = format.with_escape(escape);
550 }
551
552 if let Some(comment) = self.options.comment {
553 format = format.with_comment(comment);
554 }
555
556 let (Schema { fields, .. }, records_read) =
557 format.infer_schema(chunk.reader(), Some(records_to_read))?;
558
559 records_to_read -= records_read;
560 total_records_read += records_read;
561
562 if first_chunk {
563 (column_names, column_type_possibilities) = fields
565 .into_iter()
566 .map(|field| {
567 let mut possibilities = HashSet::new();
568 if records_read > 0 {
569 possibilities.insert(field.data_type().clone());
571 }
572 (field.name().clone(), possibilities)
573 })
574 .unzip();
575 } else {
576 if fields.len() != column_type_possibilities.len()
577 && !self.options.truncated_rows.unwrap_or(false)
578 {
579 return exec_err!(
580 "Encountered unequal lengths between records on CSV file whilst inferring schema. \
581 Expected {} fields, found {} fields at record {}",
582 column_type_possibilities.len(),
583 fields.len(),
584 record_number + 1
585 );
586 }
587
588 column_type_possibilities.iter_mut().zip(&fields).for_each(
590 |(possibilities, field)| {
591 possibilities.insert(field.data_type().clone());
592 },
593 );
594
595 if fields.len() > column_type_possibilities.len() {
597 for field in fields.iter().skip(column_type_possibilities.len()) {
599 column_names.push(field.name().clone());
600 let mut possibilities = HashSet::new();
601 if records_read > 0 {
602 possibilities.insert(field.data_type().clone());
603 }
604 column_type_possibilities.push(possibilities);
605 }
606 }
607 }
608
609 if records_to_read == 0 {
610 break;
611 }
612 }
613
614 let schema = build_schema_helper(column_names, column_type_possibilities);
615 Ok((schema, total_records_read))
616 }
617}
618
619fn build_schema_helper(names: Vec<String>, types: Vec<HashSet<DataType>>) -> Schema {
620 let fields = names
621 .into_iter()
622 .zip(types)
623 .map(|(field_name, mut data_type_possibilities)| {
624 data_type_possibilities.remove(&DataType::Null);
630
631 match data_type_possibilities.len() {
632 0 => Field::new(field_name, DataType::Null, true),
636 1 => Field::new(
637 field_name,
638 data_type_possibilities.iter().next().unwrap().clone(),
639 true,
640 ),
641 2 => {
642 if data_type_possibilities.contains(&DataType::Int64)
643 && data_type_possibilities.contains(&DataType::Float64)
644 {
645 Field::new(field_name, DataType::Float64, true)
647 } else {
648 Field::new(field_name, DataType::Utf8, true)
650 }
651 }
652 _ => Field::new(field_name, DataType::Utf8, true),
653 }
654 })
655 .collect::<Fields>();
656 Schema::new(fields)
657}
658
659impl Default for CsvSerializer {
660 fn default() -> Self {
661 Self::new()
662 }
663}
664
665pub struct CsvSerializer {
667 builder: WriterBuilder,
669 header: bool,
671}
672
673impl CsvSerializer {
674 pub fn new() -> Self {
676 Self {
677 builder: WriterBuilder::new(),
678 header: true,
679 }
680 }
681
682 pub fn with_builder(mut self, builder: WriterBuilder) -> Self {
684 self.builder = builder;
685 self
686 }
687
688 pub fn with_header(mut self, header: bool) -> Self {
690 self.header = header;
691 self
692 }
693}
694
695impl BatchSerializer for CsvSerializer {
696 fn serialize(&self, batch: RecordBatch, initial: bool) -> Result<Bytes> {
697 let mut buffer = Vec::with_capacity(4096);
698 let builder = self.builder.clone();
699 let header = self.header && initial;
700 let mut writer = builder.with_header(header).build(&mut buffer);
701 writer.write(&batch)?;
702 drop(writer);
703 Ok(Bytes::from(buffer))
704 }
705}
706
707pub struct CsvSink {
709 config: FileSinkConfig,
711 writer_options: CsvWriterOptions,
712}
713
714impl Debug for CsvSink {
715 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
716 f.debug_struct("CsvSink").finish()
717 }
718}
719
720impl DisplayAs for CsvSink {
721 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
722 match t {
723 DisplayFormatType::Default | DisplayFormatType::Verbose => {
724 write!(f, "CsvSink(file_groups=",)?;
725 FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?;
726 write!(f, ")")
727 }
728 DisplayFormatType::TreeRender => {
729 writeln!(f, "format: csv")?;
730 write!(f, "file={}", &self.config.original_url)
731 }
732 }
733 }
734}
735
736impl CsvSink {
737 pub fn new(config: FileSinkConfig, writer_options: CsvWriterOptions) -> Self {
739 Self {
740 config,
741 writer_options,
742 }
743 }
744
745 pub fn writer_options(&self) -> &CsvWriterOptions {
747 &self.writer_options
748 }
749}
750
751#[async_trait]
752impl FileSink for CsvSink {
753 fn config(&self) -> &FileSinkConfig {
754 &self.config
755 }
756
757 async fn spawn_writer_tasks_and_join(
758 &self,
759 context: &Arc<TaskContext>,
760 demux_task: SpawnedTask<Result<()>>,
761 file_stream_rx: DemuxedStreamReceiver,
762 object_store: Arc<dyn ObjectStore>,
763 ) -> Result<u64> {
764 let builder = self.writer_options.writer_options.clone();
765 let header = builder.header();
766 let serializer = Arc::new(
767 CsvSerializer::new()
768 .with_builder(builder)
769 .with_header(header),
770 ) as _;
771 spawn_writer_tasks_and_join(
772 context,
773 serializer,
774 self.writer_options.compression.into(),
775 object_store,
776 demux_task,
777 file_stream_rx,
778 )
779 .await
780 }
781}
782
783#[async_trait]
784impl DataSink for CsvSink {
785 fn as_any(&self) -> &dyn Any {
786 self
787 }
788
789 fn schema(&self) -> &SchemaRef {
790 self.config.output_schema()
791 }
792
793 async fn write_all(
794 &self,
795 data: SendableRecordBatchStream,
796 context: &Arc<TaskContext>,
797 ) -> Result<u64> {
798 FileSink::write_all(self, data, context).await
799 }
800}
801
802#[cfg(test)]
803mod tests {
804 use super::build_schema_helper;
805 use arrow::datatypes::DataType;
806 use std::collections::HashSet;
807
808 #[test]
809 fn test_build_schema_helper_different_column_counts() {
810 let mut column_names =
812 vec!["col1".to_string(), "col2".to_string(), "col3".to_string()];
813
814 column_names.push("col4".to_string());
816 column_names.push("col5".to_string());
817
818 let column_type_possibilities = vec![
819 HashSet::from([DataType::Int64]),
820 HashSet::from([DataType::Utf8]),
821 HashSet::from([DataType::Float64]),
822 HashSet::from([DataType::Utf8]), HashSet::from([DataType::Utf8]), ];
825
826 let schema = build_schema_helper(column_names, column_type_possibilities);
827
828 assert_eq!(schema.fields().len(), 5);
830 assert_eq!(schema.field(0).name(), "col1");
831 assert_eq!(schema.field(1).name(), "col2");
832 assert_eq!(schema.field(2).name(), "col3");
833 assert_eq!(schema.field(3).name(), "col4");
834 assert_eq!(schema.field(4).name(), "col5");
835
836 for field in schema.fields() {
838 assert!(
839 field.is_nullable(),
840 "Field {} should be nullable",
841 field.name()
842 );
843 }
844 }
845
846 #[test]
847 fn test_build_schema_helper_type_merging() {
848 let column_names = vec!["col1".to_string(), "col2".to_string()];
850
851 let column_type_possibilities = vec![
852 HashSet::from([DataType::Int64, DataType::Float64]), HashSet::from([DataType::Utf8]), ];
855
856 let schema = build_schema_helper(column_names, column_type_possibilities);
857
858 assert_eq!(*schema.field(0).data_type(), DataType::Float64);
860
861 assert_eq!(*schema.field(1).data_type(), DataType::Utf8);
863 }
864
865 #[test]
866 fn test_build_schema_helper_conflicting_types() {
867 let column_names = vec!["col1".to_string()];
869
870 let column_type_possibilities = vec![
871 HashSet::from([DataType::Boolean, DataType::Int64, DataType::Utf8]), ];
873
874 let schema = build_schema_helper(column_names, column_type_possibilities);
875
876 assert_eq!(*schema.field(0).data_type(), DataType::Utf8);
878 }
879}