datafusion_datasource_csv/
file_format.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`CsvFormat`], Comma Separated Value (CSV) [`FileFormat`] abstractions
19
20use std::any::Any;
21use std::collections::{HashMap, HashSet};
22use std::fmt::{self, Debug};
23use std::sync::Arc;
24
25use crate::source::CsvSource;
26
27use arrow::array::RecordBatch;
28use arrow::csv::WriterBuilder;
29use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
30use arrow::error::ArrowError;
31use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions};
32use datafusion_common::file_options::csv_writer::CsvWriterOptions;
33use datafusion_common::{
34    exec_err, not_impl_err, DataFusionError, GetExt, Result, Statistics,
35    DEFAULT_CSV_EXTENSION,
36};
37use datafusion_common_runtime::SpawnedTask;
38use datafusion_datasource::decoder::Decoder;
39use datafusion_datasource::display::FileGroupDisplay;
40use datafusion_datasource::file::FileSource;
41use datafusion_datasource::file_compression_type::FileCompressionType;
42use datafusion_datasource::file_format::{
43    FileFormat, FileFormatFactory, DEFAULT_SCHEMA_INFER_MAX_RECORD,
44};
45use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder};
46use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig};
47use datafusion_datasource::sink::{DataSink, DataSinkExec};
48use datafusion_datasource::write::demux::DemuxedStreamReceiver;
49use datafusion_datasource::write::orchestration::spawn_writer_tasks_and_join;
50use datafusion_datasource::write::BatchSerializer;
51use datafusion_execution::{SendableRecordBatchStream, TaskContext};
52use datafusion_expr::dml::InsertOp;
53use datafusion_physical_expr_common::sort_expr::LexRequirement;
54use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
55use datafusion_session::Session;
56
57use async_trait::async_trait;
58use bytes::{Buf, Bytes};
59use datafusion_datasource::source::DataSourceExec;
60use futures::stream::BoxStream;
61use futures::{pin_mut, Stream, StreamExt, TryStreamExt};
62use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore};
63use regex::Regex;
64
65#[derive(Default)]
66/// Factory used to create [`CsvFormat`]
67pub struct CsvFormatFactory {
68    /// the options for csv file read
69    pub options: Option<CsvOptions>,
70}
71
72impl CsvFormatFactory {
73    /// Creates an instance of [`CsvFormatFactory`]
74    pub fn new() -> Self {
75        Self { options: None }
76    }
77
78    /// Creates an instance of [`CsvFormatFactory`] with customized default options
79    pub fn new_with_options(options: CsvOptions) -> Self {
80        Self {
81            options: Some(options),
82        }
83    }
84}
85
86impl Debug for CsvFormatFactory {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        f.debug_struct("CsvFormatFactory")
89            .field("options", &self.options)
90            .finish()
91    }
92}
93
94impl FileFormatFactory for CsvFormatFactory {
95    fn create(
96        &self,
97        state: &dyn Session,
98        format_options: &HashMap<String, String>,
99    ) -> Result<Arc<dyn FileFormat>> {
100        let csv_options = match &self.options {
101            None => {
102                let mut table_options = state.default_table_options();
103                table_options.set_config_format(ConfigFileType::CSV);
104                table_options.alter_with_string_hash_map(format_options)?;
105                table_options.csv
106            }
107            Some(csv_options) => {
108                let mut csv_options = csv_options.clone();
109                for (k, v) in format_options {
110                    csv_options.set(k, v)?;
111                }
112                csv_options
113            }
114        };
115
116        Ok(Arc::new(CsvFormat::default().with_options(csv_options)))
117    }
118
119    fn default(&self) -> Arc<dyn FileFormat> {
120        Arc::new(CsvFormat::default())
121    }
122
123    fn as_any(&self) -> &dyn Any {
124        self
125    }
126}
127
128impl GetExt for CsvFormatFactory {
129    fn get_ext(&self) -> String {
130        // Removes the dot, i.e. ".csv" -> "csv"
131        DEFAULT_CSV_EXTENSION[1..].to_string()
132    }
133}
134
135/// Character Separated Value [`FileFormat`] implementation.
136#[derive(Debug, Default)]
137pub struct CsvFormat {
138    options: CsvOptions,
139}
140
141impl CsvFormat {
142    /// Return a newline delimited stream from the specified file on
143    /// Stream, decompressing if necessary
144    /// Each returned `Bytes` has a whole number of newline delimited rows
145    async fn read_to_delimited_chunks<'a>(
146        &self,
147        store: &Arc<dyn ObjectStore>,
148        object: &ObjectMeta,
149    ) -> BoxStream<'a, Result<Bytes>> {
150        // stream to only read as many rows as needed into memory
151        let stream = store
152            .get(&object.location)
153            .await
154            .map_err(|e| DataFusionError::ObjectStore(Box::new(e)));
155        let stream = match stream {
156            Ok(stream) => self
157                .read_to_delimited_chunks_from_stream(
158                    stream
159                        .into_stream()
160                        .map_err(|e| DataFusionError::ObjectStore(Box::new(e)))
161                        .boxed(),
162                )
163                .await
164                .map_err(DataFusionError::from)
165                .left_stream(),
166            Err(e) => {
167                futures::stream::once(futures::future::ready(Err(e))).right_stream()
168            }
169        };
170        stream.boxed()
171    }
172
173    /// Convert a stream of bytes into a stream of of [`Bytes`] containing newline
174    /// delimited CSV records, while accounting for `\` and `"`.
175    pub async fn read_to_delimited_chunks_from_stream<'a>(
176        &self,
177        stream: BoxStream<'a, Result<Bytes>>,
178    ) -> BoxStream<'a, Result<Bytes>> {
179        let file_compression_type: FileCompressionType = self.options.compression.into();
180        let decoder = file_compression_type.convert_stream(stream);
181        let stream = match decoder {
182            Ok(decoded_stream) => {
183                newline_delimited_stream(decoded_stream.map_err(|e| match e {
184                    DataFusionError::ObjectStore(e) => *e,
185                    err => object_store::Error::Generic {
186                        store: "read to delimited chunks failed",
187                        source: Box::new(err),
188                    },
189                }))
190                .map_err(DataFusionError::from)
191                .left_stream()
192            }
193            Err(e) => {
194                futures::stream::once(futures::future::ready(Err(e))).right_stream()
195            }
196        };
197        stream.boxed()
198    }
199
200    /// Set the csv options
201    pub fn with_options(mut self, options: CsvOptions) -> Self {
202        self.options = options;
203        self
204    }
205
206    /// Retrieve the csv options
207    pub fn options(&self) -> &CsvOptions {
208        &self.options
209    }
210
211    /// Set a limit in terms of records to scan to infer the schema
212    /// - default to `DEFAULT_SCHEMA_INFER_MAX_RECORD`
213    pub fn with_schema_infer_max_rec(mut self, max_rec: usize) -> Self {
214        self.options.schema_infer_max_rec = Some(max_rec);
215        self
216    }
217
218    /// Set true to indicate that the first line is a header.
219    /// - default to true
220    pub fn with_has_header(mut self, has_header: bool) -> Self {
221        self.options.has_header = Some(has_header);
222        self
223    }
224
225    pub fn with_truncated_rows(mut self, truncated_rows: bool) -> Self {
226        self.options.truncated_rows = Some(truncated_rows);
227        self
228    }
229
230    /// Set the regex to use for null values in the CSV reader.
231    /// - default to treat empty values as null.
232    pub fn with_null_regex(mut self, null_regex: Option<String>) -> Self {
233        self.options.null_regex = null_regex;
234        self
235    }
236
237    /// Returns `Some(true)` if the first line is a header, `Some(false)` if
238    /// it is not, and `None` if it is not specified.
239    pub fn has_header(&self) -> Option<bool> {
240        self.options.has_header
241    }
242
243    /// Lines beginning with this byte are ignored.
244    pub fn with_comment(mut self, comment: Option<u8>) -> Self {
245        self.options.comment = comment;
246        self
247    }
248
249    /// The character separating values within a row.
250    /// - default to ','
251    pub fn with_delimiter(mut self, delimiter: u8) -> Self {
252        self.options.delimiter = delimiter;
253        self
254    }
255
256    /// The quote character in a row.
257    /// - default to '"'
258    pub fn with_quote(mut self, quote: u8) -> Self {
259        self.options.quote = quote;
260        self
261    }
262
263    /// The escape character in a row.
264    /// - default is None
265    pub fn with_escape(mut self, escape: Option<u8>) -> Self {
266        self.options.escape = escape;
267        self
268    }
269
270    /// The character used to indicate the end of a row.
271    /// - default to None (CRLF)
272    pub fn with_terminator(mut self, terminator: Option<u8>) -> Self {
273        self.options.terminator = terminator;
274        self
275    }
276
277    /// Specifies whether newlines in (quoted) values are supported.
278    ///
279    /// Parsing newlines in quoted values may be affected by execution behaviour such as
280    /// parallel file scanning. Setting this to `true` ensures that newlines in values are
281    /// parsed successfully, which may reduce performance.
282    ///
283    /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting.
284    pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self {
285        self.options.newlines_in_values = Some(newlines_in_values);
286        self
287    }
288
289    /// Set a `FileCompressionType` of CSV
290    /// - defaults to `FileCompressionType::UNCOMPRESSED`
291    pub fn with_file_compression_type(
292        mut self,
293        file_compression_type: FileCompressionType,
294    ) -> Self {
295        self.options.compression = file_compression_type.into();
296        self
297    }
298
299    /// Set whether rows should be truncated to the column width
300    /// - defaults to false
301    pub fn with_truncate_rows(mut self, truncate_rows: bool) -> Self {
302        self.options.truncated_rows = Some(truncate_rows);
303        self
304    }
305
306    /// The delimiter character.
307    pub fn delimiter(&self) -> u8 {
308        self.options.delimiter
309    }
310
311    /// The quote character.
312    pub fn quote(&self) -> u8 {
313        self.options.quote
314    }
315
316    /// The escape character.
317    pub fn escape(&self) -> Option<u8> {
318        self.options.escape
319    }
320}
321
322#[derive(Debug)]
323pub struct CsvDecoder {
324    inner: arrow::csv::reader::Decoder,
325}
326
327impl CsvDecoder {
328    pub fn new(decoder: arrow::csv::reader::Decoder) -> Self {
329        Self { inner: decoder }
330    }
331}
332
333impl Decoder for CsvDecoder {
334    fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
335        self.inner.decode(buf)
336    }
337
338    fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
339        self.inner.flush()
340    }
341
342    fn can_flush_early(&self) -> bool {
343        self.inner.capacity() == 0
344    }
345}
346
347impl Debug for CsvSerializer {
348    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
349        f.debug_struct("CsvSerializer")
350            .field("header", &self.header)
351            .finish()
352    }
353}
354
355#[async_trait]
356impl FileFormat for CsvFormat {
357    fn as_any(&self) -> &dyn Any {
358        self
359    }
360
361    fn get_ext(&self) -> String {
362        CsvFormatFactory::new().get_ext()
363    }
364
365    fn get_ext_with_compression(
366        &self,
367        file_compression_type: &FileCompressionType,
368    ) -> Result<String> {
369        let ext = self.get_ext();
370        Ok(format!("{}{}", ext, file_compression_type.get_ext()))
371    }
372
373    fn compression_type(&self) -> Option<FileCompressionType> {
374        Some(self.options.compression.into())
375    }
376
377    async fn infer_schema(
378        &self,
379        state: &dyn Session,
380        store: &Arc<dyn ObjectStore>,
381        objects: &[ObjectMeta],
382    ) -> Result<SchemaRef> {
383        let mut schemas = vec![];
384
385        let mut records_to_read = self
386            .options
387            .schema_infer_max_rec
388            .unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD);
389
390        for object in objects {
391            let stream = self.read_to_delimited_chunks(store, object).await;
392            let (schema, records_read) = self
393                .infer_schema_from_stream(state, records_to_read, stream)
394                .await
395                .map_err(|err| {
396                    DataFusionError::Context(
397                        format!("Error when processing CSV file {}", &object.location),
398                        Box::new(err),
399                    )
400                })?;
401            records_to_read -= records_read;
402            schemas.push(schema);
403            if records_to_read == 0 {
404                break;
405            }
406        }
407
408        let merged_schema = Schema::try_merge(schemas)?;
409        Ok(Arc::new(merged_schema))
410    }
411
412    async fn infer_stats(
413        &self,
414        _state: &dyn Session,
415        _store: &Arc<dyn ObjectStore>,
416        table_schema: SchemaRef,
417        _object: &ObjectMeta,
418    ) -> Result<Statistics> {
419        Ok(Statistics::new_unknown(&table_schema))
420    }
421
422    async fn create_physical_plan(
423        &self,
424        state: &dyn Session,
425        conf: FileScanConfig,
426    ) -> Result<Arc<dyn ExecutionPlan>> {
427        // Consult configuration options for default values
428        let has_header = self
429            .options
430            .has_header
431            .unwrap_or_else(|| state.config_options().catalog.has_header);
432        let newlines_in_values = self
433            .options
434            .newlines_in_values
435            .unwrap_or_else(|| state.config_options().catalog.newlines_in_values);
436
437        let conf_builder = FileScanConfigBuilder::from(conf)
438            .with_file_compression_type(self.options.compression.into())
439            .with_newlines_in_values(newlines_in_values);
440
441        let truncated_rows = self.options.truncated_rows.unwrap_or(false);
442        let source = Arc::new(
443            CsvSource::new(has_header, self.options.delimiter, self.options.quote)
444                .with_escape(self.options.escape)
445                .with_terminator(self.options.terminator)
446                .with_comment(self.options.comment)
447                .with_truncate_rows(truncated_rows),
448        );
449
450        let config = conf_builder.with_source(source).build();
451
452        Ok(DataSourceExec::from_data_source(config))
453    }
454
455    async fn create_writer_physical_plan(
456        &self,
457        input: Arc<dyn ExecutionPlan>,
458        state: &dyn Session,
459        conf: FileSinkConfig,
460        order_requirements: Option<LexRequirement>,
461    ) -> Result<Arc<dyn ExecutionPlan>> {
462        if conf.insert_op != InsertOp::Append {
463            return not_impl_err!("Overwrites are not implemented yet for CSV");
464        }
465
466        // `has_header` and `newlines_in_values` fields of CsvOptions may inherit
467        // their values from session from configuration settings. To support
468        // this logic, writer options are built from the copy of `self.options`
469        // with updated values of these special fields.
470        let has_header = self
471            .options()
472            .has_header
473            .unwrap_or_else(|| state.config_options().catalog.has_header);
474        let newlines_in_values = self
475            .options()
476            .newlines_in_values
477            .unwrap_or_else(|| state.config_options().catalog.newlines_in_values);
478
479        let options = self
480            .options()
481            .clone()
482            .with_has_header(has_header)
483            .with_newlines_in_values(newlines_in_values);
484
485        let writer_options = CsvWriterOptions::try_from(&options)?;
486
487        let sink = Arc::new(CsvSink::new(conf, writer_options));
488
489        Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _)
490    }
491
492    fn file_source(&self) -> Arc<dyn FileSource> {
493        Arc::new(CsvSource::default())
494    }
495}
496
497impl CsvFormat {
498    /// Return the inferred schema reading up to records_to_read from a
499    /// stream of delimited chunks returning the inferred schema and the
500    /// number of lines that were read.
501    ///
502    /// This method can handle CSV files with different numbers of columns.
503    /// The inferred schema will be the union of all columns found across all files.
504    /// Files with fewer columns will have missing columns filled with null values.
505    ///
506    /// # Example
507    ///
508    /// If you have two CSV files:
509    /// - `file1.csv`: `col1,col2,col3`
510    /// - `file2.csv`: `col1,col2,col3,col4,col5`
511    ///
512    /// The inferred schema will contain all 5 columns, with files that don't
513    /// have columns 4 and 5 having null values for those columns.
514    pub async fn infer_schema_from_stream(
515        &self,
516        state: &dyn Session,
517        mut records_to_read: usize,
518        stream: impl Stream<Item = Result<Bytes>>,
519    ) -> Result<(Schema, usize)> {
520        let mut total_records_read = 0;
521        let mut column_names = vec![];
522        let mut column_type_possibilities = vec![];
523        let mut record_number = -1;
524
525        pin_mut!(stream);
526
527        while let Some(chunk) = stream.next().await.transpose()? {
528            record_number += 1;
529            let first_chunk = record_number == 0;
530            let mut format = arrow::csv::reader::Format::default()
531                .with_header(
532                    first_chunk
533                        && self
534                            .options
535                            .has_header
536                            .unwrap_or_else(|| state.config_options().catalog.has_header),
537                )
538                .with_delimiter(self.options.delimiter)
539                .with_quote(self.options.quote)
540                .with_truncated_rows(self.options.truncated_rows.unwrap_or(false));
541
542            if let Some(null_regex) = &self.options.null_regex {
543                let regex = Regex::new(null_regex.as_str())
544                    .expect("Unable to parse CSV null regex.");
545                format = format.with_null_regex(regex);
546            }
547
548            if let Some(escape) = self.options.escape {
549                format = format.with_escape(escape);
550            }
551
552            if let Some(comment) = self.options.comment {
553                format = format.with_comment(comment);
554            }
555
556            let (Schema { fields, .. }, records_read) =
557                format.infer_schema(chunk.reader(), Some(records_to_read))?;
558
559            records_to_read -= records_read;
560            total_records_read += records_read;
561
562            if first_chunk {
563                // set up initial structures for recording inferred schema across chunks
564                (column_names, column_type_possibilities) = fields
565                    .into_iter()
566                    .map(|field| {
567                        let mut possibilities = HashSet::new();
568                        if records_read > 0 {
569                            // at least 1 data row read, record the inferred datatype
570                            possibilities.insert(field.data_type().clone());
571                        }
572                        (field.name().clone(), possibilities)
573                    })
574                    .unzip();
575            } else {
576                if fields.len() != column_type_possibilities.len()
577                    && !self.options.truncated_rows.unwrap_or(false)
578                {
579                    return exec_err!(
580                        "Encountered unequal lengths between records on CSV file whilst inferring schema. \
581                         Expected {} fields, found {} fields at record {}",
582                        column_type_possibilities.len(),
583                        fields.len(),
584                        record_number + 1
585                    );
586                }
587
588                // First update type possibilities for existing columns using zip
589                column_type_possibilities.iter_mut().zip(&fields).for_each(
590                    |(possibilities, field)| {
591                        possibilities.insert(field.data_type().clone());
592                    },
593                );
594
595                // Handle files with different numbers of columns by extending the schema
596                if fields.len() > column_type_possibilities.len() {
597                    // New columns found - extend our tracking structures
598                    for field in fields.iter().skip(column_type_possibilities.len()) {
599                        column_names.push(field.name().clone());
600                        let mut possibilities = HashSet::new();
601                        if records_read > 0 {
602                            possibilities.insert(field.data_type().clone());
603                        }
604                        column_type_possibilities.push(possibilities);
605                    }
606                }
607            }
608
609            if records_to_read == 0 {
610                break;
611            }
612        }
613
614        let schema = build_schema_helper(column_names, column_type_possibilities);
615        Ok((schema, total_records_read))
616    }
617}
618
619fn build_schema_helper(names: Vec<String>, types: Vec<HashSet<DataType>>) -> Schema {
620    let fields = names
621        .into_iter()
622        .zip(types)
623        .map(|(field_name, mut data_type_possibilities)| {
624            // ripped from arrow::csv::reader::infer_reader_schema_with_csv_options
625            // determine data type based on possible types
626            // if there are incompatible types, use DataType::Utf8
627
628            // ignore nulls, to avoid conflicting datatypes (e.g. [nulls, int]) being inferred as Utf8.
629            data_type_possibilities.remove(&DataType::Null);
630
631            match data_type_possibilities.len() {
632                // Return Null for columns with only nulls / empty files
633                // This allows schema merging to work when reading folders
634                // such files along with normal files.
635                0 => Field::new(field_name, DataType::Null, true),
636                1 => Field::new(
637                    field_name,
638                    data_type_possibilities.iter().next().unwrap().clone(),
639                    true,
640                ),
641                2 => {
642                    if data_type_possibilities.contains(&DataType::Int64)
643                        && data_type_possibilities.contains(&DataType::Float64)
644                    {
645                        // we have an integer and double, fall down to double
646                        Field::new(field_name, DataType::Float64, true)
647                    } else {
648                        // default to Utf8 for conflicting datatypes (e.g bool and int)
649                        Field::new(field_name, DataType::Utf8, true)
650                    }
651                }
652                _ => Field::new(field_name, DataType::Utf8, true),
653            }
654        })
655        .collect::<Fields>();
656    Schema::new(fields)
657}
658
659impl Default for CsvSerializer {
660    fn default() -> Self {
661        Self::new()
662    }
663}
664
665/// Define a struct for serializing CSV records to a stream
666pub struct CsvSerializer {
667    // CSV writer builder
668    builder: WriterBuilder,
669    // Flag to indicate whether there will be a header
670    header: bool,
671}
672
673impl CsvSerializer {
674    /// Constructor for the CsvSerializer object
675    pub fn new() -> Self {
676        Self {
677            builder: WriterBuilder::new(),
678            header: true,
679        }
680    }
681
682    /// Method for setting the CSV writer builder
683    pub fn with_builder(mut self, builder: WriterBuilder) -> Self {
684        self.builder = builder;
685        self
686    }
687
688    /// Method for setting the CSV writer header status
689    pub fn with_header(mut self, header: bool) -> Self {
690        self.header = header;
691        self
692    }
693}
694
695impl BatchSerializer for CsvSerializer {
696    fn serialize(&self, batch: RecordBatch, initial: bool) -> Result<Bytes> {
697        let mut buffer = Vec::with_capacity(4096);
698        let builder = self.builder.clone();
699        let header = self.header && initial;
700        let mut writer = builder.with_header(header).build(&mut buffer);
701        writer.write(&batch)?;
702        drop(writer);
703        Ok(Bytes::from(buffer))
704    }
705}
706
707/// Implements [`DataSink`] for writing to a CSV file.
708pub struct CsvSink {
709    /// Config options for writing data
710    config: FileSinkConfig,
711    writer_options: CsvWriterOptions,
712}
713
714impl Debug for CsvSink {
715    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
716        f.debug_struct("CsvSink").finish()
717    }
718}
719
720impl DisplayAs for CsvSink {
721    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
722        match t {
723            DisplayFormatType::Default | DisplayFormatType::Verbose => {
724                write!(f, "CsvSink(file_groups=",)?;
725                FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?;
726                write!(f, ")")
727            }
728            DisplayFormatType::TreeRender => {
729                writeln!(f, "format: csv")?;
730                write!(f, "file={}", &self.config.original_url)
731            }
732        }
733    }
734}
735
736impl CsvSink {
737    /// Create from config.
738    pub fn new(config: FileSinkConfig, writer_options: CsvWriterOptions) -> Self {
739        Self {
740            config,
741            writer_options,
742        }
743    }
744
745    /// Retrieve the writer options
746    pub fn writer_options(&self) -> &CsvWriterOptions {
747        &self.writer_options
748    }
749}
750
751#[async_trait]
752impl FileSink for CsvSink {
753    fn config(&self) -> &FileSinkConfig {
754        &self.config
755    }
756
757    async fn spawn_writer_tasks_and_join(
758        &self,
759        context: &Arc<TaskContext>,
760        demux_task: SpawnedTask<Result<()>>,
761        file_stream_rx: DemuxedStreamReceiver,
762        object_store: Arc<dyn ObjectStore>,
763    ) -> Result<u64> {
764        let builder = self.writer_options.writer_options.clone();
765        let header = builder.header();
766        let serializer = Arc::new(
767            CsvSerializer::new()
768                .with_builder(builder)
769                .with_header(header),
770        ) as _;
771        spawn_writer_tasks_and_join(
772            context,
773            serializer,
774            self.writer_options.compression.into(),
775            object_store,
776            demux_task,
777            file_stream_rx,
778        )
779        .await
780    }
781}
782
783#[async_trait]
784impl DataSink for CsvSink {
785    fn as_any(&self) -> &dyn Any {
786        self
787    }
788
789    fn schema(&self) -> &SchemaRef {
790        self.config.output_schema()
791    }
792
793    async fn write_all(
794        &self,
795        data: SendableRecordBatchStream,
796        context: &Arc<TaskContext>,
797    ) -> Result<u64> {
798        FileSink::write_all(self, data, context).await
799    }
800}
801
802#[cfg(test)]
803mod tests {
804    use super::build_schema_helper;
805    use arrow::datatypes::DataType;
806    use std::collections::HashSet;
807
808    #[test]
809    fn test_build_schema_helper_different_column_counts() {
810        // Test the core schema building logic with different column counts
811        let mut column_names =
812            vec!["col1".to_string(), "col2".to_string(), "col3".to_string()];
813
814        // Simulate adding two more columns from another file
815        column_names.push("col4".to_string());
816        column_names.push("col5".to_string());
817
818        let column_type_possibilities = vec![
819            HashSet::from([DataType::Int64]),
820            HashSet::from([DataType::Utf8]),
821            HashSet::from([DataType::Float64]),
822            HashSet::from([DataType::Utf8]), // col4
823            HashSet::from([DataType::Utf8]), // col5
824        ];
825
826        let schema = build_schema_helper(column_names, column_type_possibilities);
827
828        // Verify schema has 5 columns
829        assert_eq!(schema.fields().len(), 5);
830        assert_eq!(schema.field(0).name(), "col1");
831        assert_eq!(schema.field(1).name(), "col2");
832        assert_eq!(schema.field(2).name(), "col3");
833        assert_eq!(schema.field(3).name(), "col4");
834        assert_eq!(schema.field(4).name(), "col5");
835
836        // All fields should be nullable
837        for field in schema.fields() {
838            assert!(
839                field.is_nullable(),
840                "Field {} should be nullable",
841                field.name()
842            );
843        }
844    }
845
846    #[test]
847    fn test_build_schema_helper_type_merging() {
848        // Test type merging logic
849        let column_names = vec!["col1".to_string(), "col2".to_string()];
850
851        let column_type_possibilities = vec![
852            HashSet::from([DataType::Int64, DataType::Float64]), // Should resolve to Float64
853            HashSet::from([DataType::Utf8]),                     // Should remain Utf8
854        ];
855
856        let schema = build_schema_helper(column_names, column_type_possibilities);
857
858        // col1 should be Float64 due to Int64 + Float64 = Float64
859        assert_eq!(*schema.field(0).data_type(), DataType::Float64);
860
861        // col2 should remain Utf8
862        assert_eq!(*schema.field(1).data_type(), DataType::Utf8);
863    }
864
865    #[test]
866    fn test_build_schema_helper_conflicting_types() {
867        // Test when we have incompatible types - should default to Utf8
868        let column_names = vec!["col1".to_string()];
869
870        let column_type_possibilities = vec![
871            HashSet::from([DataType::Boolean, DataType::Int64, DataType::Utf8]), // Should resolve to Utf8 due to conflicts
872        ];
873
874        let schema = build_schema_helper(column_names, column_type_possibilities);
875
876        // Should default to Utf8 for conflicting types
877        assert_eq!(*schema.field(0).data_type(), DataType::Utf8);
878    }
879}