1use std::any::Any;
21use std::collections::{HashMap, HashSet};
22use std::fmt::{self, Debug};
23use std::sync::Arc;
24
25use crate::source::CsvSource;
26
27use arrow::array::RecordBatch;
28use arrow::csv::WriterBuilder;
29use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
30use arrow::error::ArrowError;
31use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions};
32use datafusion_common::file_options::csv_writer::CsvWriterOptions;
33use datafusion_common::{
34 exec_err, not_impl_err, DataFusionError, GetExt, Result, Statistics,
35 DEFAULT_CSV_EXTENSION,
36};
37use datafusion_common_runtime::SpawnedTask;
38use datafusion_datasource::decoder::Decoder;
39use datafusion_datasource::display::FileGroupDisplay;
40use datafusion_datasource::file::FileSource;
41use datafusion_datasource::file_compression_type::FileCompressionType;
42use datafusion_datasource::file_format::{
43 FileFormat, FileFormatFactory, DEFAULT_SCHEMA_INFER_MAX_RECORD,
44};
45use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder};
46use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig};
47use datafusion_datasource::sink::{DataSink, DataSinkExec};
48use datafusion_datasource::write::demux::DemuxedStreamReceiver;
49use datafusion_datasource::write::orchestration::spawn_writer_tasks_and_join;
50use datafusion_datasource::write::BatchSerializer;
51use datafusion_execution::{SendableRecordBatchStream, TaskContext};
52use datafusion_expr::dml::InsertOp;
53use datafusion_physical_expr::PhysicalExpr;
54use datafusion_physical_expr_common::sort_expr::LexRequirement;
55use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
56use datafusion_session::Session;
57
58use async_trait::async_trait;
59use bytes::{Buf, Bytes};
60use datafusion_datasource::source::DataSourceExec;
61use futures::stream::BoxStream;
62use futures::{pin_mut, Stream, StreamExt, TryStreamExt};
63use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore};
64use regex::Regex;
65
66#[derive(Default)]
67pub struct CsvFormatFactory {
69 pub options: Option<CsvOptions>,
71}
72
73impl CsvFormatFactory {
74 pub fn new() -> Self {
76 Self { options: None }
77 }
78
79 pub fn new_with_options(options: CsvOptions) -> Self {
81 Self {
82 options: Some(options),
83 }
84 }
85}
86
87impl Debug for CsvFormatFactory {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 f.debug_struct("CsvFormatFactory")
90 .field("options", &self.options)
91 .finish()
92 }
93}
94
95impl FileFormatFactory for CsvFormatFactory {
96 fn create(
97 &self,
98 state: &dyn Session,
99 format_options: &HashMap<String, String>,
100 ) -> Result<Arc<dyn FileFormat>> {
101 let csv_options = match &self.options {
102 None => {
103 let mut table_options = state.default_table_options();
104 table_options.set_config_format(ConfigFileType::CSV);
105 table_options.alter_with_string_hash_map(format_options)?;
106 table_options.csv
107 }
108 Some(csv_options) => {
109 let mut csv_options = csv_options.clone();
110 for (k, v) in format_options {
111 csv_options.set(k, v)?;
112 }
113 csv_options
114 }
115 };
116
117 Ok(Arc::new(CsvFormat::default().with_options(csv_options)))
118 }
119
120 fn default(&self) -> Arc<dyn FileFormat> {
121 Arc::new(CsvFormat::default())
122 }
123
124 fn as_any(&self) -> &dyn Any {
125 self
126 }
127}
128
129impl GetExt for CsvFormatFactory {
130 fn get_ext(&self) -> String {
131 DEFAULT_CSV_EXTENSION[1..].to_string()
133 }
134}
135
136#[derive(Debug, Default)]
138pub struct CsvFormat {
139 options: CsvOptions,
140}
141
142impl CsvFormat {
143 async fn read_to_delimited_chunks<'a>(
147 &self,
148 store: &Arc<dyn ObjectStore>,
149 object: &ObjectMeta,
150 ) -> BoxStream<'a, Result<Bytes>> {
151 let stream = store
153 .get(&object.location)
154 .await
155 .map_err(DataFusionError::ObjectStore);
156 let stream = match stream {
157 Ok(stream) => self
158 .read_to_delimited_chunks_from_stream(
159 stream
160 .into_stream()
161 .map_err(DataFusionError::ObjectStore)
162 .boxed(),
163 )
164 .await
165 .map_err(DataFusionError::from)
166 .left_stream(),
167 Err(e) => {
168 futures::stream::once(futures::future::ready(Err(e))).right_stream()
169 }
170 };
171 stream.boxed()
172 }
173
174 pub async fn read_to_delimited_chunks_from_stream<'a>(
177 &self,
178 stream: BoxStream<'a, Result<Bytes>>,
179 ) -> BoxStream<'a, Result<Bytes>> {
180 let file_compression_type: FileCompressionType = self.options.compression.into();
181 let decoder = file_compression_type.convert_stream(stream);
182 let stream = match decoder {
183 Ok(decoded_stream) => {
184 newline_delimited_stream(decoded_stream.map_err(|e| match e {
185 DataFusionError::ObjectStore(e) => e,
186 err => object_store::Error::Generic {
187 store: "read to delimited chunks failed",
188 source: Box::new(err),
189 },
190 }))
191 .map_err(DataFusionError::from)
192 .left_stream()
193 }
194 Err(e) => {
195 futures::stream::once(futures::future::ready(Err(e))).right_stream()
196 }
197 };
198 stream.boxed()
199 }
200
201 pub fn with_options(mut self, options: CsvOptions) -> Self {
203 self.options = options;
204 self
205 }
206
207 pub fn options(&self) -> &CsvOptions {
209 &self.options
210 }
211
212 pub fn with_schema_infer_max_rec(mut self, max_rec: usize) -> Self {
215 self.options.schema_infer_max_rec = Some(max_rec);
216 self
217 }
218
219 pub fn with_has_header(mut self, has_header: bool) -> Self {
222 self.options.has_header = Some(has_header);
223 self
224 }
225
226 pub fn with_null_regex(mut self, null_regex: Option<String>) -> Self {
229 self.options.null_regex = null_regex;
230 self
231 }
232
233 pub fn has_header(&self) -> Option<bool> {
236 self.options.has_header
237 }
238
239 pub fn with_comment(mut self, comment: Option<u8>) -> Self {
241 self.options.comment = comment;
242 self
243 }
244
245 pub fn with_delimiter(mut self, delimiter: u8) -> Self {
248 self.options.delimiter = delimiter;
249 self
250 }
251
252 pub fn with_quote(mut self, quote: u8) -> Self {
255 self.options.quote = quote;
256 self
257 }
258
259 pub fn with_escape(mut self, escape: Option<u8>) -> Self {
262 self.options.escape = escape;
263 self
264 }
265
266 pub fn with_terminator(mut self, terminator: Option<u8>) -> Self {
269 self.options.terminator = terminator;
270 self
271 }
272
273 pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self {
281 self.options.newlines_in_values = Some(newlines_in_values);
282 self
283 }
284
285 pub fn with_file_compression_type(
288 mut self,
289 file_compression_type: FileCompressionType,
290 ) -> Self {
291 self.options.compression = file_compression_type.into();
292 self
293 }
294
295 pub fn delimiter(&self) -> u8 {
297 self.options.delimiter
298 }
299
300 pub fn quote(&self) -> u8 {
302 self.options.quote
303 }
304
305 pub fn escape(&self) -> Option<u8> {
307 self.options.escape
308 }
309}
310
311#[derive(Debug)]
312pub struct CsvDecoder {
313 inner: arrow::csv::reader::Decoder,
314}
315
316impl CsvDecoder {
317 pub fn new(decoder: arrow::csv::reader::Decoder) -> Self {
318 Self { inner: decoder }
319 }
320}
321
322impl Decoder for CsvDecoder {
323 fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
324 self.inner.decode(buf)
325 }
326
327 fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
328 self.inner.flush()
329 }
330
331 fn can_flush_early(&self) -> bool {
332 self.inner.capacity() == 0
333 }
334}
335
336impl Debug for CsvSerializer {
337 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
338 f.debug_struct("CsvSerializer")
339 .field("header", &self.header)
340 .finish()
341 }
342}
343
344#[async_trait]
345impl FileFormat for CsvFormat {
346 fn as_any(&self) -> &dyn Any {
347 self
348 }
349
350 fn get_ext(&self) -> String {
351 CsvFormatFactory::new().get_ext()
352 }
353
354 fn get_ext_with_compression(
355 &self,
356 file_compression_type: &FileCompressionType,
357 ) -> Result<String> {
358 let ext = self.get_ext();
359 Ok(format!("{}{}", ext, file_compression_type.get_ext()))
360 }
361
362 async fn infer_schema(
363 &self,
364 state: &dyn Session,
365 store: &Arc<dyn ObjectStore>,
366 objects: &[ObjectMeta],
367 ) -> Result<SchemaRef> {
368 let mut schemas = vec![];
369
370 let mut records_to_read = self
371 .options
372 .schema_infer_max_rec
373 .unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD);
374
375 for object in objects {
376 let stream = self.read_to_delimited_chunks(store, object).await;
377 let (schema, records_read) = self
378 .infer_schema_from_stream(state, records_to_read, stream)
379 .await
380 .map_err(|err| {
381 DataFusionError::Context(
382 format!("Error when processing CSV file {}", &object.location),
383 Box::new(err),
384 )
385 })?;
386 records_to_read -= records_read;
387 schemas.push(schema);
388 if records_to_read == 0 {
389 break;
390 }
391 }
392
393 let merged_schema = Schema::try_merge(schemas)?;
394 Ok(Arc::new(merged_schema))
395 }
396
397 async fn infer_stats(
398 &self,
399 _state: &dyn Session,
400 _store: &Arc<dyn ObjectStore>,
401 table_schema: SchemaRef,
402 _object: &ObjectMeta,
403 ) -> Result<Statistics> {
404 Ok(Statistics::new_unknown(&table_schema))
405 }
406
407 async fn create_physical_plan(
408 &self,
409 state: &dyn Session,
410 conf: FileScanConfig,
411 _filters: Option<&Arc<dyn PhysicalExpr>>,
412 ) -> Result<Arc<dyn ExecutionPlan>> {
413 let has_header = self
415 .options
416 .has_header
417 .unwrap_or(state.config_options().catalog.has_header);
418 let newlines_in_values = self
419 .options
420 .newlines_in_values
421 .unwrap_or(state.config_options().catalog.newlines_in_values);
422
423 let conf_builder = FileScanConfigBuilder::from(conf)
424 .with_file_compression_type(self.options.compression.into())
425 .with_newlines_in_values(newlines_in_values);
426
427 let source = Arc::new(
428 CsvSource::new(has_header, self.options.delimiter, self.options.quote)
429 .with_escape(self.options.escape)
430 .with_terminator(self.options.terminator)
431 .with_comment(self.options.comment),
432 );
433
434 let config = conf_builder.with_source(source).build();
435
436 Ok(DataSourceExec::from_data_source(config))
437 }
438
439 async fn create_writer_physical_plan(
440 &self,
441 input: Arc<dyn ExecutionPlan>,
442 state: &dyn Session,
443 conf: FileSinkConfig,
444 order_requirements: Option<LexRequirement>,
445 ) -> Result<Arc<dyn ExecutionPlan>> {
446 if conf.insert_op != InsertOp::Append {
447 return not_impl_err!("Overwrites are not implemented yet for CSV");
448 }
449
450 let has_header = self
455 .options()
456 .has_header
457 .unwrap_or(state.config_options().catalog.has_header);
458 let newlines_in_values = self
459 .options()
460 .newlines_in_values
461 .unwrap_or(state.config_options().catalog.newlines_in_values);
462
463 let options = self
464 .options()
465 .clone()
466 .with_has_header(has_header)
467 .with_newlines_in_values(newlines_in_values);
468
469 let writer_options = CsvWriterOptions::try_from(&options)?;
470
471 let sink = Arc::new(CsvSink::new(conf, writer_options));
472
473 Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _)
474 }
475
476 fn file_source(&self) -> Arc<dyn FileSource> {
477 Arc::new(CsvSource::default())
478 }
479}
480
481impl CsvFormat {
482 pub async fn infer_schema_from_stream(
486 &self,
487 state: &dyn Session,
488 mut records_to_read: usize,
489 stream: impl Stream<Item = Result<Bytes>>,
490 ) -> Result<(Schema, usize)> {
491 let mut total_records_read = 0;
492 let mut column_names = vec![];
493 let mut column_type_possibilities = vec![];
494 let mut record_number = -1;
495
496 pin_mut!(stream);
497
498 while let Some(chunk) = stream.next().await.transpose()? {
499 record_number += 1;
500 let first_chunk = record_number == 0;
501 let mut format = arrow::csv::reader::Format::default()
502 .with_header(
503 first_chunk
504 && self
505 .options
506 .has_header
507 .unwrap_or(state.config_options().catalog.has_header),
508 )
509 .with_delimiter(self.options.delimiter)
510 .with_quote(self.options.quote);
511
512 if let Some(null_regex) = &self.options.null_regex {
513 let regex = Regex::new(null_regex.as_str())
514 .expect("Unable to parse CSV null regex.");
515 format = format.with_null_regex(regex);
516 }
517
518 if let Some(escape) = self.options.escape {
519 format = format.with_escape(escape);
520 }
521
522 if let Some(comment) = self.options.comment {
523 format = format.with_comment(comment);
524 }
525
526 let (Schema { fields, .. }, records_read) =
527 format.infer_schema(chunk.reader(), Some(records_to_read))?;
528
529 records_to_read -= records_read;
530 total_records_read += records_read;
531
532 if first_chunk {
533 (column_names, column_type_possibilities) = fields
535 .into_iter()
536 .map(|field| {
537 let mut possibilities = HashSet::new();
538 if records_read > 0 {
539 possibilities.insert(field.data_type().clone());
541 }
542 (field.name().clone(), possibilities)
543 })
544 .unzip();
545 } else {
546 if fields.len() != column_type_possibilities.len() {
547 return exec_err!(
548 "Encountered unequal lengths between records on CSV file whilst inferring schema. \
549 Expected {} fields, found {} fields at record {}",
550 column_type_possibilities.len(),
551 fields.len(),
552 record_number + 1
553 );
554 }
555
556 column_type_possibilities.iter_mut().zip(&fields).for_each(
557 |(possibilities, field)| {
558 possibilities.insert(field.data_type().clone());
559 },
560 );
561 }
562
563 if records_to_read == 0 {
564 break;
565 }
566 }
567
568 let schema = build_schema_helper(column_names, &column_type_possibilities);
569 Ok((schema, total_records_read))
570 }
571}
572
573fn build_schema_helper(names: Vec<String>, types: &[HashSet<DataType>]) -> Schema {
574 let fields = names
575 .into_iter()
576 .zip(types)
577 .map(|(field_name, data_type_possibilities)| {
578 match data_type_possibilities.len() {
582 1 => Field::new(
583 field_name,
584 data_type_possibilities.iter().next().unwrap().clone(),
585 true,
586 ),
587 2 => {
588 if data_type_possibilities.contains(&DataType::Int64)
589 && data_type_possibilities.contains(&DataType::Float64)
590 {
591 Field::new(field_name, DataType::Float64, true)
593 } else {
594 Field::new(field_name, DataType::Utf8, true)
596 }
597 }
598 _ => Field::new(field_name, DataType::Utf8, true),
599 }
600 })
601 .collect::<Fields>();
602 Schema::new(fields)
603}
604
605impl Default for CsvSerializer {
606 fn default() -> Self {
607 Self::new()
608 }
609}
610
611pub struct CsvSerializer {
613 builder: WriterBuilder,
615 header: bool,
617}
618
619impl CsvSerializer {
620 pub fn new() -> Self {
622 Self {
623 builder: WriterBuilder::new(),
624 header: true,
625 }
626 }
627
628 pub fn with_builder(mut self, builder: WriterBuilder) -> Self {
630 self.builder = builder;
631 self
632 }
633
634 pub fn with_header(mut self, header: bool) -> Self {
636 self.header = header;
637 self
638 }
639}
640
641impl BatchSerializer for CsvSerializer {
642 fn serialize(&self, batch: RecordBatch, initial: bool) -> Result<Bytes> {
643 let mut buffer = Vec::with_capacity(4096);
644 let builder = self.builder.clone();
645 let header = self.header && initial;
646 let mut writer = builder.with_header(header).build(&mut buffer);
647 writer.write(&batch)?;
648 drop(writer);
649 Ok(Bytes::from(buffer))
650 }
651}
652
653pub struct CsvSink {
655 config: FileSinkConfig,
657 writer_options: CsvWriterOptions,
658}
659
660impl Debug for CsvSink {
661 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
662 f.debug_struct("CsvSink").finish()
663 }
664}
665
666impl DisplayAs for CsvSink {
667 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
668 match t {
669 DisplayFormatType::Default | DisplayFormatType::Verbose => {
670 write!(f, "CsvSink(file_groups=",)?;
671 FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?;
672 write!(f, ")")
673 }
674 DisplayFormatType::TreeRender => {
675 writeln!(f, "format: csv")?;
676 write!(f, "file={}", &self.config.original_url)
677 }
678 }
679 }
680}
681
682impl CsvSink {
683 pub fn new(config: FileSinkConfig, writer_options: CsvWriterOptions) -> Self {
685 Self {
686 config,
687 writer_options,
688 }
689 }
690
691 pub fn writer_options(&self) -> &CsvWriterOptions {
693 &self.writer_options
694 }
695}
696
697#[async_trait]
698impl FileSink for CsvSink {
699 fn config(&self) -> &FileSinkConfig {
700 &self.config
701 }
702
703 async fn spawn_writer_tasks_and_join(
704 &self,
705 context: &Arc<TaskContext>,
706 demux_task: SpawnedTask<Result<()>>,
707 file_stream_rx: DemuxedStreamReceiver,
708 object_store: Arc<dyn ObjectStore>,
709 ) -> Result<u64> {
710 let builder = self.writer_options.writer_options.clone();
711 let header = builder.header();
712 let serializer = Arc::new(
713 CsvSerializer::new()
714 .with_builder(builder)
715 .with_header(header),
716 ) as _;
717 spawn_writer_tasks_and_join(
718 context,
719 serializer,
720 self.writer_options.compression.into(),
721 object_store,
722 demux_task,
723 file_stream_rx,
724 )
725 .await
726 }
727}
728
729#[async_trait]
730impl DataSink for CsvSink {
731 fn as_any(&self) -> &dyn Any {
732 self
733 }
734
735 fn schema(&self) -> &SchemaRef {
736 self.config.output_schema()
737 }
738
739 async fn write_all(
740 &self,
741 data: SendableRecordBatchStream,
742 context: &Arc<TaskContext>,
743 ) -> Result<u64> {
744 FileSink::write_all(self, data, context).await
745 }
746}