datafusion/datasource/file_format/
arrow.rs1use std::any::Any;
23use std::borrow::Cow;
24use std::collections::HashMap;
25use std::fmt::{self, Debug};
26use std::sync::Arc;
27
28use super::file_compression_type::FileCompressionType;
29use super::write::demux::DemuxedStreamReceiver;
30use super::write::SharedBuffer;
31use super::FileFormatFactory;
32use crate::datasource::file_format::write::get_writer_schema;
33use crate::datasource::file_format::FileFormat;
34use crate::datasource::physical_plan::{ArrowSource, FileSink, FileSinkConfig};
35use crate::error::Result;
36use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
37
38use arrow::datatypes::{Schema, SchemaRef};
39use arrow::error::ArrowError;
40use arrow::ipc::convert::fb_to_schema;
41use arrow::ipc::reader::FileReader;
42use arrow::ipc::writer::IpcWriteOptions;
43use arrow::ipc::{root_as_message, CompressionType};
44use datafusion_catalog::Session;
45use datafusion_common::parsers::CompressionTypeVariant;
46use datafusion_common::{
47 not_impl_err, DataFusionError, GetExt, Statistics, DEFAULT_ARROW_EXTENSION,
48};
49use datafusion_common_runtime::{JoinSet, SpawnedTask};
50use datafusion_datasource::display::FileGroupDisplay;
51use datafusion_datasource::file::FileSource;
52use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder};
53use datafusion_datasource::sink::{DataSink, DataSinkExec};
54use datafusion_datasource::write::ObjectWriterBuilder;
55use datafusion_execution::{SendableRecordBatchStream, TaskContext};
56use datafusion_expr::dml::InsertOp;
57use datafusion_physical_expr_common::sort_expr::LexRequirement;
58
59use async_trait::async_trait;
60use bytes::Bytes;
61use datafusion_datasource::source::DataSourceExec;
62use futures::stream::BoxStream;
63use futures::StreamExt;
64use object_store::{GetResultPayload, ObjectMeta, ObjectStore};
65use tokio::io::AsyncWriteExt;
66
67const INITIAL_BUFFER_BYTES: usize = 1048576;
70
71const BUFFER_FLUSH_BYTES: usize = 1024000;
73
74#[derive(Default, Debug)]
75pub struct ArrowFormatFactory;
77
78impl ArrowFormatFactory {
79 pub fn new() -> Self {
81 Self {}
82 }
83}
84
85impl FileFormatFactory for ArrowFormatFactory {
86 fn create(
87 &self,
88 _state: &dyn Session,
89 _format_options: &HashMap<String, String>,
90 ) -> Result<Arc<dyn FileFormat>> {
91 Ok(Arc::new(ArrowFormat))
92 }
93
94 fn default(&self) -> Arc<dyn FileFormat> {
95 Arc::new(ArrowFormat)
96 }
97
98 fn as_any(&self) -> &dyn Any {
99 self
100 }
101}
102
103impl GetExt for ArrowFormatFactory {
104 fn get_ext(&self) -> String {
105 DEFAULT_ARROW_EXTENSION[1..].to_string()
107 }
108}
109
110#[derive(Default, Debug)]
112pub struct ArrowFormat;
113
114#[async_trait]
115impl FileFormat for ArrowFormat {
116 fn as_any(&self) -> &dyn Any {
117 self
118 }
119
120 fn get_ext(&self) -> String {
121 ArrowFormatFactory::new().get_ext()
122 }
123
124 fn get_ext_with_compression(
125 &self,
126 file_compression_type: &FileCompressionType,
127 ) -> Result<String> {
128 let ext = self.get_ext();
129 match file_compression_type.get_variant() {
130 CompressionTypeVariant::UNCOMPRESSED => Ok(ext),
131 _ => Err(DataFusionError::Internal(
132 "Arrow FileFormat does not support compression.".into(),
133 )),
134 }
135 }
136
137 fn compression_type(&self) -> Option<FileCompressionType> {
138 None
139 }
140
141 async fn infer_schema(
142 &self,
143 _state: &dyn Session,
144 store: &Arc<dyn ObjectStore>,
145 objects: &[ObjectMeta],
146 ) -> Result<SchemaRef> {
147 let mut schemas = vec![];
148 for object in objects {
149 let r = store.as_ref().get(&object.location).await?;
150 let schema = match r.payload {
151 #[cfg(not(target_arch = "wasm32"))]
152 GetResultPayload::File(mut file, _) => {
153 let reader = FileReader::try_new(&mut file, None)?;
154 reader.schema()
155 }
156 GetResultPayload::Stream(stream) => {
157 infer_schema_from_file_stream(stream).await?
158 }
159 };
160 schemas.push(schema.as_ref().clone());
161 }
162 let merged_schema = Schema::try_merge(schemas)?;
163 Ok(Arc::new(merged_schema))
164 }
165
166 async fn infer_stats(
167 &self,
168 _state: &dyn Session,
169 _store: &Arc<dyn ObjectStore>,
170 table_schema: SchemaRef,
171 _object: &ObjectMeta,
172 ) -> Result<Statistics> {
173 Ok(Statistics::new_unknown(&table_schema))
174 }
175
176 async fn create_physical_plan(
177 &self,
178 _state: &dyn Session,
179 conf: FileScanConfig,
180 ) -> Result<Arc<dyn ExecutionPlan>> {
181 let source = Arc::new(ArrowSource::default());
182 let config = FileScanConfigBuilder::from(conf)
183 .with_source(source)
184 .build();
185
186 Ok(DataSourceExec::from_data_source(config))
187 }
188
189 async fn create_writer_physical_plan(
190 &self,
191 input: Arc<dyn ExecutionPlan>,
192 _state: &dyn Session,
193 conf: FileSinkConfig,
194 order_requirements: Option<LexRequirement>,
195 ) -> Result<Arc<dyn ExecutionPlan>> {
196 if conf.insert_op != InsertOp::Append {
197 return not_impl_err!("Overwrites are not implemented yet for Arrow format");
198 }
199
200 let sink = Arc::new(ArrowFileSink::new(conf));
201
202 Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _)
203 }
204
205 fn file_source(&self) -> Arc<dyn FileSource> {
206 Arc::new(ArrowSource::default())
207 }
208}
209
210struct ArrowFileSink {
212 config: FileSinkConfig,
213}
214
215impl ArrowFileSink {
216 fn new(config: FileSinkConfig) -> Self {
217 Self { config }
218 }
219}
220
221#[async_trait]
222impl FileSink for ArrowFileSink {
223 fn config(&self) -> &FileSinkConfig {
224 &self.config
225 }
226
227 async fn spawn_writer_tasks_and_join(
228 &self,
229 context: &Arc<TaskContext>,
230 demux_task: SpawnedTask<Result<()>>,
231 mut file_stream_rx: DemuxedStreamReceiver,
232 object_store: Arc<dyn ObjectStore>,
233 ) -> Result<u64> {
234 let mut file_write_tasks: JoinSet<std::result::Result<usize, DataFusionError>> =
235 JoinSet::new();
236
237 let ipc_options =
238 IpcWriteOptions::try_new(64, false, arrow_ipc::MetadataVersion::V5)?
239 .try_with_compression(Some(CompressionType::LZ4_FRAME))?;
240 while let Some((path, mut rx)) = file_stream_rx.recv().await {
241 let shared_buffer = SharedBuffer::new(INITIAL_BUFFER_BYTES);
242 let mut arrow_writer = arrow_ipc::writer::FileWriter::try_new_with_options(
243 shared_buffer.clone(),
244 &get_writer_schema(&self.config),
245 ipc_options.clone(),
246 )?;
247 let mut object_store_writer = ObjectWriterBuilder::new(
248 FileCompressionType::UNCOMPRESSED,
249 &path,
250 Arc::clone(&object_store),
251 )
252 .with_buffer_size(Some(
253 context
254 .session_config()
255 .options()
256 .execution
257 .objectstore_writer_buffer_size,
258 ))
259 .build()?;
260 file_write_tasks.spawn(async move {
261 let mut row_count = 0;
262 while let Some(batch) = rx.recv().await {
263 row_count += batch.num_rows();
264 arrow_writer.write(&batch)?;
265 let mut buff_to_flush = shared_buffer.buffer.try_lock().unwrap();
266 if buff_to_flush.len() > BUFFER_FLUSH_BYTES {
267 object_store_writer
268 .write_all(buff_to_flush.as_slice())
269 .await?;
270 buff_to_flush.clear();
271 }
272 }
273 arrow_writer.finish()?;
274 let final_buff = shared_buffer.buffer.try_lock().unwrap();
275
276 object_store_writer.write_all(final_buff.as_slice()).await?;
277 object_store_writer.shutdown().await?;
278 Ok(row_count)
279 });
280 }
281
282 let mut row_count = 0;
283 while let Some(result) = file_write_tasks.join_next().await {
284 match result {
285 Ok(r) => {
286 row_count += r?;
287 }
288 Err(e) => {
289 if e.is_panic() {
290 std::panic::resume_unwind(e.into_panic());
291 } else {
292 unreachable!();
293 }
294 }
295 }
296 }
297
298 demux_task
299 .join_unwind()
300 .await
301 .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??;
302 Ok(row_count as u64)
303 }
304}
305
306impl Debug for ArrowFileSink {
307 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
308 f.debug_struct("ArrowFileSink").finish()
309 }
310}
311
312impl DisplayAs for ArrowFileSink {
313 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314 match t {
315 DisplayFormatType::Default | DisplayFormatType::Verbose => {
316 write!(f, "ArrowFileSink(file_groups=",)?;
317 FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?;
318 write!(f, ")")
319 }
320 DisplayFormatType::TreeRender => {
321 writeln!(f, "format: arrow")?;
322 write!(f, "file={}", &self.config.original_url)
323 }
324 }
325 }
326}
327
328#[async_trait]
329impl DataSink for ArrowFileSink {
330 fn as_any(&self) -> &dyn Any {
331 self
332 }
333
334 fn schema(&self) -> &SchemaRef {
335 self.config.output_schema()
336 }
337
338 async fn write_all(
339 &self,
340 data: SendableRecordBatchStream,
341 context: &Arc<TaskContext>,
342 ) -> Result<u64> {
343 FileSink::write_all(self, data, context).await
344 }
345}
346
347const ARROW_MAGIC: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1'];
348const CONTINUATION_MARKER: [u8; 4] = [0xff; 4];
349
350async fn infer_schema_from_file_stream(
353 mut stream: BoxStream<'static, object_store::Result<Bytes>>,
354) -> Result<SchemaRef> {
355 let bytes = collect_at_least_n_bytes(&mut stream, 16, None).await?;
366
367 if bytes[0..6] != ARROW_MAGIC {
369 return Err(ArrowError::ParseError(
370 "Arrow file does not contain correct header".to_string(),
371 ))?;
372 }
373
374 let (meta_len, rest_of_bytes_start_index) = if bytes[8..12] == CONTINUATION_MARKER {
376 (&bytes[12..16], 16)
377 } else {
378 (&bytes[8..12], 12)
379 };
380
381 let meta_len = [meta_len[0], meta_len[1], meta_len[2], meta_len[3]];
382 let meta_len = i32::from_le_bytes(meta_len);
383
384 let block_data = if bytes[rest_of_bytes_start_index..].len() < meta_len as usize {
386 let mut block_data = Vec::with_capacity(meta_len as usize);
388 block_data.extend_from_slice(&bytes[rest_of_bytes_start_index..]);
390 let size_to_read = meta_len as usize - block_data.len();
391 let block_data =
392 collect_at_least_n_bytes(&mut stream, size_to_read, Some(block_data)).await?;
393 Cow::Owned(block_data)
394 } else {
395 let end_index = meta_len as usize + rest_of_bytes_start_index;
397 let block_data = &bytes[rest_of_bytes_start_index..end_index];
398 Cow::Borrowed(block_data)
399 };
400
401 let message = root_as_message(&block_data).map_err(|err| {
403 ArrowError::ParseError(format!("Unable to read IPC message as metadata: {err:?}"))
404 })?;
405 let ipc_schema = message.header_as_schema().ok_or_else(|| {
406 ArrowError::IpcError("Unable to read IPC message as schema".to_string())
407 })?;
408 let schema = fb_to_schema(ipc_schema);
409
410 Ok(Arc::new(schema))
411}
412
413async fn collect_at_least_n_bytes(
414 stream: &mut BoxStream<'static, object_store::Result<Bytes>>,
415 n: usize,
416 extend_from: Option<Vec<u8>>,
417) -> Result<Vec<u8>> {
418 let mut buf = extend_from.unwrap_or_else(|| Vec::with_capacity(n));
419 let n = n + buf.len();
421 while let Some(bytes) = stream.next().await.transpose()? {
422 buf.extend_from_slice(&bytes);
423 if buf.len() >= n {
424 break;
425 }
426 }
427 if buf.len() < n {
428 return Err(ArrowError::ParseError(
429 "Unexpected end of byte stream for Arrow IPC file".to_string(),
430 ))?;
431 }
432 Ok(buf)
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use crate::execution::context::SessionContext;
439
440 use chrono::DateTime;
441 use object_store::{chunked::ChunkedStore, memory::InMemory, path::Path};
442
443 #[tokio::test]
444 async fn test_infer_schema_stream() -> Result<()> {
445 let mut bytes = std::fs::read("tests/data/example.arrow")?;
446 bytes.truncate(bytes.len() - 20); let location = Path::parse("example.arrow")?;
448 let in_memory_store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
449 in_memory_store.put(&location, bytes.into()).await?;
450
451 let session_ctx = SessionContext::new();
452 let state = session_ctx.state();
453 let object_meta = ObjectMeta {
454 location,
455 last_modified: DateTime::default(),
456 size: u64::MAX,
457 e_tag: None,
458 version: None,
459 };
460
461 let arrow_format = ArrowFormat {};
462 let expected = vec!["f0: Int64", "f1: Utf8", "f2: Boolean"];
463
464 for chunk_size in [7, 3000] {
467 let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), chunk_size));
468 let inferred_schema = arrow_format
469 .infer_schema(
470 &state,
471 &(store.clone() as Arc<dyn ObjectStore>),
472 std::slice::from_ref(&object_meta),
473 )
474 .await?;
475 let actual_fields = inferred_schema
476 .fields()
477 .iter()
478 .map(|f| format!("{}: {:?}", f.name(), f.data_type()))
479 .collect::<Vec<_>>();
480 assert_eq!(expected, actual_fields);
481 }
482
483 Ok(())
484 }
485
486 #[tokio::test]
487 async fn test_infer_schema_short_stream() -> Result<()> {
488 let mut bytes = std::fs::read("tests/data/example.arrow")?;
489 bytes.truncate(20); let location = Path::parse("example.arrow")?;
491 let in_memory_store: Arc<dyn ObjectStore> = Arc::new(InMemory::new());
492 in_memory_store.put(&location, bytes.into()).await?;
493
494 let session_ctx = SessionContext::new();
495 let state = session_ctx.state();
496 let object_meta = ObjectMeta {
497 location,
498 last_modified: DateTime::default(),
499 size: u64::MAX,
500 e_tag: None,
501 version: None,
502 };
503
504 let arrow_format = ArrowFormat {};
505
506 let store = Arc::new(ChunkedStore::new(in_memory_store.clone(), 7));
507 let err = arrow_format
508 .infer_schema(
509 &state,
510 &(store.clone() as Arc<dyn ObjectStore>),
511 std::slice::from_ref(&object_meta),
512 )
513 .await;
514
515 assert!(err.is_err());
516 assert_eq!(
517 "Arrow error: Parser error: Unexpected end of byte stream for Arrow IPC file",
518 err.unwrap_err().to_string().lines().next().unwrap()
519 );
520
521 Ok(())
522 }
523}