1use std::any::Any;
23use std::collections::HashMap;
24use std::fmt::{self, Debug};
25use std::io::{Seek, SeekFrom};
26use std::sync::Arc;
27
28use arrow::datatypes::{Schema, SchemaRef};
29use arrow::error::ArrowError;
30use arrow::ipc::convert::fb_to_schema;
31use arrow::ipc::reader::{FileReader, StreamReader};
32use arrow::ipc::writer::IpcWriteOptions;
33use arrow::ipc::{CompressionType, root_as_message};
34use datafusion_common::error::Result;
35use datafusion_common::parsers::CompressionTypeVariant;
36use datafusion_common::{
37 DEFAULT_ARROW_EXTENSION, DataFusionError, GetExt, Statistics,
38 internal_datafusion_err, not_impl_err,
39};
40use datafusion_common_runtime::{JoinSet, SpawnedTask};
41use datafusion_datasource::TableSchema;
42use datafusion_datasource::display::FileGroupDisplay;
43use datafusion_datasource::file::FileSource;
44use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder};
45use datafusion_datasource::sink::{DataSink, DataSinkExec};
46use datafusion_datasource::write::{
47 ObjectWriterBuilder, SharedBuffer, get_writer_schema,
48};
49use datafusion_execution::{SendableRecordBatchStream, TaskContext};
50use datafusion_expr::dml::InsertOp;
51use datafusion_physical_expr_common::sort_expr::LexRequirement;
52
53use crate::source::ArrowSource;
54use async_trait::async_trait;
55use bytes::Bytes;
56use datafusion_datasource::file_compression_type::FileCompressionType;
57use datafusion_datasource::file_format::{FileFormat, FileFormatFactory};
58use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig};
59use datafusion_datasource::source::DataSourceExec;
60use datafusion_datasource::write::demux::DemuxedStreamReceiver;
61use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
62use datafusion_session::Session;
63use futures::StreamExt;
64use futures::stream::BoxStream;
65use object_store::{
66 GetOptions, GetRange, GetResultPayload, ObjectMeta, ObjectStore, path::Path,
67};
68use tokio::io::AsyncWriteExt;
69
70const INITIAL_BUFFER_BYTES: usize = 1048576;
73
74const BUFFER_FLUSH_BYTES: usize = 1024000;
76
77#[derive(Default, Debug)]
79pub struct ArrowFormatFactory;
80
81impl ArrowFormatFactory {
82 pub fn new() -> Self {
84 Self {}
85 }
86}
87
88impl FileFormatFactory for ArrowFormatFactory {
89 fn create(
90 &self,
91 _state: &dyn Session,
92 _format_options: &HashMap<String, String>,
93 ) -> Result<Arc<dyn FileFormat>> {
94 Ok(Arc::new(ArrowFormat))
95 }
96
97 fn default(&self) -> Arc<dyn FileFormat> {
98 Arc::new(ArrowFormat)
99 }
100
101 fn as_any(&self) -> &dyn Any {
102 self
103 }
104}
105
106impl GetExt for ArrowFormatFactory {
107 fn get_ext(&self) -> String {
108 DEFAULT_ARROW_EXTENSION[1..].to_string()
110 }
111}
112
113#[derive(Default, Debug)]
115pub struct ArrowFormat;
116
117#[async_trait]
118impl FileFormat for ArrowFormat {
119 fn as_any(&self) -> &dyn Any {
120 self
121 }
122
123 fn get_ext(&self) -> String {
124 ArrowFormatFactory::new().get_ext()
125 }
126
127 fn get_ext_with_compression(
128 &self,
129 file_compression_type: &FileCompressionType,
130 ) -> Result<String> {
131 let ext = self.get_ext();
132 match file_compression_type.get_variant() {
133 CompressionTypeVariant::UNCOMPRESSED => Ok(ext),
134 _ => Err(internal_datafusion_err!(
135 "Arrow FileFormat does not support compression."
136 )),
137 }
138 }
139
140 fn compression_type(&self) -> Option<FileCompressionType> {
141 None
142 }
143
144 async fn infer_schema(
145 &self,
146 _state: &dyn Session,
147 store: &Arc<dyn ObjectStore>,
148 objects: &[ObjectMeta],
149 ) -> Result<SchemaRef> {
150 let mut schemas = vec![];
151 for object in objects {
152 let r = store.as_ref().get(&object.location).await?;
153 let schema = match r.payload {
154 #[cfg(not(target_arch = "wasm32"))]
155 GetResultPayload::File(mut file, _) => {
156 match FileReader::try_new(&mut file, None) {
157 Ok(reader) => reader.schema(),
158 Err(file_error) => {
159 file.seek(SeekFrom::Start(0))?;
163 match StreamReader::try_new(&mut file, None) {
164 Ok(reader) => reader.schema(),
165 Err(stream_error) => {
166 return Err(internal_datafusion_err!(
167 "Failed to parse Arrow file as either file format or stream format. File format error: {file_error}. Stream format error: {stream_error}"
168 ));
169 }
170 }
171 }
172 }
173 }
174 GetResultPayload::Stream(stream) => infer_stream_schema(stream).await?,
175 };
176 schemas.push(schema.as_ref().clone());
177 }
178 let merged_schema = Schema::try_merge(schemas)?;
179 Ok(Arc::new(merged_schema))
180 }
181
182 async fn infer_stats(
183 &self,
184 _state: &dyn Session,
185 _store: &Arc<dyn ObjectStore>,
186 table_schema: SchemaRef,
187 _object: &ObjectMeta,
188 ) -> Result<Statistics> {
189 Ok(Statistics::new_unknown(&table_schema))
190 }
191
192 async fn create_physical_plan(
193 &self,
194 state: &dyn Session,
195 conf: FileScanConfig,
196 ) -> Result<Arc<dyn ExecutionPlan>> {
197 let object_store = state.runtime_env().object_store(&conf.object_store_url)?;
198 let object_location = &conf
199 .file_groups
200 .first()
201 .ok_or_else(|| internal_datafusion_err!("No files found in file group"))?
202 .files()
203 .first()
204 .ok_or_else(|| internal_datafusion_err!("No files found in file group"))?
205 .object_meta
206 .location;
207
208 let table_schema = TableSchema::new(
209 Arc::clone(conf.file_schema()),
210 conf.table_partition_cols().clone(),
211 );
212
213 let mut source: Arc<dyn FileSource> =
214 match is_object_in_arrow_ipc_file_format(object_store, object_location).await
215 {
216 Ok(true) => Arc::new(ArrowSource::new_file_source(table_schema)),
217 Ok(false) => Arc::new(ArrowSource::new_stream_file_source(table_schema)),
218 Err(e) => Err(e)?,
219 };
220
221 if let Some(projection) = conf.file_source.projection()
223 && let Some(new_source) = source.try_pushdown_projection(projection)?
224 {
225 source = new_source;
226 }
227
228 let config = FileScanConfigBuilder::from(conf)
229 .with_source(source)
230 .build();
231
232 Ok(DataSourceExec::from_data_source(config))
233 }
234
235 async fn create_writer_physical_plan(
236 &self,
237 input: Arc<dyn ExecutionPlan>,
238 _state: &dyn Session,
239 conf: FileSinkConfig,
240 order_requirements: Option<LexRequirement>,
241 ) -> Result<Arc<dyn ExecutionPlan>> {
242 if conf.insert_op != InsertOp::Append {
243 return not_impl_err!("Overwrites are not implemented yet for Arrow format");
244 }
245
246 let sink = Arc::new(ArrowFileSink::new(conf));
247
248 Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _)
249 }
250
251 fn file_source(&self, table_schema: TableSchema) -> Arc<dyn FileSource> {
252 Arc::new(ArrowSource::new_file_source(table_schema))
253 }
254}
255
256struct ArrowFileSink {
258 config: FileSinkConfig,
259}
260
261impl ArrowFileSink {
262 fn new(config: FileSinkConfig) -> Self {
263 Self { config }
264 }
265}
266
267#[async_trait]
268impl FileSink for ArrowFileSink {
269 fn config(&self) -> &FileSinkConfig {
270 &self.config
271 }
272
273 async fn spawn_writer_tasks_and_join(
274 &self,
275 context: &Arc<TaskContext>,
276 demux_task: SpawnedTask<Result<()>>,
277 mut file_stream_rx: DemuxedStreamReceiver,
278 object_store: Arc<dyn ObjectStore>,
279 ) -> Result<u64> {
280 let mut file_write_tasks: JoinSet<std::result::Result<usize, DataFusionError>> =
281 JoinSet::new();
282
283 let ipc_options =
284 IpcWriteOptions::try_new(64, false, arrow_ipc::MetadataVersion::V5)?
285 .try_with_compression(Some(CompressionType::LZ4_FRAME))?;
286 while let Some((path, mut rx)) = file_stream_rx.recv().await {
287 let shared_buffer = SharedBuffer::new(INITIAL_BUFFER_BYTES);
288 let mut arrow_writer = arrow_ipc::writer::FileWriter::try_new_with_options(
289 shared_buffer.clone(),
290 &get_writer_schema(&self.config),
291 ipc_options.clone(),
292 )?;
293 let mut object_store_writer = ObjectWriterBuilder::new(
294 FileCompressionType::UNCOMPRESSED,
295 &path,
296 Arc::clone(&object_store),
297 )
298 .with_buffer_size(Some(
299 context
300 .session_config()
301 .options()
302 .execution
303 .objectstore_writer_buffer_size,
304 ))
305 .build()?;
306 file_write_tasks.spawn(async move {
307 let mut row_count = 0;
308 while let Some(batch) = rx.recv().await {
309 row_count += batch.num_rows();
310 arrow_writer.write(&batch)?;
311 let mut buff_to_flush = shared_buffer.buffer.try_lock().unwrap();
312 if buff_to_flush.len() > BUFFER_FLUSH_BYTES {
313 object_store_writer
314 .write_all(buff_to_flush.as_slice())
315 .await?;
316 buff_to_flush.clear();
317 }
318 }
319 arrow_writer.finish()?;
320 let final_buff = shared_buffer.buffer.try_lock().unwrap();
321
322 object_store_writer.write_all(final_buff.as_slice()).await?;
323 object_store_writer.shutdown().await?;
324 Ok(row_count)
325 });
326 }
327
328 let mut row_count = 0;
329 while let Some(result) = file_write_tasks.join_next().await {
330 match result {
331 Ok(r) => {
332 row_count += r?;
333 }
334 Err(e) => {
335 if e.is_panic() {
336 std::panic::resume_unwind(e.into_panic());
337 } else {
338 unreachable!();
339 }
340 }
341 }
342 }
343
344 demux_task
345 .join_unwind()
346 .await
347 .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??;
348 Ok(row_count as u64)
349 }
350}
351
352impl Debug for ArrowFileSink {
353 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354 f.debug_struct("ArrowFileSink").finish()
355 }
356}
357
358impl DisplayAs for ArrowFileSink {
359 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360 match t {
361 DisplayFormatType::Default | DisplayFormatType::Verbose => {
362 write!(f, "ArrowFileSink(file_groups=",)?;
363 FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?;
364 write!(f, ")")
365 }
366 DisplayFormatType::TreeRender => {
367 writeln!(f, "format: arrow")?;
368 write!(f, "file={}", &self.config.original_url)
369 }
370 }
371 }
372}
373
374#[async_trait]
375impl DataSink for ArrowFileSink {
376 fn as_any(&self) -> &dyn Any {
377 self
378 }
379
380 fn schema(&self) -> &SchemaRef {
381 self.config.output_schema()
382 }
383
384 async fn write_all(
385 &self,
386 data: SendableRecordBatchStream,
387 context: &Arc<TaskContext>,
388 ) -> Result<u64> {
389 FileSink::write_all(self, data, context).await
390 }
391}
392
393const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1'];
397const CONTINUATION_MARKER: [u8; 4] = [0xff; 4];
398
399async fn infer_stream_schema(
400 mut stream: BoxStream<'static, object_store::Result<Bytes>>,
401) -> Result<SchemaRef> {
402 let bytes = extend_bytes_to_n_length_from_stream(vec![], 16, &mut stream).await?;
443
444 let preamble_len = if bytes[0..6] == ARROW_MAGIC {
446 if bytes[8..12] == CONTINUATION_MARKER {
448 12
450 } else {
451 8
453 }
454 } else if bytes[0..4] == CONTINUATION_MARKER {
455 4
457 } else {
458 0
460 };
461
462 let meta_len_bytes: [u8; 4] = bytes[preamble_len..preamble_len + 4]
463 .try_into()
464 .map_err(|err| {
465 ArrowError::ParseError(format!(
466 "Unable to read IPC message metadata length: {err:?}"
467 ))
468 })?;
469
470 let meta_len = i32::from_le_bytes([
471 meta_len_bytes[0],
472 meta_len_bytes[1],
473 meta_len_bytes[2],
474 meta_len_bytes[3],
475 ]);
476
477 if meta_len < 0 {
478 return Err(ArrowError::ParseError(
479 "IPC message metadata length is negative".to_string(),
480 )
481 .into());
482 }
483
484 let bytes = extend_bytes_to_n_length_from_stream(
485 bytes,
486 preamble_len + 4 + (meta_len as usize),
487 &mut stream,
488 )
489 .await?;
490
491 let message = root_as_message(&bytes[preamble_len + 4..]).map_err(|err| {
492 ArrowError::ParseError(format!("Unable to read IPC message metadata: {err:?}"))
493 })?;
494 let fb_schema = message.header_as_schema().ok_or_else(|| {
495 ArrowError::IpcError("Unable to read IPC message schema".to_string())
496 })?;
497 let schema = fb_to_schema(fb_schema);
498
499 Ok(Arc::new(schema))
500}
501
502async fn extend_bytes_to_n_length_from_stream(
503 bytes: Vec<u8>,
504 n: usize,
505 stream: &mut BoxStream<'static, object_store::Result<Bytes>>,
506) -> Result<Vec<u8>> {
507 if bytes.len() >= n {
508 return Ok(bytes);
509 }
510
511 let mut buf = bytes;
512
513 while let Some(b) = stream.next().await.transpose()? {
514 buf.extend_from_slice(&b);
515
516 if buf.len() >= n {
517 break;
518 }
519 }
520
521 if buf.len() < n {
522 return Err(ArrowError::ParseError(
523 "Unexpected end of byte stream for Arrow IPC file".to_string(),
524 )
525 .into());
526 }
527
528 Ok(buf)
529}
530
531async fn is_object_in_arrow_ipc_file_format(
532 store: Arc<dyn ObjectStore>,
533 object_location: &Path,
534) -> Result<bool> {
535 let get_opts = GetOptions {
536 range: Some(GetRange::Bounded(0..6)),
537 ..Default::default()
538 };
539 let bytes = store
540 .get_opts(object_location, get_opts)
541 .await?
542 .bytes()
543 .await?;
544 Ok(bytes.len() >= 6 && bytes[0..6] == ARROW_MAGIC)
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550
551 use chrono::DateTime;
552 use datafusion_common::DFSchema;
553 use datafusion_common::config::TableOptions;
554 use datafusion_execution::config::SessionConfig;
555 use datafusion_execution::runtime_env::RuntimeEnv;
556 use datafusion_expr::execution_props::ExecutionProps;
557 use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF};
558 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
559 use object_store::{chunked::ChunkedStore, memory::InMemory, path::Path};
560
561 struct MockSession {
562 config: SessionConfig,
563 runtime_env: Arc<RuntimeEnv>,
564 }
565
566 impl MockSession {
567 fn new() -> Self {
568 Self {
569 config: SessionConfig::new(),
570 runtime_env: Arc::new(RuntimeEnv::default()),
571 }
572 }
573 }
574
575 #[async_trait::async_trait]
576 impl Session for MockSession {
577 fn session_id(&self) -> &str {
578 unimplemented!()
579 }
580
581 fn config(&self) -> &SessionConfig {
582 &self.config
583 }
584
585 async fn create_physical_plan(
586 &self,
587 _logical_plan: &LogicalPlan,
588 ) -> Result<Arc<dyn ExecutionPlan>> {
589 unimplemented!()
590 }
591
592 fn create_physical_expr(
593 &self,
594 _expr: Expr,
595 _df_schema: &DFSchema,
596 ) -> Result<Arc<dyn PhysicalExpr>> {
597 unimplemented!()
598 }
599
600 fn scalar_functions(&self) -> &HashMap<String, Arc<ScalarUDF>> {
601 unimplemented!()
602 }
603
604 fn aggregate_functions(&self) -> &HashMap<String, Arc<AggregateUDF>> {
605 unimplemented!()
606 }
607
608 fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
609 unimplemented!()
610 }
611
612 fn runtime_env(&self) -> &Arc<RuntimeEnv> {
613 &self.runtime_env
614 }
615
616 fn execution_props(&self) -> &ExecutionProps {
617 unimplemented!()
618 }
619
620 fn as_any(&self) -> &dyn Any {
621 unimplemented!()
622 }
623
624 fn table_options(&self) -> &TableOptions {
625 unimplemented!()
626 }
627
628 fn table_options_mut(&mut self) -> &mut TableOptions {
629 unimplemented!()
630 }
631
632 fn task_ctx(&self) -> Arc<TaskContext> {
633 unimplemented!()
634 }
635 }
636
637 #[tokio::test]
638 async fn test_infer_schema_stream() -> Result<()> {
639 for file in ["example.arrow", "example_stream.arrow"] {
640 let mut bytes = std::fs::read(format!("tests/data/{file}"))?;
641 bytes.truncate(bytes.len() - 20); let location = Path::parse(file)?;
643 let in_memory_store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
644 in_memory_store.put(&location, bytes.into()).await?;
645
646 let state = MockSession::new();
647 let object_meta = ObjectMeta {
648 location,
649 last_modified: DateTime::default(),
650 size: u64::MAX,
651 e_tag: None,
652 version: None,
653 };
654
655 let arrow_format = ArrowFormat {};
656 let expected = vec!["f0: Int64", "f1: Utf8", "f2: Boolean"];
657
658 for chunk_size in [7, 3000] {
661 let store =
662 Arc::new(ChunkedStore::new(in_memory_store.clone(), chunk_size));
663 let inferred_schema = arrow_format
664 .infer_schema(
665 &state,
666 &(store.clone() as Arc<dyn ObjectStore>),
667 std::slice::from_ref(&object_meta),
668 )
669 .await?;
670 let actual_fields = inferred_schema
671 .fields()
672 .iter()
673 .map(|f| format!("{}: {:?}", f.name(), f.data_type()))
674 .collect::<Vec<_>>();
675 assert_eq!(expected, actual_fields);
676 }
677 }
678 Ok(())
679 }
680
681 #[tokio::test]
682 async fn test_infer_schema_short_stream() -> Result<()> {
683 for file in ["example.arrow", "example_stream.arrow"] {
684 let mut bytes = std::fs::read(format!("tests/data/{file}"))?;
685 bytes.truncate(20); let location = Path::parse(file)?;
687 let in_memory_store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
688 in_memory_store.put(&location, bytes.into()).await?;
689
690 let state = MockSession::new();
691 let object_meta = ObjectMeta {
692 location,
693 last_modified: DateTime::default(),
694 size: u64::MAX,
695 e_tag: None,
696 version: None,
697 };
698
699 let arrow_format = ArrowFormat {};
700
701 let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), 7));
702 let err = arrow_format
703 .infer_schema(
704 &state,
705 &(store.clone() as Arc<dyn ObjectStore>),
706 std::slice::from_ref(&object_meta),
707 )
708 .await;
709
710 assert!(err.is_err());
711 assert_eq!(
712 "Arrow error: Parser error: Unexpected end of byte stream for Arrow IPC file",
713 err.unwrap_err().to_string().lines().next().unwrap()
714 );
715 }
716
717 Ok(())
718 }
719
720 #[tokio::test]
721 async fn test_format_detection_file_format() -> Result<()> {
722 let store = Arc::new(InMemory::new());
723 let path = Path::from("test.arrow");
724
725 let file_bytes = std::fs::read("tests/data/example.arrow")?;
726 store.put(&path, file_bytes.into()).await?;
727
728 let is_file = is_object_in_arrow_ipc_file_format(store.clone(), &path).await?;
729 assert!(is_file, "Should detect file format");
730 Ok(())
731 }
732
733 #[tokio::test]
734 async fn test_format_detection_stream_format() -> Result<()> {
735 let store = Arc::new(InMemory::new());
736 let path = Path::from("test_stream.arrow");
737
738 let stream_bytes = std::fs::read("tests/data/example_stream.arrow")?;
739 store.put(&path, stream_bytes.into()).await?;
740
741 let is_file = is_object_in_arrow_ipc_file_format(store.clone(), &path).await?;
742
743 assert!(!is_file, "Should detect stream format (not file)");
744
745 Ok(())
746 }
747
748 #[tokio::test]
749 async fn test_format_detection_corrupted_file() -> Result<()> {
750 let store = Arc::new(InMemory::new());
751 let path = Path::from("corrupted.arrow");
752
753 store
754 .put(&path, Bytes::from(vec![0x43, 0x4f, 0x52, 0x41]).into())
755 .await?;
756
757 let is_file = is_object_in_arrow_ipc_file_format(store.clone(), &path).await?;
758
759 assert!(
760 !is_file,
761 "Corrupted file should not be detected as file format"
762 );
763
764 Ok(())
765 }
766
767 #[tokio::test]
768 async fn test_format_detection_empty_file() -> Result<()> {
769 let store = Arc::new(InMemory::new());
770 let path = Path::from("empty.arrow");
771
772 store.put(&path, Bytes::new().into()).await?;
773
774 let result = is_object_in_arrow_ipc_file_format(store.clone(), &path).await;
775
776 assert!(result.is_err(), "Empty file should error");
778
779 Ok(())
780 }
781}