1use std::any::Any;
23use std::collections::HashMap;
24use std::fmt::{self, Debug};
25use std::io::{Seek, SeekFrom};
26use std::sync::Arc;
27
28use arrow::datatypes::{Schema, SchemaRef};
29use arrow::error::ArrowError;
30use arrow::ipc::convert::fb_to_schema;
31use arrow::ipc::reader::{FileReader, StreamReader};
32use arrow::ipc::writer::IpcWriteOptions;
33use arrow::ipc::{CompressionType, root_as_message};
34use datafusion_common::error::Result;
35use datafusion_common::parsers::CompressionTypeVariant;
36use datafusion_common::{
37 DEFAULT_ARROW_EXTENSION, DataFusionError, GetExt, Statistics,
38 internal_datafusion_err, not_impl_err,
39};
40use datafusion_common_runtime::{JoinSet, SpawnedTask};
41use datafusion_datasource::TableSchema;
42use datafusion_datasource::display::FileGroupDisplay;
43use datafusion_datasource::file::FileSource;
44use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder};
45use datafusion_datasource::sink::{DataSink, DataSinkExec};
46use datafusion_datasource::write::{
47 ObjectWriterBuilder, SharedBuffer, get_writer_schema,
48};
49use datafusion_execution::{SendableRecordBatchStream, TaskContext};
50use datafusion_expr::dml::InsertOp;
51use datafusion_physical_expr_common::sort_expr::LexRequirement;
52
53use crate::source::ArrowSource;
54use async_trait::async_trait;
55use bytes::Bytes;
56use datafusion_datasource::file_compression_type::FileCompressionType;
57use datafusion_datasource::file_format::{FileFormat, FileFormatFactory};
58use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig};
59use datafusion_datasource::source::DataSourceExec;
60use datafusion_datasource::write::demux::DemuxedStreamReceiver;
61use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
62use datafusion_session::Session;
63use futures::StreamExt;
64use futures::stream::BoxStream;
65use object_store::{
66 GetOptions, GetRange, GetResultPayload, ObjectMeta, ObjectStore, ObjectStoreExt,
67 path::Path,
68};
69use tokio::io::AsyncWriteExt;
70
71const INITIAL_BUFFER_BYTES: usize = 1048576;
74
75const BUFFER_FLUSH_BYTES: usize = 1024000;
77
78#[derive(Default, Debug)]
80pub struct ArrowFormatFactory;
81
82impl ArrowFormatFactory {
83 pub fn new() -> Self {
85 Self {}
86 }
87}
88
89impl FileFormatFactory for ArrowFormatFactory {
90 fn create(
91 &self,
92 _state: &dyn Session,
93 _format_options: &HashMap<String, String>,
94 ) -> Result<Arc<dyn FileFormat>> {
95 Ok(Arc::new(ArrowFormat))
96 }
97
98 fn default(&self) -> Arc<dyn FileFormat> {
99 Arc::new(ArrowFormat)
100 }
101
102 fn as_any(&self) -> &dyn Any {
103 self
104 }
105}
106
107impl GetExt for ArrowFormatFactory {
108 fn get_ext(&self) -> String {
109 DEFAULT_ARROW_EXTENSION[1..].to_string()
111 }
112}
113
114#[derive(Default, Debug)]
116pub struct ArrowFormat;
117
118#[async_trait]
119impl FileFormat for ArrowFormat {
120 fn as_any(&self) -> &dyn Any {
121 self
122 }
123
124 fn get_ext(&self) -> String {
125 ArrowFormatFactory::new().get_ext()
126 }
127
128 fn get_ext_with_compression(
129 &self,
130 file_compression_type: &FileCompressionType,
131 ) -> Result<String> {
132 let ext = self.get_ext();
133 match file_compression_type.get_variant() {
134 CompressionTypeVariant::UNCOMPRESSED => Ok(ext),
135 _ => Err(internal_datafusion_err!(
136 "Arrow FileFormat does not support compression."
137 )),
138 }
139 }
140
141 fn compression_type(&self) -> Option<FileCompressionType> {
142 None
143 }
144
145 async fn infer_schema(
146 &self,
147 _state: &dyn Session,
148 store: &Arc<dyn ObjectStore>,
149 objects: &[ObjectMeta],
150 ) -> Result<SchemaRef> {
151 let mut schemas = vec![];
152 for object in objects {
153 let r = store.as_ref().get(&object.location).await?;
154 let schema = match r.payload {
155 #[cfg(not(target_arch = "wasm32"))]
156 GetResultPayload::File(mut file, _) => {
157 match FileReader::try_new(&mut file, None) {
158 Ok(reader) => reader.schema(),
159 Err(file_error) => {
160 file.seek(SeekFrom::Start(0))?;
164 match StreamReader::try_new(&mut file, None) {
165 Ok(reader) => reader.schema(),
166 Err(stream_error) => {
167 return Err(internal_datafusion_err!(
168 "Failed to parse Arrow file as either file format or stream format. File format error: {file_error}. Stream format error: {stream_error}"
169 ));
170 }
171 }
172 }
173 }
174 }
175 GetResultPayload::Stream(stream) => infer_stream_schema(stream).await?,
176 };
177 schemas.push(schema.as_ref().clone());
178 }
179 let merged_schema = Schema::try_merge(schemas)?;
180 Ok(Arc::new(merged_schema))
181 }
182
183 async fn infer_stats(
184 &self,
185 _state: &dyn Session,
186 _store: &Arc<dyn ObjectStore>,
187 table_schema: SchemaRef,
188 _object: &ObjectMeta,
189 ) -> Result<Statistics> {
190 Ok(Statistics::new_unknown(&table_schema))
191 }
192
193 async fn create_physical_plan(
194 &self,
195 state: &dyn Session,
196 conf: FileScanConfig,
197 ) -> Result<Arc<dyn ExecutionPlan>> {
198 let object_store = state.runtime_env().object_store(&conf.object_store_url)?;
199 let object_location = &conf
200 .file_groups
201 .first()
202 .ok_or_else(|| internal_datafusion_err!("No files found in file group"))?
203 .files()
204 .first()
205 .ok_or_else(|| internal_datafusion_err!("No files found in file group"))?
206 .object_meta
207 .location;
208
209 let table_schema = TableSchema::new(
210 Arc::clone(conf.file_schema()),
211 conf.table_partition_cols().clone(),
212 );
213
214 let mut source: Arc<dyn FileSource> =
215 match is_object_in_arrow_ipc_file_format(object_store, object_location).await
216 {
217 Ok(true) => Arc::new(ArrowSource::new_file_source(table_schema)),
218 Ok(false) => Arc::new(ArrowSource::new_stream_file_source(table_schema)),
219 Err(e) => Err(e)?,
220 };
221
222 if let Some(projection) = conf.file_source.projection()
224 && let Some(new_source) = source.try_pushdown_projection(projection)?
225 {
226 source = new_source;
227 }
228
229 let config = FileScanConfigBuilder::from(conf)
230 .with_source(source)
231 .build();
232
233 Ok(DataSourceExec::from_data_source(config))
234 }
235
236 async fn create_writer_physical_plan(
237 &self,
238 input: Arc<dyn ExecutionPlan>,
239 _state: &dyn Session,
240 conf: FileSinkConfig,
241 order_requirements: Option<LexRequirement>,
242 ) -> Result<Arc<dyn ExecutionPlan>> {
243 if conf.insert_op != InsertOp::Append {
244 return not_impl_err!("Overwrites are not implemented yet for Arrow format");
245 }
246
247 let sink = Arc::new(ArrowFileSink::new(conf));
248
249 Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _)
250 }
251
252 fn file_source(&self, table_schema: TableSchema) -> Arc<dyn FileSource> {
253 Arc::new(ArrowSource::new_file_source(table_schema))
254 }
255}
256
257struct ArrowFileSink {
259 config: FileSinkConfig,
260}
261
262impl ArrowFileSink {
263 fn new(config: FileSinkConfig) -> Self {
264 Self { config }
265 }
266}
267
268#[async_trait]
269impl FileSink for ArrowFileSink {
270 fn config(&self) -> &FileSinkConfig {
271 &self.config
272 }
273
274 async fn spawn_writer_tasks_and_join(
275 &self,
276 context: &Arc<TaskContext>,
277 demux_task: SpawnedTask<Result<()>>,
278 mut file_stream_rx: DemuxedStreamReceiver,
279 object_store: Arc<dyn ObjectStore>,
280 ) -> Result<u64> {
281 let mut file_write_tasks: JoinSet<std::result::Result<usize, DataFusionError>> =
282 JoinSet::new();
283
284 let ipc_options =
285 IpcWriteOptions::try_new(64, false, arrow_ipc::MetadataVersion::V5)?
286 .try_with_compression(Some(CompressionType::LZ4_FRAME))?;
287 while let Some((path, mut rx)) = file_stream_rx.recv().await {
288 let shared_buffer = SharedBuffer::new(INITIAL_BUFFER_BYTES);
289 let mut arrow_writer = arrow_ipc::writer::FileWriter::try_new_with_options(
290 shared_buffer.clone(),
291 &get_writer_schema(&self.config),
292 ipc_options.clone(),
293 )?;
294 let mut object_store_writer = ObjectWriterBuilder::new(
295 FileCompressionType::UNCOMPRESSED,
296 &path,
297 Arc::clone(&object_store),
298 )
299 .with_buffer_size(Some(
300 context
301 .session_config()
302 .options()
303 .execution
304 .objectstore_writer_buffer_size,
305 ))
306 .build()?;
307 file_write_tasks.spawn(async move {
308 let mut row_count = 0;
309 while let Some(batch) = rx.recv().await {
310 row_count += batch.num_rows();
311 arrow_writer.write(&batch)?;
312 let mut buff_to_flush = shared_buffer.buffer.try_lock().unwrap();
313 if buff_to_flush.len() > BUFFER_FLUSH_BYTES {
314 object_store_writer
315 .write_all(buff_to_flush.as_slice())
316 .await?;
317 buff_to_flush.clear();
318 }
319 }
320 arrow_writer.finish()?;
321 let final_buff = shared_buffer.buffer.try_lock().unwrap();
322
323 object_store_writer.write_all(final_buff.as_slice()).await?;
324 object_store_writer.shutdown().await?;
325 Ok(row_count)
326 });
327 }
328
329 let mut row_count = 0;
330 while let Some(result) = file_write_tasks.join_next().await {
331 match result {
332 Ok(r) => {
333 row_count += r?;
334 }
335 Err(e) => {
336 if e.is_panic() {
337 std::panic::resume_unwind(e.into_panic());
338 } else {
339 unreachable!();
340 }
341 }
342 }
343 }
344
345 demux_task
346 .join_unwind()
347 .await
348 .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??;
349 Ok(row_count as u64)
350 }
351}
352
353impl Debug for ArrowFileSink {
354 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
355 f.debug_struct("ArrowFileSink").finish()
356 }
357}
358
359impl DisplayAs for ArrowFileSink {
360 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 match t {
362 DisplayFormatType::Default | DisplayFormatType::Verbose => {
363 write!(f, "ArrowFileSink(file_groups=",)?;
364 FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?;
365 write!(f, ")")
366 }
367 DisplayFormatType::TreeRender => {
368 writeln!(f, "format: arrow")?;
369 write!(f, "file={}", &self.config.original_url)
370 }
371 }
372 }
373}
374
375#[async_trait]
376impl DataSink for ArrowFileSink {
377 fn as_any(&self) -> &dyn Any {
378 self
379 }
380
381 fn schema(&self) -> &SchemaRef {
382 self.config.output_schema()
383 }
384
385 async fn write_all(
386 &self,
387 data: SendableRecordBatchStream,
388 context: &Arc<TaskContext>,
389 ) -> Result<u64> {
390 FileSink::write_all(self, data, context).await
391 }
392}
393
394const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1'];
398const CONTINUATION_MARKER: [u8; 4] = [0xff; 4];
399
400async fn infer_stream_schema(
401 mut stream: BoxStream<'static, object_store::Result<Bytes>>,
402) -> Result<SchemaRef> {
403 let bytes = extend_bytes_to_n_length_from_stream(vec![], 16, &mut stream).await?;
444
445 let preamble_len = if bytes[0..6] == ARROW_MAGIC {
447 if bytes[8..12] == CONTINUATION_MARKER {
449 12
451 } else {
452 8
454 }
455 } else if bytes[0..4] == CONTINUATION_MARKER {
456 4
458 } else {
459 0
461 };
462
463 let meta_len_bytes: [u8; 4] = bytes[preamble_len..preamble_len + 4]
464 .try_into()
465 .map_err(|err| {
466 ArrowError::ParseError(format!(
467 "Unable to read IPC message metadata length: {err:?}"
468 ))
469 })?;
470
471 let meta_len = i32::from_le_bytes([
472 meta_len_bytes[0],
473 meta_len_bytes[1],
474 meta_len_bytes[2],
475 meta_len_bytes[3],
476 ]);
477
478 if meta_len < 0 {
479 return Err(ArrowError::ParseError(
480 "IPC message metadata length is negative".to_string(),
481 )
482 .into());
483 }
484
485 let bytes = extend_bytes_to_n_length_from_stream(
486 bytes,
487 preamble_len + 4 + (meta_len as usize),
488 &mut stream,
489 )
490 .await?;
491
492 let message = root_as_message(&bytes[preamble_len + 4..]).map_err(|err| {
493 ArrowError::ParseError(format!("Unable to read IPC message metadata: {err:?}"))
494 })?;
495 let fb_schema = message.header_as_schema().ok_or_else(|| {
496 ArrowError::IpcError("Unable to read IPC message schema".to_string())
497 })?;
498 let schema = fb_to_schema(fb_schema);
499
500 Ok(Arc::new(schema))
501}
502
503async fn extend_bytes_to_n_length_from_stream(
504 bytes: Vec<u8>,
505 n: usize,
506 stream: &mut BoxStream<'static, object_store::Result<Bytes>>,
507) -> Result<Vec<u8>> {
508 if bytes.len() >= n {
509 return Ok(bytes);
510 }
511
512 let mut buf = bytes;
513
514 while let Some(b) = stream.next().await.transpose()? {
515 buf.extend_from_slice(&b);
516
517 if buf.len() >= n {
518 break;
519 }
520 }
521
522 if buf.len() < n {
523 return Err(ArrowError::ParseError(
524 "Unexpected end of byte stream for Arrow IPC file".to_string(),
525 )
526 .into());
527 }
528
529 Ok(buf)
530}
531
532async fn is_object_in_arrow_ipc_file_format(
533 store: Arc<dyn ObjectStore>,
534 object_location: &Path,
535) -> Result<bool> {
536 let get_opts = GetOptions {
537 range: Some(GetRange::Bounded(0..6)),
538 ..Default::default()
539 };
540 let bytes = store
541 .get_opts(object_location, get_opts)
542 .await?
543 .bytes()
544 .await?;
545 Ok(bytes.len() >= 6 && bytes[0..6] == ARROW_MAGIC)
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551
552 use chrono::DateTime;
553 use datafusion_common::DFSchema;
554 use datafusion_common::config::TableOptions;
555 use datafusion_execution::config::SessionConfig;
556 use datafusion_execution::runtime_env::RuntimeEnv;
557 use datafusion_expr::execution_props::ExecutionProps;
558 use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF};
559 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
560 use object_store::{chunked::ChunkedStore, memory::InMemory, path::Path};
561
562 struct MockSession {
563 config: SessionConfig,
564 runtime_env: Arc<RuntimeEnv>,
565 }
566
567 impl MockSession {
568 fn new() -> Self {
569 Self {
570 config: SessionConfig::new(),
571 runtime_env: Arc::new(RuntimeEnv::default()),
572 }
573 }
574 }
575
576 #[async_trait::async_trait]
577 impl Session for MockSession {
578 fn session_id(&self) -> &str {
579 unimplemented!()
580 }
581
582 fn config(&self) -> &SessionConfig {
583 &self.config
584 }
585
586 async fn create_physical_plan(
587 &self,
588 _logical_plan: &LogicalPlan,
589 ) -> Result<Arc<dyn ExecutionPlan>> {
590 unimplemented!()
591 }
592
593 fn create_physical_expr(
594 &self,
595 _expr: Expr,
596 _df_schema: &DFSchema,
597 ) -> Result<Arc<dyn PhysicalExpr>> {
598 unimplemented!()
599 }
600
601 fn scalar_functions(&self) -> &HashMap<String, Arc<ScalarUDF>> {
602 unimplemented!()
603 }
604
605 fn aggregate_functions(&self) -> &HashMap<String, Arc<AggregateUDF>> {
606 unimplemented!()
607 }
608
609 fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
610 unimplemented!()
611 }
612
613 fn runtime_env(&self) -> &Arc<RuntimeEnv> {
614 &self.runtime_env
615 }
616
617 fn execution_props(&self) -> &ExecutionProps {
618 unimplemented!()
619 }
620
621 fn as_any(&self) -> &dyn Any {
622 unimplemented!()
623 }
624
625 fn table_options(&self) -> &TableOptions {
626 unimplemented!()
627 }
628
629 fn table_options_mut(&mut self) -> &mut TableOptions {
630 unimplemented!()
631 }
632
633 fn task_ctx(&self) -> Arc<TaskContext> {
634 unimplemented!()
635 }
636 }
637
638 #[tokio::test]
639 async fn test_infer_schema_stream() -> Result<()> {
640 for file in ["example.arrow", "example_stream.arrow"] {
641 let mut bytes = std::fs::read(format!("tests/data/{file}"))?;
642 bytes.truncate(bytes.len() - 20); let location = Path::parse(file)?;
644 let in_memory_store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
645 in_memory_store.put(&location, bytes.into()).await?;
646
647 let state = MockSession::new();
648 let object_meta = ObjectMeta {
649 location,
650 last_modified: DateTime::default(),
651 size: u64::MAX,
652 e_tag: None,
653 version: None,
654 };
655
656 let arrow_format = ArrowFormat {};
657 let expected = vec!["f0: Int64", "f1: Utf8", "f2: Boolean"];
658
659 for chunk_size in [7, 3000] {
662 let store =
663 Arc::new(ChunkedStore::new(in_memory_store.clone(), chunk_size));
664 let inferred_schema = arrow_format
665 .infer_schema(
666 &state,
667 &(store.clone() as Arc<dyn ObjectStore>),
668 std::slice::from_ref(&object_meta),
669 )
670 .await?;
671 let actual_fields = inferred_schema
672 .fields()
673 .iter()
674 .map(|f| format!("{}: {:?}", f.name(), f.data_type()))
675 .collect::<Vec<_>>();
676 assert_eq!(expected, actual_fields);
677 }
678 }
679 Ok(())
680 }
681
682 #[tokio::test]
683 async fn test_infer_schema_short_stream() -> Result<()> {
684 for file in ["example.arrow", "example_stream.arrow"] {
685 let mut bytes = std::fs::read(format!("tests/data/{file}"))?;
686 bytes.truncate(20); let location = Path::parse(file)?;
688 let in_memory_store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
689 in_memory_store.put(&location, bytes.into()).await?;
690
691 let state = MockSession::new();
692 let object_meta = ObjectMeta {
693 location,
694 last_modified: DateTime::default(),
695 size: u64::MAX,
696 e_tag: None,
697 version: None,
698 };
699
700 let arrow_format = ArrowFormat {};
701
702 let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), 7));
703 let err = arrow_format
704 .infer_schema(
705 &state,
706 &(store.clone() as Arc<dyn ObjectStore>),
707 std::slice::from_ref(&object_meta),
708 )
709 .await;
710
711 assert!(err.is_err());
712 assert_eq!(
713 "Arrow error: Parser error: Unexpected end of byte stream for Arrow IPC file",
714 err.unwrap_err().to_string().lines().next().unwrap()
715 );
716 }
717
718 Ok(())
719 }
720
721 #[tokio::test]
722 async fn test_format_detection_file_format() -> Result<()> {
723 let store = Arc::new(InMemory::new());
724 let path = Path::from("test.arrow");
725
726 let file_bytes = std::fs::read("tests/data/example.arrow")?;
727 store.put(&path, file_bytes.into()).await?;
728
729 let is_file = is_object_in_arrow_ipc_file_format(store.clone(), &path).await?;
730 assert!(is_file, "Should detect file format");
731 Ok(())
732 }
733
734 #[tokio::test]
735 async fn test_format_detection_stream_format() -> Result<()> {
736 let store = Arc::new(InMemory::new());
737 let path = Path::from("test_stream.arrow");
738
739 let stream_bytes = std::fs::read("tests/data/example_stream.arrow")?;
740 store.put(&path, stream_bytes.into()).await?;
741
742 let is_file = is_object_in_arrow_ipc_file_format(store.clone(), &path).await?;
743
744 assert!(!is_file, "Should detect stream format (not file)");
745
746 Ok(())
747 }
748
749 #[tokio::test]
750 async fn test_format_detection_corrupted_file() -> Result<()> {
751 let store = Arc::new(InMemory::new());
752 let path = Path::from("corrupted.arrow");
753
754 store
755 .put(&path, Bytes::from(vec![0x43, 0x4f, 0x52, 0x41]).into())
756 .await?;
757
758 let is_file = is_object_in_arrow_ipc_file_format(store.clone(), &path).await?;
759
760 assert!(
761 !is_file,
762 "Corrupted file should not be detected as file format"
763 );
764
765 Ok(())
766 }
767
768 #[tokio::test]
769 async fn test_format_detection_empty_file() -> Result<()> {
770 let store = Arc::new(InMemory::new());
771 let path = Path::from("empty.arrow");
772
773 store.put(&path, Bytes::new().into()).await?;
774
775 let result = is_object_in_arrow_ipc_file_format(store.clone(), &path).await;
776
777 assert!(result.is_err(), "Empty file should error");
779
780 Ok(())
781 }
782}