datafusion/datasource/file_format/
arrow.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ArrowFormat`]: Apache Arrow [`FileFormat`] abstractions
19//!
20//! Works with files following the [Arrow IPC format](https://arrow.apache.org/docs/format/Columnar.html#ipc-file-format)
21
22use std::any::Any;
23use std::borrow::Cow;
24use std::collections::HashMap;
25use std::fmt::{self, Debug};
26use std::sync::Arc;
27
28use super::file_compression_type::FileCompressionType;
29use super::write::demux::DemuxedStreamReceiver;
30use super::write::SharedBuffer;
31use super::FileFormatFactory;
32use crate::datasource::file_format::write::get_writer_schema;
33use crate::datasource::file_format::FileFormat;
34use crate::datasource::physical_plan::{ArrowSource, FileSink, FileSinkConfig};
35use crate::error::Result;
36use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
37
38use arrow::datatypes::{Schema, SchemaRef};
39use arrow::error::ArrowError;
40use arrow::ipc::convert::fb_to_schema;
41use arrow::ipc::reader::FileReader;
42use arrow::ipc::writer::IpcWriteOptions;
43use arrow::ipc::{root_as_message, CompressionType};
44use datafusion_catalog::Session;
45use datafusion_common::parsers::CompressionTypeVariant;
46use datafusion_common::{
47    not_impl_err, DataFusionError, GetExt, Statistics, DEFAULT_ARROW_EXTENSION,
48};
49use datafusion_common_runtime::{JoinSet, SpawnedTask};
50use datafusion_datasource::display::FileGroupDisplay;
51use datafusion_datasource::file::FileSource;
52use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder};
53use datafusion_datasource::sink::{DataSink, DataSinkExec};
54use datafusion_datasource::write::ObjectWriterBuilder;
55use datafusion_execution::{SendableRecordBatchStream, TaskContext};
56use datafusion_expr::dml::InsertOp;
57use datafusion_physical_expr_common::sort_expr::LexRequirement;
58
59use async_trait::async_trait;
60use bytes::Bytes;
61use datafusion_datasource::source::DataSourceExec;
62use futures::stream::BoxStream;
63use futures::StreamExt;
64use object_store::{GetResultPayload, ObjectMeta, ObjectStore};
65use tokio::io::AsyncWriteExt;
66
67/// Initial writing buffer size. Note this is just a size hint for efficiency. It
68/// will grow beyond the set value if needed.
69const INITIAL_BUFFER_BYTES: usize = 1048576;
70
71/// If the buffered Arrow data exceeds this size, it is flushed to object store
72const BUFFER_FLUSH_BYTES: usize = 1024000;
73
74#[derive(Default, Debug)]
75/// Factory struct used to create [ArrowFormat]
76pub struct ArrowFormatFactory;
77
78impl ArrowFormatFactory {
79    /// Creates an instance of [ArrowFormatFactory]
80    pub fn new() -> Self {
81        Self {}
82    }
83}
84
85impl FileFormatFactory for ArrowFormatFactory {
86    fn create(
87        &self,
88        _state: &dyn Session,
89        _format_options: &HashMap<String, String>,
90    ) -> Result<Arc<dyn FileFormat>> {
91        Ok(Arc::new(ArrowFormat))
92    }
93
94    fn default(&self) -> Arc<dyn FileFormat> {
95        Arc::new(ArrowFormat)
96    }
97
98    fn as_any(&self) -> &dyn Any {
99        self
100    }
101}
102
103impl GetExt for ArrowFormatFactory {
104    fn get_ext(&self) -> String {
105        // Removes the dot, i.e. ".parquet" -> "parquet"
106        DEFAULT_ARROW_EXTENSION[1..].to_string()
107    }
108}
109
110/// Arrow `FileFormat` implementation.
111#[derive(Default, Debug)]
112pub struct ArrowFormat;
113
114#[async_trait]
115impl FileFormat for ArrowFormat {
116    fn as_any(&self) -> &dyn Any {
117        self
118    }
119
120    fn get_ext(&self) -> String {
121        ArrowFormatFactory::new().get_ext()
122    }
123
124    fn get_ext_with_compression(
125        &self,
126        file_compression_type: &FileCompressionType,
127    ) -> Result<String> {
128        let ext = self.get_ext();
129        match file_compression_type.get_variant() {
130            CompressionTypeVariant::UNCOMPRESSED => Ok(ext),
131            _ => Err(DataFusionError::Internal(
132                "Arrow FileFormat does not support compression.".into(),
133            )),
134        }
135    }
136
137    fn compression_type(&self) -> Option<FileCompressionType> {
138        None
139    }
140
141    async fn infer_schema(
142        &self,
143        _state: &dyn Session,
144        store: &Arc<dyn ObjectStore>,
145        objects: &[ObjectMeta],
146    ) -> Result<SchemaRef> {
147        let mut schemas = vec![];
148        for object in objects {
149            let r = store.as_ref().get(&object.location).await?;
150            let schema = match r.payload {
151                #[cfg(not(target_arch = "wasm32"))]
152                GetResultPayload::File(mut file, _) => {
153                    let reader = FileReader::try_new(&mut file, None)?;
154                    reader.schema()
155                }
156                GetResultPayload::Stream(stream) => {
157                    infer_schema_from_file_stream(stream).await?
158                }
159            };
160            schemas.push(schema.as_ref().clone());
161        }
162        let merged_schema = Schema::try_merge(schemas)?;
163        Ok(Arc::new(merged_schema))
164    }
165
166    async fn infer_stats(
167        &self,
168        _state: &dyn Session,
169        _store: &Arc<dyn ObjectStore>,
170        table_schema: SchemaRef,
171        _object: &ObjectMeta,
172    ) -> Result<Statistics> {
173        Ok(Statistics::new_unknown(&table_schema))
174    }
175
176    async fn create_physical_plan(
177        &self,
178        _state: &dyn Session,
179        conf: FileScanConfig,
180    ) -> Result<Arc<dyn ExecutionPlan>> {
181        let source = Arc::new(ArrowSource::default());
182        let config = FileScanConfigBuilder::from(conf)
183            .with_source(source)
184            .build();
185
186        Ok(DataSourceExec::from_data_source(config))
187    }
188
189    async fn create_writer_physical_plan(
190        &self,
191        input: Arc<dyn ExecutionPlan>,
192        _state: &dyn Session,
193        conf: FileSinkConfig,
194        order_requirements: Option<LexRequirement>,
195    ) -> Result<Arc<dyn ExecutionPlan>> {
196        if conf.insert_op != InsertOp::Append {
197            return not_impl_err!("Overwrites are not implemented yet for Arrow format");
198        }
199
200        let sink = Arc::new(ArrowFileSink::new(conf));
201
202        Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _)
203    }
204
205    fn file_source(&self) -> Arc<dyn FileSource> {
206        Arc::new(ArrowSource::default())
207    }
208}
209
210/// Implements [`FileSink`] for writing to arrow_ipc files
211struct ArrowFileSink {
212    config: FileSinkConfig,
213}
214
215impl ArrowFileSink {
216    fn new(config: FileSinkConfig) -> Self {
217        Self { config }
218    }
219}
220
221#[async_trait]
222impl FileSink for ArrowFileSink {
223    fn config(&self) -> &FileSinkConfig {
224        &self.config
225    }
226
227    async fn spawn_writer_tasks_and_join(
228        &self,
229        context: &Arc<TaskContext>,
230        demux_task: SpawnedTask<Result<()>>,
231        mut file_stream_rx: DemuxedStreamReceiver,
232        object_store: Arc<dyn ObjectStore>,
233    ) -> Result<u64> {
234        let mut file_write_tasks: JoinSet<std::result::Result<usize, DataFusionError>> =
235            JoinSet::new();
236
237        let ipc_options =
238            IpcWriteOptions::try_new(64, false, arrow_ipc::MetadataVersion::V5)?
239                .try_with_compression(Some(CompressionType::LZ4_FRAME))?;
240        while let Some((path, mut rx)) = file_stream_rx.recv().await {
241            let shared_buffer = SharedBuffer::new(INITIAL_BUFFER_BYTES);
242            let mut arrow_writer = arrow_ipc::writer::FileWriter::try_new_with_options(
243                shared_buffer.clone(),
244                &get_writer_schema(&self.config),
245                ipc_options.clone(),
246            )?;
247            let mut object_store_writer = ObjectWriterBuilder::new(
248                FileCompressionType::UNCOMPRESSED,
249                &path,
250                Arc::clone(&object_store),
251            )
252            .with_buffer_size(Some(
253                context
254                    .session_config()
255                    .options()
256                    .execution
257                    .objectstore_writer_buffer_size,
258            ))
259            .build()?;
260            file_write_tasks.spawn(async move {
261                let mut row_count = 0;
262                while let Some(batch) = rx.recv().await {
263                    row_count += batch.num_rows();
264                    arrow_writer.write(&batch)?;
265                    let mut buff_to_flush = shared_buffer.buffer.try_lock().unwrap();
266                    if buff_to_flush.len() > BUFFER_FLUSH_BYTES {
267                        object_store_writer
268                            .write_all(buff_to_flush.as_slice())
269                            .await?;
270                        buff_to_flush.clear();
271                    }
272                }
273                arrow_writer.finish()?;
274                let final_buff = shared_buffer.buffer.try_lock().unwrap();
275
276                object_store_writer.write_all(final_buff.as_slice()).await?;
277                object_store_writer.shutdown().await?;
278                Ok(row_count)
279            });
280        }
281
282        let mut row_count = 0;
283        while let Some(result) = file_write_tasks.join_next().await {
284            match result {
285                Ok(r) => {
286                    row_count += r?;
287                }
288                Err(e) => {
289                    if e.is_panic() {
290                        std::panic::resume_unwind(e.into_panic());
291                    } else {
292                        unreachable!();
293                    }
294                }
295            }
296        }
297
298        demux_task
299            .join_unwind()
300            .await
301            .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??;
302        Ok(row_count as u64)
303    }
304}
305
306impl Debug for ArrowFileSink {
307    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
308        f.debug_struct("ArrowFileSink").finish()
309    }
310}
311
312impl DisplayAs for ArrowFileSink {
313    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314        match t {
315            DisplayFormatType::Default | DisplayFormatType::Verbose => {
316                write!(f, "ArrowFileSink(file_groups=",)?;
317                FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?;
318                write!(f, ")")
319            }
320            DisplayFormatType::TreeRender => {
321                writeln!(f, "format: arrow")?;
322                write!(f, "file={}", &self.config.original_url)
323            }
324        }
325    }
326}
327
328#[async_trait]
329impl DataSink for ArrowFileSink {
330    fn as_any(&self) -> &dyn Any {
331        self
332    }
333
334    fn schema(&self) -> &SchemaRef {
335        self.config.output_schema()
336    }
337
338    async fn write_all(
339        &self,
340        data: SendableRecordBatchStream,
341        context: &Arc<TaskContext>,
342    ) -> Result<u64> {
343        FileSink::write_all(self, data, context).await
344    }
345}
346
347const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1'];
348const CONTINUATION_MARKER: [u8; 4] = [0xff; 4];
349
350/// Custom implementation of inferring schema. Should eventually be moved upstream to arrow-rs.
351/// See <https://github.com/apache/arrow-rs/issues/5021>
352async fn infer_schema_from_file_stream(
353    mut stream: BoxStream<'static, object_store::Result<Bytes>>,
354) -> Result<SchemaRef> {
355    // Expected format:
356    // <magic number "ARROW1"> - 6 bytes
357    // <empty padding bytes [to 8 byte boundary]> - 2 bytes
358    // <continuation: 0xFFFFFFFF> - 4 bytes, not present below v0.15.0
359    // <metadata_size: int32> - 4 bytes
360    // <metadata_flatbuffer: bytes>
361    // <rest of file bytes>
362
363    // So in first read we need at least all known sized sections,
364    // which is 6 + 2 + 4 + 4 = 16 bytes.
365    let bytes = collect_at_least_n_bytes(&mut stream, 16, None).await?;
366
367    // Files should start with these magic bytes
368    if bytes[0..6] != ARROW_MAGIC {
369        return Err(ArrowError::ParseError(
370            "Arrow file does not contain correct header".to_string(),
371        ))?;
372    }
373
374    // Since continuation marker bytes added in later versions
375    let (meta_len, rest_of_bytes_start_index) = if bytes[8..12] == CONTINUATION_MARKER {
376        (&bytes[12..16], 16)
377    } else {
378        (&bytes[8..12], 12)
379    };
380
381    let meta_len = [meta_len[0], meta_len[1], meta_len[2], meta_len[3]];
382    let meta_len = i32::from_le_bytes(meta_len);
383
384    // Read bytes for Schema message
385    let block_data = if bytes[rest_of_bytes_start_index..].len() < meta_len as usize {
386        // Need to read more bytes to decode Message
387        let mut block_data = Vec::with_capacity(meta_len as usize);
388        // In case we had some spare bytes in our initial read chunk
389        block_data.extend_from_slice(&bytes[rest_of_bytes_start_index..]);
390        let size_to_read = meta_len as usize - block_data.len();
391        let block_data =
392            collect_at_least_n_bytes(&mut stream, size_to_read, Some(block_data)).await?;
393        Cow::Owned(block_data)
394    } else {
395        // Already have the bytes we need
396        let end_index = meta_len as usize + rest_of_bytes_start_index;
397        let block_data = &bytes[rest_of_bytes_start_index..end_index];
398        Cow::Borrowed(block_data)
399    };
400
401    // Decode Schema message
402    let message = root_as_message(&block_data).map_err(|err| {
403        ArrowError::ParseError(format!("Unable to read IPC message as metadata: {err:?}"))
404    })?;
405    let ipc_schema = message.header_as_schema().ok_or_else(|| {
406        ArrowError::IpcError("Unable to read IPC message as schema".to_string())
407    })?;
408    let schema = fb_to_schema(ipc_schema);
409
410    Ok(Arc::new(schema))
411}
412
413async fn collect_at_least_n_bytes(
414    stream: &mut BoxStream<'static, object_store::Result<Bytes>>,
415    n: usize,
416    extend_from: Option<Vec<u8>>,
417) -> Result<Vec<u8>> {
418    let mut buf = extend_from.unwrap_or_else(|| Vec::with_capacity(n));
419    // If extending existing buffer then ensure we read n additional bytes
420    let n = n + buf.len();
421    while let Some(bytes) = stream.next().await.transpose()? {
422        buf.extend_from_slice(&bytes);
423        if buf.len() >= n {
424            break;
425        }
426    }
427    if buf.len() < n {
428        return Err(ArrowError::ParseError(
429            "Unexpected end of byte stream for Arrow IPC file".to_string(),
430        ))?;
431    }
432    Ok(buf)
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use crate::execution::context::SessionContext;
439
440    use chrono::DateTime;
441    use object_store::{chunked::ChunkedStore, memory::InMemory, path::Path};
442
443    #[tokio::test]
444    async fn test_infer_schema_stream() -> Result<()> {
445        let mut bytes = std::fs::read("tests/data/example.arrow")?;
446        bytes.truncate(bytes.len() - 20); // mangle end to show we don't need to read whole file
447        let location = Path::parse("example.arrow")?;
448        let in_memory_store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
449        in_memory_store.put(&location, bytes.into()).await?;
450
451        let session_ctx = SessionContext::new();
452        let state = session_ctx.state();
453        let object_meta = ObjectMeta {
454            location,
455            last_modified: DateTime::default(),
456            size: u64::MAX,
457            e_tag: None,
458            version: None,
459        };
460
461        let arrow_format = ArrowFormat {};
462        let expected = vec!["f0: Int64", "f1: Utf8", "f2: Boolean"];
463
464        // Test chunk sizes where too small so we keep having to read more bytes
465        // And when large enough that first read contains all we need
466        for chunk_size in [7, 3000] {
467            let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), chunk_size));
468            let inferred_schema = arrow_format
469                .infer_schema(
470                    &state,
471                    &(store.clone() as Arc<dyn ObjectStore>),
472                    std::slice::from_ref(&object_meta),
473                )
474                .await?;
475            let actual_fields = inferred_schema
476                .fields()
477                .iter()
478                .map(|f| format!("{}: {:?}", f.name(), f.data_type()))
479                .collect::<Vec<_>>();
480            assert_eq!(expected, actual_fields);
481        }
482
483        Ok(())
484    }
485
486    #[tokio::test]
487    async fn test_infer_schema_short_stream() -> Result<()> {
488        let mut bytes = std::fs::read("tests/data/example.arrow")?;
489        bytes.truncate(20); // should cause error that file shorter than expected
490        let location = Path::parse("example.arrow")?;
491        let in_memory_store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
492        in_memory_store.put(&location, bytes.into()).await?;
493
494        let session_ctx = SessionContext::new();
495        let state = session_ctx.state();
496        let object_meta = ObjectMeta {
497            location,
498            last_modified: DateTime::default(),
499            size: u64::MAX,
500            e_tag: None,
501            version: None,
502        };
503
504        let arrow_format = ArrowFormat {};
505
506        let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), 7));
507        let err = arrow_format
508            .infer_schema(
509                &state,
510                &(store.clone() as Arc<dyn ObjectStore>),
511                std::slice::from_ref(&object_meta),
512            )
513            .await;
514
515        assert!(err.is_err());
516        assert_eq!(
517            "Arrow error: Parser error: Unexpected end of byte stream for Arrow IPC file",
518            err.unwrap_err().to_string().lines().next().unwrap()
519        );
520
521        Ok(())
522    }
523}