datafusion_datasource_csv/
file_format.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`CsvFormat`], Comma Separated Value (CSV) [`FileFormat`] abstractions
19
20use std::any::Any;
21use std::collections::{HashMap, HashSet};
22use std::fmt::{self, Debug};
23use std::sync::Arc;
24
25use crate::source::CsvSource;
26
27use arrow::array::RecordBatch;
28use arrow::csv::WriterBuilder;
29use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
30use arrow::error::ArrowError;
31use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions};
32use datafusion_common::file_options::csv_writer::CsvWriterOptions;
33use datafusion_common::{
34    DEFAULT_CSV_EXTENSION, DataFusionError, GetExt, Result, Statistics, exec_err,
35    not_impl_err,
36};
37use datafusion_common_runtime::SpawnedTask;
38use datafusion_datasource::TableSchema;
39use datafusion_datasource::decoder::Decoder;
40use datafusion_datasource::display::FileGroupDisplay;
41use datafusion_datasource::file::FileSource;
42use datafusion_datasource::file_compression_type::FileCompressionType;
43use datafusion_datasource::file_format::{
44    DEFAULT_SCHEMA_INFER_MAX_RECORD, FileFormat, FileFormatFactory,
45};
46use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder};
47use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig};
48use datafusion_datasource::sink::{DataSink, DataSinkExec};
49use datafusion_datasource::write::BatchSerializer;
50use datafusion_datasource::write::demux::DemuxedStreamReceiver;
51use datafusion_datasource::write::orchestration::spawn_writer_tasks_and_join;
52use datafusion_execution::{SendableRecordBatchStream, TaskContext};
53use datafusion_expr::dml::InsertOp;
54use datafusion_physical_expr_common::sort_expr::LexRequirement;
55use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
56use datafusion_session::Session;
57
58use async_trait::async_trait;
59use bytes::{Buf, Bytes};
60use datafusion_datasource::source::DataSourceExec;
61use futures::stream::BoxStream;
62use futures::{Stream, StreamExt, TryStreamExt, pin_mut};
63use object_store::{ObjectMeta, ObjectStore, delimited::newline_delimited_stream};
64use regex::Regex;
65
66#[derive(Default)]
67/// Factory used to create [`CsvFormat`]
68pub struct CsvFormatFactory {
69    /// the options for csv file read
70    pub options: Option<CsvOptions>,
71}
72
73impl CsvFormatFactory {
74    /// Creates an instance of [`CsvFormatFactory`]
75    pub fn new() -> Self {
76        Self { options: None }
77    }
78
79    /// Creates an instance of [`CsvFormatFactory`] with customized default options
80    pub fn new_with_options(options: CsvOptions) -> Self {
81        Self {
82            options: Some(options),
83        }
84    }
85}
86
87impl Debug for CsvFormatFactory {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        f.debug_struct("CsvFormatFactory")
90            .field("options", &self.options)
91            .finish()
92    }
93}
94
95impl FileFormatFactory for CsvFormatFactory {
96    fn create(
97        &self,
98        state: &dyn Session,
99        format_options: &HashMap<String, String>,
100    ) -> Result<Arc<dyn FileFormat>> {
101        let csv_options = match &self.options {
102            None => {
103                let mut table_options = state.default_table_options();
104                table_options.set_config_format(ConfigFileType::CSV);
105                table_options.alter_with_string_hash_map(format_options)?;
106                table_options.csv
107            }
108            Some(csv_options) => {
109                let mut csv_options = csv_options.clone();
110                for (k, v) in format_options {
111                    csv_options.set(k, v)?;
112                }
113                csv_options
114            }
115        };
116
117        Ok(Arc::new(CsvFormat::default().with_options(csv_options)))
118    }
119
120    fn default(&self) -> Arc<dyn FileFormat> {
121        Arc::new(CsvFormat::default())
122    }
123
124    fn as_any(&self) -> &dyn Any {
125        self
126    }
127}
128
129impl GetExt for CsvFormatFactory {
130    fn get_ext(&self) -> String {
131        // Removes the dot, i.e. ".csv" -> "csv"
132        DEFAULT_CSV_EXTENSION[1..].to_string()
133    }
134}
135
136/// Character Separated Value [`FileFormat`] implementation.
137#[derive(Debug, Default)]
138pub struct CsvFormat {
139    options: CsvOptions,
140}
141
142impl CsvFormat {
143    /// Return a newline delimited stream from the specified file on
144    /// Stream, decompressing if necessary
145    /// Each returned `Bytes` has a whole number of newline delimited rows
146    async fn read_to_delimited_chunks<'a>(
147        &self,
148        store: &Arc<dyn ObjectStore>,
149        object: &ObjectMeta,
150    ) -> BoxStream<'a, Result<Bytes>> {
151        // stream to only read as many rows as needed into memory
152        let stream = store
153            .get(&object.location)
154            .await
155            .map_err(|e| DataFusionError::ObjectStore(Box::new(e)));
156        let stream = match stream {
157            Ok(stream) => self
158                .read_to_delimited_chunks_from_stream(
159                    stream
160                        .into_stream()
161                        .map_err(|e| DataFusionError::ObjectStore(Box::new(e)))
162                        .boxed(),
163                )
164                .await
165                .map_err(DataFusionError::from)
166                .left_stream(),
167            Err(e) => {
168                futures::stream::once(futures::future::ready(Err(e))).right_stream()
169            }
170        };
171        stream.boxed()
172    }
173
174    /// Convert a stream of bytes into a stream of of [`Bytes`] containing newline
175    /// delimited CSV records, while accounting for `\` and `"`.
176    pub async fn read_to_delimited_chunks_from_stream<'a>(
177        &self,
178        stream: BoxStream<'a, Result<Bytes>>,
179    ) -> BoxStream<'a, Result<Bytes>> {
180        let file_compression_type: FileCompressionType = self.options.compression.into();
181        let decoder = file_compression_type.convert_stream(stream);
182        let stream = match decoder {
183            Ok(decoded_stream) => {
184                newline_delimited_stream(decoded_stream.map_err(|e| match e {
185                    DataFusionError::ObjectStore(e) => *e,
186                    err => object_store::Error::Generic {
187                        store: "read to delimited chunks failed",
188                        source: Box::new(err),
189                    },
190                }))
191                .map_err(DataFusionError::from)
192                .left_stream()
193            }
194            Err(e) => {
195                futures::stream::once(futures::future::ready(Err(e))).right_stream()
196            }
197        };
198        stream.boxed()
199    }
200
201    /// Set the csv options
202    pub fn with_options(mut self, options: CsvOptions) -> Self {
203        self.options = options;
204        self
205    }
206
207    /// Retrieve the csv options
208    pub fn options(&self) -> &CsvOptions {
209        &self.options
210    }
211
212    /// Set a limit in terms of records to scan to infer the schema
213    /// - default to `DEFAULT_SCHEMA_INFER_MAX_RECORD`
214    ///
215    /// # Behavior when set to 0
216    ///
217    /// When `max_rec` is set to 0, schema inference is disabled and all fields
218    /// will be inferred as `Utf8` (string) type, regardless of their actual content.
219    pub fn with_schema_infer_max_rec(mut self, max_rec: usize) -> Self {
220        self.options.schema_infer_max_rec = Some(max_rec);
221        self
222    }
223
224    /// Set true to indicate that the first line is a header.
225    /// - default to true
226    pub fn with_has_header(mut self, has_header: bool) -> Self {
227        self.options.has_header = Some(has_header);
228        self
229    }
230
231    pub fn with_truncated_rows(mut self, truncated_rows: bool) -> Self {
232        self.options.truncated_rows = Some(truncated_rows);
233        self
234    }
235
236    /// Set the regex to use for null values in the CSV reader.
237    /// - default to treat empty values as null.
238    pub fn with_null_regex(mut self, null_regex: Option<String>) -> Self {
239        self.options.null_regex = null_regex;
240        self
241    }
242
243    /// Returns `Some(true)` if the first line is a header, `Some(false)` if
244    /// it is not, and `None` if it is not specified.
245    pub fn has_header(&self) -> Option<bool> {
246        self.options.has_header
247    }
248
249    /// Lines beginning with this byte are ignored.
250    pub fn with_comment(mut self, comment: Option<u8>) -> Self {
251        self.options.comment = comment;
252        self
253    }
254
255    /// The character separating values within a row.
256    /// - default to ','
257    pub fn with_delimiter(mut self, delimiter: u8) -> Self {
258        self.options.delimiter = delimiter;
259        self
260    }
261
262    /// The quote character in a row.
263    /// - default to '"'
264    pub fn with_quote(mut self, quote: u8) -> Self {
265        self.options.quote = quote;
266        self
267    }
268
269    /// The escape character in a row.
270    /// - default is None
271    pub fn with_escape(mut self, escape: Option<u8>) -> Self {
272        self.options.escape = escape;
273        self
274    }
275
276    /// The character used to indicate the end of a row.
277    /// - default to None (CRLF)
278    pub fn with_terminator(mut self, terminator: Option<u8>) -> Self {
279        self.options.terminator = terminator;
280        self
281    }
282
283    /// Specifies whether newlines in (quoted) values are supported.
284    ///
285    /// Parsing newlines in quoted values may be affected by execution behaviour such as
286    /// parallel file scanning. Setting this to `true` ensures that newlines in values are
287    /// parsed successfully, which may reduce performance.
288    ///
289    /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting.
290    pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self {
291        self.options.newlines_in_values = Some(newlines_in_values);
292        self
293    }
294
295    /// Set a `FileCompressionType` of CSV
296    /// - defaults to `FileCompressionType::UNCOMPRESSED`
297    pub fn with_file_compression_type(
298        mut self,
299        file_compression_type: FileCompressionType,
300    ) -> Self {
301        self.options.compression = file_compression_type.into();
302        self
303    }
304
305    /// Set whether rows should be truncated to the column width
306    /// - defaults to false
307    pub fn with_truncate_rows(mut self, truncate_rows: bool) -> Self {
308        self.options.truncated_rows = Some(truncate_rows);
309        self
310    }
311
312    /// The delimiter character.
313    pub fn delimiter(&self) -> u8 {
314        self.options.delimiter
315    }
316
317    /// The quote character.
318    pub fn quote(&self) -> u8 {
319        self.options.quote
320    }
321
322    /// The escape character.
323    pub fn escape(&self) -> Option<u8> {
324        self.options.escape
325    }
326}
327
328#[derive(Debug)]
329pub struct CsvDecoder {
330    inner: arrow::csv::reader::Decoder,
331}
332
333impl CsvDecoder {
334    pub fn new(decoder: arrow::csv::reader::Decoder) -> Self {
335        Self { inner: decoder }
336    }
337}
338
339impl Decoder for CsvDecoder {
340    fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
341        self.inner.decode(buf)
342    }
343
344    fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
345        self.inner.flush()
346    }
347
348    fn can_flush_early(&self) -> bool {
349        self.inner.capacity() == 0
350    }
351}
352
353impl Debug for CsvSerializer {
354    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
355        f.debug_struct("CsvSerializer")
356            .field("header", &self.header)
357            .finish()
358    }
359}
360
361#[async_trait]
362impl FileFormat for CsvFormat {
363    fn as_any(&self) -> &dyn Any {
364        self
365    }
366
367    fn get_ext(&self) -> String {
368        CsvFormatFactory::new().get_ext()
369    }
370
371    fn get_ext_with_compression(
372        &self,
373        file_compression_type: &FileCompressionType,
374    ) -> Result<String> {
375        let ext = self.get_ext();
376        Ok(format!("{}{}", ext, file_compression_type.get_ext()))
377    }
378
379    fn compression_type(&self) -> Option<FileCompressionType> {
380        Some(self.options.compression.into())
381    }
382
383    async fn infer_schema(
384        &self,
385        state: &dyn Session,
386        store: &Arc<dyn ObjectStore>,
387        objects: &[ObjectMeta],
388    ) -> Result<SchemaRef> {
389        let mut schemas = vec![];
390
391        let mut records_to_read = self
392            .options
393            .schema_infer_max_rec
394            .unwrap_or(DEFAULT_SCHEMA_INFER_MAX_RECORD);
395
396        for object in objects {
397            let stream = self.read_to_delimited_chunks(store, object).await;
398            let (schema, records_read) = self
399                .infer_schema_from_stream(state, records_to_read, stream)
400                .await
401                .map_err(|err| {
402                    DataFusionError::Context(
403                        format!("Error when processing CSV file {}", &object.location),
404                        Box::new(err),
405                    )
406                })?;
407            records_to_read -= records_read;
408            schemas.push(schema);
409            if records_to_read == 0 {
410                break;
411            }
412        }
413
414        let merged_schema = Schema::try_merge(schemas)?;
415        Ok(Arc::new(merged_schema))
416    }
417
418    async fn infer_stats(
419        &self,
420        _state: &dyn Session,
421        _store: &Arc<dyn ObjectStore>,
422        table_schema: SchemaRef,
423        _object: &ObjectMeta,
424    ) -> Result<Statistics> {
425        Ok(Statistics::new_unknown(&table_schema))
426    }
427
428    async fn create_physical_plan(
429        &self,
430        state: &dyn Session,
431        conf: FileScanConfig,
432    ) -> Result<Arc<dyn ExecutionPlan>> {
433        // Consult configuration options for default values
434        let has_header = self
435            .options
436            .has_header
437            .unwrap_or_else(|| state.config_options().catalog.has_header);
438        let newlines_in_values = self
439            .options
440            .newlines_in_values
441            .unwrap_or_else(|| state.config_options().catalog.newlines_in_values);
442
443        let mut csv_options = self.options.clone();
444        csv_options.has_header = Some(has_header);
445        csv_options.newlines_in_values = Some(newlines_in_values);
446
447        // Get the existing CsvSource and update its options
448        // We need to preserve the table_schema from the original source (which includes partition columns)
449        let csv_source = conf
450            .file_source
451            .as_any()
452            .downcast_ref::<CsvSource>()
453            .expect("file_source should be a CsvSource");
454        let source = Arc::new(csv_source.clone().with_csv_options(csv_options));
455
456        let config = FileScanConfigBuilder::from(conf)
457            .with_file_compression_type(self.options.compression.into())
458            .with_source(source)
459            .build();
460
461        Ok(DataSourceExec::from_data_source(config))
462    }
463
464    async fn create_writer_physical_plan(
465        &self,
466        input: Arc<dyn ExecutionPlan>,
467        state: &dyn Session,
468        conf: FileSinkConfig,
469        order_requirements: Option<LexRequirement>,
470    ) -> Result<Arc<dyn ExecutionPlan>> {
471        if conf.insert_op != InsertOp::Append {
472            return not_impl_err!("Overwrites are not implemented yet for CSV");
473        }
474
475        // `has_header` and `newlines_in_values` fields of CsvOptions may inherit
476        // their values from session from configuration settings. To support
477        // this logic, writer options are built from the copy of `self.options`
478        // with updated values of these special fields.
479        let has_header = self
480            .options()
481            .has_header
482            .unwrap_or_else(|| state.config_options().catalog.has_header);
483        let newlines_in_values = self
484            .options()
485            .newlines_in_values
486            .unwrap_or_else(|| state.config_options().catalog.newlines_in_values);
487
488        let options = self
489            .options()
490            .clone()
491            .with_has_header(has_header)
492            .with_newlines_in_values(newlines_in_values);
493
494        let writer_options = CsvWriterOptions::try_from(&options)?;
495
496        let sink = Arc::new(CsvSink::new(conf, writer_options));
497
498        Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _)
499    }
500
501    fn file_source(&self, table_schema: TableSchema) -> Arc<dyn FileSource> {
502        let mut csv_options = self.options.clone();
503        if csv_options.has_header.is_none() {
504            csv_options.has_header = Some(true);
505        }
506        Arc::new(CsvSource::new(table_schema).with_csv_options(csv_options))
507    }
508}
509
510impl CsvFormat {
511    /// Return the inferred schema reading up to records_to_read from a
512    /// stream of delimited chunks returning the inferred schema and the
513    /// number of lines that were read.
514    ///
515    /// This method can handle CSV files with different numbers of columns.
516    /// The inferred schema will be the union of all columns found across all files.
517    /// Files with fewer columns will have missing columns filled with null values.
518    ///
519    /// # Example
520    ///
521    /// If you have two CSV files:
522    /// - `file1.csv`: `col1,col2,col3`
523    /// - `file2.csv`: `col1,col2,col3,col4,col5`
524    ///
525    /// The inferred schema will contain all 5 columns, with files that don't
526    /// have columns 4 and 5 having null values for those columns.
527    pub async fn infer_schema_from_stream(
528        &self,
529        state: &dyn Session,
530        mut records_to_read: usize,
531        stream: impl Stream<Item = Result<Bytes>>,
532    ) -> Result<(Schema, usize)> {
533        let mut total_records_read = 0;
534        let mut column_names = vec![];
535        let mut column_type_possibilities = vec![];
536        let mut record_number = -1;
537        let initial_records_to_read = records_to_read;
538
539        pin_mut!(stream);
540
541        while let Some(chunk) = stream.next().await.transpose()? {
542            record_number += 1;
543            let first_chunk = record_number == 0;
544            let mut format = arrow::csv::reader::Format::default()
545                .with_header(
546                    first_chunk
547                        && self
548                            .options
549                            .has_header
550                            .unwrap_or_else(|| state.config_options().catalog.has_header),
551                )
552                .with_delimiter(self.options.delimiter)
553                .with_quote(self.options.quote)
554                .with_truncated_rows(self.options.truncated_rows.unwrap_or(false));
555
556            if let Some(null_regex) = &self.options.null_regex {
557                let regex = Regex::new(null_regex.as_str())
558                    .expect("Unable to parse CSV null regex.");
559                format = format.with_null_regex(regex);
560            }
561
562            if let Some(escape) = self.options.escape {
563                format = format.with_escape(escape);
564            }
565
566            if let Some(comment) = self.options.comment {
567                format = format.with_comment(comment);
568            }
569
570            let (Schema { fields, .. }, records_read) =
571                format.infer_schema(chunk.reader(), Some(records_to_read))?;
572
573            records_to_read -= records_read;
574            total_records_read += records_read;
575
576            if first_chunk {
577                // set up initial structures for recording inferred schema across chunks
578                (column_names, column_type_possibilities) = fields
579                    .into_iter()
580                    .map(|field| {
581                        let mut possibilities = HashSet::new();
582                        if records_read > 0 {
583                            // at least 1 data row read, record the inferred datatype
584                            possibilities.insert(field.data_type().clone());
585                        }
586                        (field.name().clone(), possibilities)
587                    })
588                    .unzip();
589            } else {
590                if fields.len() != column_type_possibilities.len()
591                    && !self.options.truncated_rows.unwrap_or(false)
592                {
593                    return exec_err!(
594                        "Encountered unequal lengths between records on CSV file whilst inferring schema. \
595                         Expected {} fields, found {} fields at record {}",
596                        column_type_possibilities.len(),
597                        fields.len(),
598                        record_number + 1
599                    );
600                }
601
602                // First update type possibilities for existing columns using zip
603                column_type_possibilities.iter_mut().zip(&fields).for_each(
604                    |(possibilities, field)| {
605                        possibilities.insert(field.data_type().clone());
606                    },
607                );
608
609                // Handle files with different numbers of columns by extending the schema
610                if fields.len() > column_type_possibilities.len() {
611                    // New columns found - extend our tracking structures
612                    for field in fields.iter().skip(column_type_possibilities.len()) {
613                        column_names.push(field.name().clone());
614                        let mut possibilities = HashSet::new();
615                        if records_read > 0 {
616                            possibilities.insert(field.data_type().clone());
617                        }
618                        column_type_possibilities.push(possibilities);
619                    }
620                }
621            }
622
623            if records_to_read == 0 {
624                break;
625            }
626        }
627
628        let schema = build_schema_helper(
629            column_names,
630            column_type_possibilities,
631            initial_records_to_read == 0,
632        );
633        Ok((schema, total_records_read))
634    }
635}
636
637/// Builds a schema from column names and their possible data types.
638///
639/// # Arguments
640///
641/// * `names` - Vector of column names
642/// * `types` - Vector of possible data types for each column (as HashSets)
643/// * `disable_inference` - When true, forces all columns with no inferred types to be Utf8.
644///   This should be set to true when `schema_infer_max_rec` is explicitly
645///   set to 0, indicating the user wants to skip type inference and treat
646///   all fields as strings. When false, columns with no inferred types
647///   will be set to Null, allowing schema merging to work properly.
648fn build_schema_helper(
649    names: Vec<String>,
650    types: Vec<HashSet<DataType>>,
651    disable_inference: bool,
652) -> Schema {
653    let fields = names
654        .into_iter()
655        .zip(types)
656        .map(|(field_name, mut data_type_possibilities)| {
657            // ripped from arrow::csv::reader::infer_reader_schema_with_csv_options
658            // determine data type based on possible types
659            // if there are incompatible types, use DataType::Utf8
660
661            // ignore nulls, to avoid conflicting datatypes (e.g. [nulls, int]) being inferred as Utf8.
662            data_type_possibilities.remove(&DataType::Null);
663
664            match data_type_possibilities.len() {
665                // When no types were inferred (empty HashSet):
666                // - If schema_infer_max_rec was explicitly set to 0, return Utf8
667                // - Otherwise return Null (whether from reading null values or empty files)
668                //   This allows schema merging to work when reading folders with empty files
669                0 => {
670                    if disable_inference {
671                        Field::new(field_name, DataType::Utf8, true)
672                    } else {
673                        Field::new(field_name, DataType::Null, true)
674                    }
675                }
676                1 => Field::new(
677                    field_name,
678                    data_type_possibilities.iter().next().unwrap().clone(),
679                    true,
680                ),
681                2 => {
682                    if data_type_possibilities.contains(&DataType::Int64)
683                        && data_type_possibilities.contains(&DataType::Float64)
684                    {
685                        // we have an integer and double, fall down to double
686                        Field::new(field_name, DataType::Float64, true)
687                    } else {
688                        // default to Utf8 for conflicting datatypes (e.g bool and int)
689                        Field::new(field_name, DataType::Utf8, true)
690                    }
691                }
692                _ => Field::new(field_name, DataType::Utf8, true),
693            }
694        })
695        .collect::<Fields>();
696    Schema::new(fields)
697}
698
699impl Default for CsvSerializer {
700    fn default() -> Self {
701        Self::new()
702    }
703}
704
705/// Define a struct for serializing CSV records to a stream
706pub struct CsvSerializer {
707    // CSV writer builder
708    builder: WriterBuilder,
709    // Flag to indicate whether there will be a header
710    header: bool,
711}
712
713impl CsvSerializer {
714    /// Constructor for the CsvSerializer object
715    pub fn new() -> Self {
716        Self {
717            builder: WriterBuilder::new(),
718            header: true,
719        }
720    }
721
722    /// Method for setting the CSV writer builder
723    pub fn with_builder(mut self, builder: WriterBuilder) -> Self {
724        self.builder = builder;
725        self
726    }
727
728    /// Method for setting the CSV writer header status
729    pub fn with_header(mut self, header: bool) -> Self {
730        self.header = header;
731        self
732    }
733}
734
735impl BatchSerializer for CsvSerializer {
736    fn serialize(&self, batch: RecordBatch, initial: bool) -> Result<Bytes> {
737        let mut buffer = Vec::with_capacity(4096);
738        let builder = self.builder.clone();
739        let header = self.header && initial;
740        let mut writer = builder.with_header(header).build(&mut buffer);
741        writer.write(&batch)?;
742        drop(writer);
743        Ok(Bytes::from(buffer))
744    }
745}
746
747/// Implements [`DataSink`] for writing to a CSV file.
748pub struct CsvSink {
749    /// Config options for writing data
750    config: FileSinkConfig,
751    writer_options: CsvWriterOptions,
752}
753
754impl Debug for CsvSink {
755    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
756        f.debug_struct("CsvSink").finish()
757    }
758}
759
760impl DisplayAs for CsvSink {
761    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
762        match t {
763            DisplayFormatType::Default | DisplayFormatType::Verbose => {
764                write!(f, "CsvSink(file_groups=",)?;
765                FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?;
766                write!(f, ")")
767            }
768            DisplayFormatType::TreeRender => {
769                writeln!(f, "format: csv")?;
770                write!(f, "file={}", &self.config.original_url)
771            }
772        }
773    }
774}
775
776impl CsvSink {
777    /// Create from config.
778    pub fn new(config: FileSinkConfig, writer_options: CsvWriterOptions) -> Self {
779        Self {
780            config,
781            writer_options,
782        }
783    }
784
785    /// Retrieve the writer options
786    pub fn writer_options(&self) -> &CsvWriterOptions {
787        &self.writer_options
788    }
789}
790
791#[async_trait]
792impl FileSink for CsvSink {
793    fn config(&self) -> &FileSinkConfig {
794        &self.config
795    }
796
797    async fn spawn_writer_tasks_and_join(
798        &self,
799        context: &Arc<TaskContext>,
800        demux_task: SpawnedTask<Result<()>>,
801        file_stream_rx: DemuxedStreamReceiver,
802        object_store: Arc<dyn ObjectStore>,
803    ) -> Result<u64> {
804        let builder = self.writer_options.writer_options.clone();
805        let header = builder.header();
806        let serializer = Arc::new(
807            CsvSerializer::new()
808                .with_builder(builder)
809                .with_header(header),
810        ) as _;
811        spawn_writer_tasks_and_join(
812            context,
813            serializer,
814            self.writer_options.compression.into(),
815            self.writer_options.compression_level,
816            object_store,
817            demux_task,
818            file_stream_rx,
819        )
820        .await
821    }
822}
823
824#[async_trait]
825impl DataSink for CsvSink {
826    fn as_any(&self) -> &dyn Any {
827        self
828    }
829
830    fn schema(&self) -> &SchemaRef {
831        self.config.output_schema()
832    }
833
834    async fn write_all(
835        &self,
836        data: SendableRecordBatchStream,
837        context: &Arc<TaskContext>,
838    ) -> Result<u64> {
839        FileSink::write_all(self, data, context).await
840    }
841}
842
843#[cfg(test)]
844mod tests {
845    use super::build_schema_helper;
846    use arrow::datatypes::DataType;
847    use std::collections::HashSet;
848
849    #[test]
850    fn test_build_schema_helper_different_column_counts() {
851        // Test the core schema building logic with different column counts
852        let mut column_names =
853            vec!["col1".to_string(), "col2".to_string(), "col3".to_string()];
854
855        // Simulate adding two more columns from another file
856        column_names.push("col4".to_string());
857        column_names.push("col5".to_string());
858
859        let column_type_possibilities = vec![
860            HashSet::from([DataType::Int64]),
861            HashSet::from([DataType::Utf8]),
862            HashSet::from([DataType::Float64]),
863            HashSet::from([DataType::Utf8]), // col4
864            HashSet::from([DataType::Utf8]), // col5
865        ];
866
867        let schema = build_schema_helper(column_names, column_type_possibilities, false);
868
869        // Verify schema has 5 columns
870        assert_eq!(schema.fields().len(), 5);
871        assert_eq!(schema.field(0).name(), "col1");
872        assert_eq!(schema.field(1).name(), "col2");
873        assert_eq!(schema.field(2).name(), "col3");
874        assert_eq!(schema.field(3).name(), "col4");
875        assert_eq!(schema.field(4).name(), "col5");
876
877        // All fields should be nullable
878        for field in schema.fields() {
879            assert!(
880                field.is_nullable(),
881                "Field {} should be nullable",
882                field.name()
883            );
884        }
885    }
886
887    #[test]
888    fn test_build_schema_helper_type_merging() {
889        // Test type merging logic
890        let column_names = vec!["col1".to_string(), "col2".to_string()];
891
892        let column_type_possibilities = vec![
893            HashSet::from([DataType::Int64, DataType::Float64]), // Should resolve to Float64
894            HashSet::from([DataType::Utf8]),                     // Should remain Utf8
895        ];
896
897        let schema = build_schema_helper(column_names, column_type_possibilities, false);
898
899        // col1 should be Float64 due to Int64 + Float64 = Float64
900        assert_eq!(*schema.field(0).data_type(), DataType::Float64);
901
902        // col2 should remain Utf8
903        assert_eq!(*schema.field(1).data_type(), DataType::Utf8);
904    }
905
906    #[test]
907    fn test_build_schema_helper_conflicting_types() {
908        // Test when we have incompatible types - should default to Utf8
909        let column_names = vec!["col1".to_string()];
910
911        let column_type_possibilities = vec![
912            HashSet::from([DataType::Boolean, DataType::Int64, DataType::Utf8]), // Should resolve to Utf8 due to conflicts
913        ];
914
915        let schema = build_schema_helper(column_names, column_type_possibilities, false);
916
917        // Should default to Utf8 for conflicting types
918        assert_eq!(*schema.field(0).data_type(), DataType::Utf8);
919    }
920}