1use std::any::Any;
21use std::collections::{HashMap, HashSet};
22use std::fmt::{self, Debug};
23use std::sync::Arc;
24
25use crate::source::CsvSource;
26
27use arrow::array::RecordBatch;
28use arrow::csv::WriterBuilder;
29use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
30use arrow::error::ArrowError;
31use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions};
32use datafusion_common::file_options::csv_writer::CsvWriterOptions;
33use datafusion_common::{
34 exec_err, not_impl_err, DataFusionError, GetExt, Result, Statistics,
35 DEFAULT_CSV_EXTENSION,
36};
37use datafusion_common_runtime::SpawnedTask;
38use datafusion_datasource::decoder::Decoder;
39use datafusion_datasource::display::FileGroupDisplay;
40use datafusion_datasource::file::FileSource;
41use datafusion_datasource::file_compression_type::FileCompressionType;
42use datafusion_datasource::file_format::{
43 FileFormat, FileFormatFactory, DEFAULT_SCHEMA_INFER_MAX_RECORD,
44};
45use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder};
46use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig};
47use datafusion_datasource::sink::{DataSink, DataSinkExec};
48use datafusion_datasource::write::demux::DemuxedStreamReceiver;
49use datafusion_datasource::write::orchestration::spawn_writer_tasks_and_join;
50use datafusion_datasource::write::BatchSerializer;
51use datafusion_execution::{SendableRecordBatchStream, TaskContext};
52use datafusion_expr::dml::InsertOp;
53use datafusion_physical_expr_common::sort_expr::LexRequirement;
54use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
55use datafusion_session::Session;
56
57use async_trait::async_trait;
58use bytes::{Buf, Bytes};
59use datafusion_datasource::source::DataSourceExec;
60use futures::stream::BoxStream;
61use futures::{pin_mut, Stream, StreamExt, TryStreamExt};
62use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore};
63use regex::Regex;
64
65#[derive(Default)]
66pub struct CsvFormatFactory {
68 pub options: Option<CsvOptions>,
70}
71
72impl CsvFormatFactory {
73 pub fn new() -> Self {
75 Self { options: None }
76 }
77
78 pub fn new_with_options(options: CsvOptions) -> Self {
80 Self {
81 options: Some(options),
82 }
83 }
84}
85
86impl Debug for CsvFormatFactory {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 f.debug_struct("CsvFormatFactory")
89 .field("options", &self.options)
90 .finish()
91 }
92}
93
94impl FileFormatFactory for CsvFormatFactory {
95 fn create(
96 &self,
97 state: &dyn Session,
98 format_options: &HashMap<String, String>,
99 ) -> Result<Arc<dyn FileFormat>> {
100 let csv_options = match &self.options {
101 None => {
102 let mut table_options = state.default_table_options();
103 table_options.set_config_format(ConfigFileType::CSV);
104 table_options.alter_with_string_hash_map(format_options)?;
105 table_options.csv
106 }
107 Some(csv_options) => {
108 let mut csv_options = csv_options.clone();
109 for (k, v) in format_options {
110 csv_options.set(k, v)?;
111 }
112 csv_options
113 }
114 };
115
116 Ok(Arc::new(CsvFormat::default().with_options(csv_options)))
117 }
118
119 fn default(&self) -> Arc<dyn FileFormat> {
120 Arc::new(CsvFormat::default())
121 }
122
123 fn as_any(&self) -> &dyn Any {
124 self
125 }
126}
127
128impl GetExt for CsvFormatFactory {
129 fn get_ext(&self) -> String {
130 DEFAULT_CSV_EXTENSION[1..].to_string()
132 }
133}
134
135#[derive(Debug, Default)]
137pub struct CsvFormat {
138 options: CsvOptions,
139}
140
141impl CsvFormat {
142 async fn read_to_delimited_chunks<'a>(
146 &self,
147 store: &Arc<dyn ObjectStore>,
148 object: &ObjectMeta,
149 ) -> BoxStream<'a, Result<Bytes>> {
150 let stream = store
152 .get(&object.location)
153 .await
154 .map_err(|e| DataFusionError::ObjectStore(Box::new(e)));
155 let stream = match stream {
156 Ok(stream) => self
157 .read_to_delimited_chunks_from_stream(
158 stream
159 .into_stream()
160 .map_err(|e| DataFusionError::ObjectStore(Box::new(e)))
161 .boxed(),
162 )
163 .await
164 .map_err(DataFusionError::from)
165 .left_stream(),
166 Err(e) => {
167 futures::stream::once(futures::future::ready(Err(e))).right_stream()
168 }
169 };
170 stream.boxed()
171 }
172
173 pub async fn read_to_delimited_chunks_from_stream<'a>(
176 &self,
177 stream: BoxStream<'a, Result<Bytes>>,
178 ) -> BoxStream<'a, Result<Bytes>> {
179 let file_compression_type: FileCompressionType = self.options.compression.into();
180 let decoder = file_compression_type.convert_stream(stream);
181 let stream = match decoder {
182 Ok(decoded_stream) => {
183 newline_delimited_stream(decoded_stream.map_err(|e| match e {
184 DataFusionError::ObjectStore(e) => *e,
185 err => object_store::Error::Generic {
186 store: "read to delimited chunks failed",
187 source: Box::new(err),
188 },
189 }))
190 .map_err(DataFusionError::from)
191 .left_stream()
192 }
193 Err(e) => {
194 futures::stream::once(futures::future::ready(Err(e))).right_stream()
195 }
196 };
197 stream.boxed()
198 }
199
200 pub fn with_options(mut self, options: CsvOptions) -> Self {
202 self.options = options;
203 self
204 }
205
206 pub fn options(&self) -> &CsvOptions {
208 &self.options
209 }
210
211 pub fn with_schema_infer_max_rec(mut self, max_rec: usize) -> Self {
214 self.options.schema_infer_max_rec = Some(max_rec);
215 self
216 }
217
218 pub fn with_has_header(mut self, has_header: bool) -> Self {
221 self.options.has_header = Some(has_header);
222 self
223 }
224
225 pub fn with_null_regex(mut self, null_regex: Option<String>) -> Self {
228 self.options.null_regex = null_regex;
229 self
230 }
231
232 pub fn has_header(&self) -> Option<bool> {
235 self.options.has_header
236 }
237
238 pub fn with_comment(mut self, comment: Option<u8>) -> Self {
240 self.options.comment = comment;
241 self
242 }
243
244 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
247 self.options.delimiter = delimiter;
248 self
249 }
250
251 pub fn with_quote(mut self, quote: u8) -> Self {
254 self.options.quote = quote;
255 self
256 }
257
258 pub fn with_escape(mut self, escape: Option<u8>) -> Self {
261 self.options.escape = escape;
262 self
263 }
264
265 pub fn with_terminator(mut self, terminator: Option<u8>) -> Self {
268 self.options.terminator = terminator;
269 self
270 }
271
272 pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self {
280 self.options.newlines_in_values = Some(newlines_in_values);
281 self
282 }
283
284 pub fn with_file_compression_type(
287 mut self,
288 file_compression_type: FileCompressionType,
289 ) -> Self {
290 self.options.compression = file_compression_type.into();
291 self
292 }
293
294 pub fn delimiter(&self) -> u8 {
296 self.options.delimiter
297 }
298
299 pub fn quote(&self) -> u8 {
301 self.options.quote
302 }
303
304 pub fn escape(&self) -> Option<u8> {
306 self.options.escape
307 }
308}
309
310#[derive(Debug)]
311pub struct CsvDecoder {
312 inner: arrow::csv::reader::Decoder,
313}
314
315impl CsvDecoder {
316 pub fn new(decoder: arrow::csv::reader::Decoder) -> Self {
317 Self { inner: decoder }
318 }
319}
320
321impl Decoder for CsvDecoder {
322 fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
323 self.inner.decode(buf)
324 }
325
326 fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
327 self.inner.flush()
328 }
329
330 fn can_flush_early(&self) -> bool {
331 self.inner.capacity() == 0
332 }
333}
334
335impl Debug for CsvSerializer {
336 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
337 f.debug_struct("CsvSerializer")
338 .field("header", &self.header)
339 .finish()
340 }
341}
342
343#[async_trait]
344impl FileFormat for CsvFormat {
345 fn as_any(&self) -> &dyn Any {
346 self
347 }
348
349 fn get_ext(&self) -> String {
350 CsvFormatFactory::new().get_ext()
351 }
352
353 fn get_ext_with_compression(
354 &self,
355 file_compression_type: &FileCompressionType,
356 ) -> Result<String> {
357 let ext = self.get_ext();
358 Ok(format!("{}{}", ext, file_compression_type.get_ext()))
359 }
360
361 fn compression_type(&self) -> Option<FileCompressionType> {
362 Some(self.options.compression.into())
363 }
364
365 async fn infer_schema(
366 &self,
367 state: &dyn Session,
368 store: &Arc<dyn ObjectStore>,
369 objects: &[ObjectMeta],
370 ) -> Result<SchemaRef> {
371 let mut schemas = vec![];
372
373 let mut records_to_read = self
374 .options
375 .schema_infer_max_rec
376 .unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD);
377
378 for object in objects {
379 let stream = self.read_to_delimited_chunks(store, object).await;
380 let (schema, records_read) = self
381 .infer_schema_from_stream(state, records_to_read, stream)
382 .await
383 .map_err(|err| {
384 DataFusionError::Context(
385 format!("Error when processing CSV file {}", &object.location),
386 Box::new(err),
387 )
388 })?;
389 records_to_read -= records_read;
390 schemas.push(schema);
391 if records_to_read == 0 {
392 break;
393 }
394 }
395
396 let merged_schema = Schema::try_merge(schemas)?;
397 Ok(Arc::new(merged_schema))
398 }
399
400 async fn infer_stats(
401 &self,
402 _state: &dyn Session,
403 _store: &Arc<dyn ObjectStore>,
404 table_schema: SchemaRef,
405 _object: &ObjectMeta,
406 ) -> Result<Statistics> {
407 Ok(Statistics::new_unknown(&table_schema))
408 }
409
410 async fn create_physical_plan(
411 &self,
412 state: &dyn Session,
413 conf: FileScanConfig,
414 ) -> Result<Arc<dyn ExecutionPlan>> {
415 let has_header = self
417 .options
418 .has_header
419 .unwrap_or_else(|| state.config_options().catalog.has_header);
420 let newlines_in_values = self
421 .options
422 .newlines_in_values
423 .unwrap_or_else(|| state.config_options().catalog.newlines_in_values);
424
425 let conf_builder = FileScanConfigBuilder::from(conf)
426 .with_file_compression_type(self.options.compression.into())
427 .with_newlines_in_values(newlines_in_values);
428
429 let source = Arc::new(
430 CsvSource::new(has_header, self.options.delimiter, self.options.quote)
431 .with_escape(self.options.escape)
432 .with_terminator(self.options.terminator)
433 .with_comment(self.options.comment),
434 );
435
436 let config = conf_builder.with_source(source).build();
437
438 Ok(DataSourceExec::from_data_source(config))
439 }
440
441 async fn create_writer_physical_plan(
442 &self,
443 input: Arc<dyn ExecutionPlan>,
444 state: &dyn Session,
445 conf: FileSinkConfig,
446 order_requirements: Option<LexRequirement>,
447 ) -> Result<Arc<dyn ExecutionPlan>> {
448 if conf.insert_op != InsertOp::Append {
449 return not_impl_err!("Overwrites are not implemented yet for CSV");
450 }
451
452 let has_header = self
457 .options()
458 .has_header
459 .unwrap_or_else(|| state.config_options().catalog.has_header);
460 let newlines_in_values = self
461 .options()
462 .newlines_in_values
463 .unwrap_or_else(|| state.config_options().catalog.newlines_in_values);
464
465 let options = self
466 .options()
467 .clone()
468 .with_has_header(has_header)
469 .with_newlines_in_values(newlines_in_values);
470
471 let writer_options = CsvWriterOptions::try_from(&options)?;
472
473 let sink = Arc::new(CsvSink::new(conf, writer_options));
474
475 Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _)
476 }
477
478 fn file_source(&self) -> Arc<dyn FileSource> {
479 Arc::new(CsvSource::default())
480 }
481}
482
483impl CsvFormat {
484 pub async fn infer_schema_from_stream(
488 &self,
489 state: &dyn Session,
490 mut records_to_read: usize,
491 stream: impl Stream<Item = Result<Bytes>>,
492 ) -> Result<(Schema, usize)> {
493 let mut total_records_read = 0;
494 let mut column_names = vec![];
495 let mut column_type_possibilities = vec![];
496 let mut record_number = -1;
497
498 pin_mut!(stream);
499
500 while let Some(chunk) = stream.next().await.transpose()? {
501 record_number += 1;
502 let first_chunk = record_number == 0;
503 let mut format = arrow::csv::reader::Format::default()
504 .with_header(
505 first_chunk
506 && self
507 .options
508 .has_header
509 .unwrap_or_else(|| state.config_options().catalog.has_header),
510 )
511 .with_delimiter(self.options.delimiter)
512 .with_quote(self.options.quote);
513
514 if let Some(null_regex) = &self.options.null_regex {
515 let regex = Regex::new(null_regex.as_str())
516 .expect("Unable to parse CSV null regex.");
517 format = format.with_null_regex(regex);
518 }
519
520 if let Some(escape) = self.options.escape {
521 format = format.with_escape(escape);
522 }
523
524 if let Some(comment) = self.options.comment {
525 format = format.with_comment(comment);
526 }
527
528 let (Schema { fields, .. }, records_read) =
529 format.infer_schema(chunk.reader(), Some(records_to_read))?;
530
531 records_to_read -= records_read;
532 total_records_read += records_read;
533
534 if first_chunk {
535 (column_names, column_type_possibilities) = fields
537 .into_iter()
538 .map(|field| {
539 let mut possibilities = HashSet::new();
540 if records_read > 0 {
541 possibilities.insert(field.data_type().clone());
543 }
544 (field.name().clone(), possibilities)
545 })
546 .unzip();
547 } else {
548 if fields.len() != column_type_possibilities.len() {
549 return exec_err!(
550 "Encountered unequal lengths between records on CSV file whilst inferring schema. \
551 Expected {} fields, found {} fields at record {}",
552 column_type_possibilities.len(),
553 fields.len(),
554 record_number + 1
555 );
556 }
557
558 column_type_possibilities.iter_mut().zip(&fields).for_each(
559 |(possibilities, field)| {
560 possibilities.insert(field.data_type().clone());
561 },
562 );
563 }
564
565 if records_to_read == 0 {
566 break;
567 }
568 }
569
570 let schema = build_schema_helper(column_names, &column_type_possibilities);
571 Ok((schema, total_records_read))
572 }
573}
574
575fn build_schema_helper(names: Vec<String>, types: &[HashSet<DataType>]) -> Schema {
576 let fields = names
577 .into_iter()
578 .zip(types)
579 .map(|(field_name, data_type_possibilities)| {
580 match data_type_possibilities.len() {
584 1 => Field::new(
585 field_name,
586 data_type_possibilities.iter().next().unwrap().clone(),
587 true,
588 ),
589 2 => {
590 if data_type_possibilities.contains(&DataType::Int64)
591 && data_type_possibilities.contains(&DataType::Float64)
592 {
593 Field::new(field_name, DataType::Float64, true)
595 } else {
596 Field::new(field_name, DataType::Utf8, true)
598 }
599 }
600 _ => Field::new(field_name, DataType::Utf8, true),
601 }
602 })
603 .collect::<Fields>();
604 Schema::new(fields)
605}
606
607impl Default for CsvSerializer {
608 fn default() -> Self {
609 Self::new()
610 }
611}
612
613pub struct CsvSerializer {
615 builder: WriterBuilder,
617 header: bool,
619}
620
621impl CsvSerializer {
622 pub fn new() -> Self {
624 Self {
625 builder: WriterBuilder::new(),
626 header: true,
627 }
628 }
629
630 pub fn with_builder(mut self, builder: WriterBuilder) -> Self {
632 self.builder = builder;
633 self
634 }
635
636 pub fn with_header(mut self, header: bool) -> Self {
638 self.header = header;
639 self
640 }
641}
642
643impl BatchSerializer for CsvSerializer {
644 fn serialize(&self, batch: RecordBatch, initial: bool) -> Result<Bytes> {
645 let mut buffer = Vec::with_capacity(4096);
646 let builder = self.builder.clone();
647 let header = self.header && initial;
648 let mut writer = builder.with_header(header).build(&mut buffer);
649 writer.write(&batch)?;
650 drop(writer);
651 Ok(Bytes::from(buffer))
652 }
653}
654
655pub struct CsvSink {
657 config: FileSinkConfig,
659 writer_options: CsvWriterOptions,
660}
661
662impl Debug for CsvSink {
663 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
664 f.debug_struct("CsvSink").finish()
665 }
666}
667
668impl DisplayAs for CsvSink {
669 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
670 match t {
671 DisplayFormatType::Default | DisplayFormatType::Verbose => {
672 write!(f, "CsvSink(file_groups=",)?;
673 FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?;
674 write!(f, ")")
675 }
676 DisplayFormatType::TreeRender => {
677 writeln!(f, "format: csv")?;
678 write!(f, "file={}", &self.config.original_url)
679 }
680 }
681 }
682}
683
684impl CsvSink {
685 pub fn new(config: FileSinkConfig, writer_options: CsvWriterOptions) -> Self {
687 Self {
688 config,
689 writer_options,
690 }
691 }
692
693 pub fn writer_options(&self) -> &CsvWriterOptions {
695 &self.writer_options
696 }
697}
698
699#[async_trait]
700impl FileSink for CsvSink {
701 fn config(&self) -> &FileSinkConfig {
702 &self.config
703 }
704
705 async fn spawn_writer_tasks_and_join(
706 &self,
707 context: &Arc<TaskContext>,
708 demux_task: SpawnedTask<Result<()>>,
709 file_stream_rx: DemuxedStreamReceiver,
710 object_store: Arc<dyn ObjectStore>,
711 ) -> Result<u64> {
712 let builder = self.writer_options.writer_options.clone();
713 let header = builder.header();
714 let serializer = Arc::new(
715 CsvSerializer::new()
716 .with_builder(builder)
717 .with_header(header),
718 ) as _;
719 spawn_writer_tasks_and_join(
720 context,
721 serializer,
722 self.writer_options.compression.into(),
723 object_store,
724 demux_task,
725 file_stream_rx,
726 )
727 .await
728 }
729}
730
731#[async_trait]
732impl DataSink for CsvSink {
733 fn as_any(&self) -> &dyn Any {
734 self
735 }
736
737 fn schema(&self) -> &SchemaRef {
738 self.config.output_schema()
739 }
740
741 async fn write_all(
742 &self,
743 data: SendableRecordBatchStream,
744 context: &Arc<TaskContext>,
745 ) -> Result<u64> {
746 FileSink::write_all(self, data, context).await
747 }
748}