orc_rust/
arrow_writer.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::io::Write;
19
20use arrow::{
21    array::RecordBatch,
22    datatypes::{DataType as ArrowDataType, SchemaRef},
23};
24use prost::Message;
25use snafu::{ensure, ResultExt};
26
27use crate::{
28    error::{IoSnafu, Result, UnexpectedSnafu},
29    memory::EstimateMemory,
30    proto,
31    writer::stripe::{StripeInformation, StripeWriter},
32};
33
34/// Construct an [`ArrowWriter`] to encode [`RecordBatch`]es into a single
35/// ORC file.
36pub struct ArrowWriterBuilder<W> {
37    writer: W,
38    schema: SchemaRef,
39    batch_size: usize,
40    stripe_byte_size: usize,
41}
42
43impl<W: Write> ArrowWriterBuilder<W> {
44    /// Create a new [`ArrowWriterBuilder`], which will write an ORC file to
45    /// the provided writer, with the expected Arrow schema.
46    pub fn new(writer: W, schema: SchemaRef) -> Self {
47        Self {
48            writer,
49            schema,
50            batch_size: 1024,
51            // 64 MiB
52            stripe_byte_size: 64 * 1024 * 1024,
53        }
54    }
55
56    /// Batch size controls the encoding behaviour, where `batch_size` values
57    /// are encoded at a time. Default is `1024`.
58    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
59        self.batch_size = batch_size;
60        self
61    }
62
63    /// The approximate size of stripes. Default is `64MiB`.
64    pub fn with_stripe_byte_size(mut self, stripe_byte_size: usize) -> Self {
65        self.stripe_byte_size = stripe_byte_size;
66        self
67    }
68
69    /// Construct an [`ArrowWriter`] ready to encode [`RecordBatch`]es into
70    /// an ORC file.
71    pub fn try_build(mut self) -> Result<ArrowWriter<W>> {
72        // Required magic "ORC" bytes at start of file
73        self.writer.write_all(b"ORC").context(IoSnafu)?;
74        let writer = StripeWriter::new(self.writer, &self.schema);
75        Ok(ArrowWriter {
76            writer,
77            schema: self.schema,
78            batch_size: self.batch_size,
79            stripe_byte_size: self.stripe_byte_size,
80            written_stripes: vec![],
81            // Accounting for the 3 magic bytes above
82            total_bytes_written: 3,
83        })
84    }
85}
86
87/// Encodes [`RecordBatch`]es into an ORC file. Will encode `batch_size` rows
88/// at a time into a stripe, flushing the stripe to the underlying writer when
89/// it's estimated memory footprint exceeds the configures `stripe_byte_size`.
90pub struct ArrowWriter<W> {
91    writer: StripeWriter<W>,
92    schema: SchemaRef,
93    batch_size: usize,
94    stripe_byte_size: usize,
95    written_stripes: Vec<StripeInformation>,
96    /// Used to keep track of progress in file so far (instead of needing Seek on the writer)
97    total_bytes_written: u64,
98}
99
100impl<W: Write> ArrowWriter<W> {
101    /// Encode the provided batch at `batch_size` rows at a time, flushing any
102    /// stripes that exceed the configured stripe size.
103    pub fn write(&mut self, batch: &RecordBatch) -> Result<()> {
104        ensure!(
105            batch.schema() == self.schema,
106            UnexpectedSnafu {
107                msg: "RecordBatch doesn't match expected schema"
108            }
109        );
110
111        for offset in (0..batch.num_rows()).step_by(self.batch_size) {
112            let length = self.batch_size.min(batch.num_rows() - offset);
113            let batch = batch.slice(offset, length);
114            self.writer.encode_batch(&batch)?;
115
116            // TODO: be able to flush whilst writing a batch (instead of between batches)
117            // Flush stripe when it exceeds estimated configured size
118            if self.writer.estimate_memory_size() > self.stripe_byte_size {
119                self.flush_stripe()?;
120            }
121        }
122        Ok(())
123    }
124
125    /// Flush any buffered data that hasn't been written, and write the stripe
126    /// footer metadata.
127    pub fn flush_stripe(&mut self) -> Result<()> {
128        let info = self.writer.finish_stripe(self.total_bytes_written)?;
129        self.total_bytes_written += info.total_byte_size();
130        self.written_stripes.push(info);
131        Ok(())
132    }
133
134    /// Flush the current stripe if it is still in progress, and write the tail
135    /// metadata and close the writer.
136    pub fn close(mut self) -> Result<()> {
137        // Flush in-progress stripe
138        if self.writer.row_count > 0 {
139            self.flush_stripe()?;
140        }
141        let footer = serialize_footer(&self.written_stripes, &self.schema);
142        let footer = footer.encode_to_vec();
143        let postscript = serialize_postscript(footer.len() as u64);
144        let postscript = postscript.encode_to_vec();
145        let postscript_len = postscript.len() as u8;
146
147        let mut writer = self.writer.finish();
148        writer.write_all(&footer).context(IoSnafu)?;
149        writer.write_all(&postscript).context(IoSnafu)?;
150        // Postscript length as last byte
151        writer.write_all(&[postscript_len]).context(IoSnafu)?;
152
153        // TODO: return file metadata
154        Ok(())
155    }
156}
157
158fn serialize_schema(schema: &SchemaRef) -> Vec<proto::Type> {
159    let mut types = vec![];
160
161    let field_names = schema
162        .fields()
163        .iter()
164        .map(|f| f.name().to_owned())
165        .collect();
166    // TODO: consider nested types
167    let subtypes = (1..(schema.fields().len() as u32 + 1)).collect();
168    let root_type = proto::Type {
169        kind: Some(proto::r#type::Kind::Struct.into()),
170        subtypes,
171        field_names,
172        maximum_length: None,
173        precision: None,
174        scale: None,
175        attributes: vec![],
176    };
177    types.push(root_type);
178    for field in schema.fields() {
179        let t = match field.data_type() {
180            ArrowDataType::Float32 => proto::Type {
181                kind: Some(proto::r#type::Kind::Float.into()),
182                ..Default::default()
183            },
184            ArrowDataType::Float64 => proto::Type {
185                kind: Some(proto::r#type::Kind::Double.into()),
186                ..Default::default()
187            },
188            ArrowDataType::Int8 => proto::Type {
189                kind: Some(proto::r#type::Kind::Byte.into()),
190                ..Default::default()
191            },
192            ArrowDataType::Int16 => proto::Type {
193                kind: Some(proto::r#type::Kind::Short.into()),
194                ..Default::default()
195            },
196            ArrowDataType::Int32 => proto::Type {
197                kind: Some(proto::r#type::Kind::Int.into()),
198                ..Default::default()
199            },
200            ArrowDataType::Int64 => proto::Type {
201                kind: Some(proto::r#type::Kind::Long.into()),
202                ..Default::default()
203            },
204            ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => proto::Type {
205                kind: Some(proto::r#type::Kind::String.into()),
206                ..Default::default()
207            },
208            ArrowDataType::Binary | ArrowDataType::LargeBinary => proto::Type {
209                kind: Some(proto::r#type::Kind::Binary.into()),
210                ..Default::default()
211            },
212            ArrowDataType::Boolean => proto::Type {
213                kind: Some(proto::r#type::Kind::Boolean.into()),
214                ..Default::default()
215            },
216            // TODO: support more types
217            _ => unimplemented!("unsupported datatype"),
218        };
219        types.push(t);
220    }
221    types
222}
223
224fn serialize_footer(stripes: &[StripeInformation], schema: &SchemaRef) -> proto::Footer {
225    let body_length = stripes
226        .iter()
227        .map(|s| s.index_length + s.data_length + s.footer_length)
228        .sum::<u64>();
229    let number_of_rows = stripes.iter().map(|s| s.row_count as u64).sum::<u64>();
230    let stripes = stripes.iter().map(From::from).collect();
231    let types = serialize_schema(schema);
232    proto::Footer {
233        header_length: Some(3),
234        content_length: Some(body_length + 3),
235        stripes,
236        types,
237        metadata: vec![],
238        number_of_rows: Some(number_of_rows),
239        statistics: vec![],
240        row_index_stride: None,
241        writer: Some(u32::MAX),
242        encryption: None,
243        calendar: None,
244        software_version: None,
245    }
246}
247
248fn serialize_postscript(footer_length: u64) -> proto::PostScript {
249    proto::PostScript {
250        footer_length: Some(footer_length),
251        compression: Some(proto::CompressionKind::None.into()), // TODO: support compression
252        compression_block_size: None,
253        version: vec![0, 12],
254        metadata_length: Some(0),       // TODO: statistics
255        writer_version: Some(u32::MAX), // TODO: check which version to use
256        stripe_statistics_length: None,
257        magic: Some("ORC".to_string()),
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use std::sync::Arc;
264
265    use arrow::{
266        array::{
267            Array, BinaryArray, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
268            Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, RecordBatchReader,
269            StringArray,
270        },
271        buffer::NullBuffer,
272        compute::concat_batches,
273        datatypes::{DataType as ArrowDataType, Field, Schema},
274    };
275    use bytes::Bytes;
276
277    use crate::{stripe::Stripe, ArrowReaderBuilder};
278
279    use super::*;
280
281    fn roundtrip(batches: &[RecordBatch]) -> Vec<RecordBatch> {
282        let mut f = vec![];
283        let mut writer = ArrowWriterBuilder::new(&mut f, batches[0].schema())
284            .try_build()
285            .unwrap();
286        for batch in batches {
287            writer.write(batch).unwrap();
288        }
289        writer.close().unwrap();
290
291        let f = Bytes::from(f);
292        let reader = ArrowReaderBuilder::try_new(f).unwrap().build();
293        reader.collect::<Result<Vec<_>, _>>().unwrap()
294    }
295
296    #[test]
297    fn test_roundtrip_write() {
298        let f32_array = Arc::new(Float32Array::from(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]));
299        let f64_array = Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]));
300        let int8_array = Arc::new(Int8Array::from(vec![0, 1, 2, 3, 4, 5, 6]));
301        let int16_array = Arc::new(Int16Array::from(vec![0, 1, 2, 3, 4, 5, 6]));
302        let int32_array = Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6]));
303        let int64_array = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, 6]));
304        let utf8_array = Arc::new(StringArray::from(vec![
305            "Hello",
306            "there",
307            "楡井希実",
308            "💯",
309            "ORC",
310            "",
311            "123",
312        ]));
313        let binary_array = Arc::new(BinaryArray::from(vec![
314            "Hello".as_bytes(),
315            "there".as_bytes(),
316            "楡井希実".as_bytes(),
317            "💯".as_bytes(),
318            "ORC".as_bytes(),
319            "".as_bytes(),
320            "123".as_bytes(),
321        ]));
322        let boolean_array = Arc::new(BooleanArray::from(vec![
323            true, false, true, false, true, true, false,
324        ]));
325        let schema = Schema::new(vec![
326            Field::new("f32", ArrowDataType::Float32, false),
327            Field::new("f64", ArrowDataType::Float64, false),
328            Field::new("int8", ArrowDataType::Int8, false),
329            Field::new("int16", ArrowDataType::Int16, false),
330            Field::new("int32", ArrowDataType::Int32, false),
331            Field::new("int64", ArrowDataType::Int64, false),
332            Field::new("utf8", ArrowDataType::Utf8, false),
333            Field::new("binary", ArrowDataType::Binary, false),
334            Field::new("boolean", ArrowDataType::Boolean, false),
335        ]);
336
337        let batch = RecordBatch::try_new(
338            Arc::new(schema),
339            vec![
340                f32_array,
341                f64_array,
342                int8_array,
343                int16_array,
344                int32_array,
345                int64_array,
346                utf8_array,
347                binary_array,
348                boolean_array,
349            ],
350        )
351        .unwrap();
352
353        let rows = roundtrip(std::slice::from_ref(&batch));
354        assert_eq!(batch, rows[0]);
355    }
356
357    #[test]
358    fn test_roundtrip_write_large_type() {
359        let large_utf8_array = Arc::new(LargeStringArray::from(vec![
360            "Hello",
361            "there",
362            "楡井希実",
363            "💯",
364            "ORC",
365            "",
366            "123",
367        ]));
368        let large_binary_array = Arc::new(LargeBinaryArray::from(vec![
369            "Hello".as_bytes(),
370            "there".as_bytes(),
371            "楡井希実".as_bytes(),
372            "💯".as_bytes(),
373            "ORC".as_bytes(),
374            "".as_bytes(),
375            "123".as_bytes(),
376        ]));
377        let schema = Schema::new(vec![
378            Field::new("large_utf8", ArrowDataType::LargeUtf8, false),
379            Field::new("large_binary", ArrowDataType::LargeBinary, false),
380        ]);
381        let batch =
382            RecordBatch::try_new(Arc::new(schema), vec![large_utf8_array, large_binary_array])
383                .unwrap();
384
385        let rows = roundtrip(&[batch]);
386
387        // Currently we read all String/Binary columns from ORC as plain StringArray/BinaryArray
388        let utf8_array = Arc::new(StringArray::from(vec![
389            "Hello",
390            "there",
391            "楡井希実",
392            "💯",
393            "ORC",
394            "",
395            "123",
396        ]));
397        let binary_array = Arc::new(BinaryArray::from(vec![
398            "Hello".as_bytes(),
399            "there".as_bytes(),
400            "楡井希実".as_bytes(),
401            "💯".as_bytes(),
402            "ORC".as_bytes(),
403            "".as_bytes(),
404            "123".as_bytes(),
405        ]));
406        let schema = Schema::new(vec![
407            Field::new("large_utf8", ArrowDataType::Utf8, false),
408            Field::new("large_binary", ArrowDataType::Binary, false),
409        ]);
410        let batch = RecordBatch::try_new(Arc::new(schema), vec![utf8_array, binary_array]).unwrap();
411        assert_eq!(batch, rows[0]);
412    }
413
414    #[test]
415    fn test_write_small_stripes() {
416        // Set small stripe size to ensure writing across multiple stripes works
417        let data: Vec<i64> = (0..1_000_000).collect();
418        let int64_array = Arc::new(Int64Array::from(data));
419        let schema = Schema::new(vec![Field::new("int64", ArrowDataType::Int64, true)]);
420
421        let batch = RecordBatch::try_new(Arc::new(schema), vec![int64_array]).unwrap();
422
423        let mut f = vec![];
424        let mut writer = ArrowWriterBuilder::new(&mut f, batch.schema())
425            .with_stripe_byte_size(256)
426            .try_build()
427            .unwrap();
428        writer.write(&batch).unwrap();
429        writer.close().unwrap();
430
431        let f = Bytes::from(f);
432        let reader = ArrowReaderBuilder::try_new(f).unwrap().build();
433        let schema = reader.schema();
434        // Current reader doesn't read a batch across stripe boundaries, so we expect
435        // more than one batch to prove multiple stripes are being written here
436        let rows = reader.collect::<Result<Vec<_>, _>>().unwrap();
437        assert!(
438            rows.len() > 1,
439            "must have written more than 1 stripe (each stripe read as separate recordbatch)"
440        );
441        let actual = concat_batches(&schema, rows.iter()).unwrap();
442        assert_eq!(batch, actual);
443    }
444
445    #[test]
446    fn test_write_inconsistent_null_buffers() {
447        // When writing arrays where null buffer can appear/disappear between writes
448        let schema = Arc::new(Schema::new(vec![Field::new(
449            "int64",
450            ArrowDataType::Int64,
451            true,
452        )]));
453
454        // Ensure first batch has array with no null buffer
455        let array_no_nulls = Arc::new(Int64Array::from(vec![1, 2, 3]));
456        assert!(array_no_nulls.nulls().is_none());
457        // But subsequent batch has array with null buffer
458        let array_with_nulls = Arc::new(Int64Array::from(vec![None, Some(4), None]));
459        assert!(array_with_nulls.nulls().is_some());
460
461        let batch1 = RecordBatch::try_new(schema.clone(), vec![array_no_nulls]).unwrap();
462        let batch2 = RecordBatch::try_new(schema.clone(), vec![array_with_nulls]).unwrap();
463
464        // ORC writer should be able to handle this gracefully
465        let expected_array = Arc::new(Int64Array::from(vec![
466            Some(1),
467            Some(2),
468            Some(3),
469            None,
470            Some(4),
471            None,
472        ]));
473        let expected_batch = RecordBatch::try_new(schema, vec![expected_array]).unwrap();
474
475        let rows = roundtrip(&[batch1, batch2]);
476        assert_eq!(expected_batch, rows[0]);
477    }
478
479    #[test]
480    fn test_empty_null_buffers() {
481        // Create an ORC file with present streams, but which have no nulls.
482        // When this file is read then the resulting Arrow arrays show have
483        // NO null buffer, even though there is a present stream.
484        let schema = Arc::new(Schema::new(vec![Field::new(
485            "int64",
486            ArrowDataType::Int64,
487            true,
488        )]));
489
490        // Array with null buffer but has no nulls
491        let array_empty_nulls = Arc::new(Int64Array::from_iter_values_with_nulls(
492            vec![1],
493            Some(NullBuffer::from_iter(vec![true])),
494        ));
495        assert!(array_empty_nulls.nulls().is_some());
496        assert!(array_empty_nulls.null_count() == 0);
497
498        let batch = RecordBatch::try_new(schema, vec![array_empty_nulls]).unwrap();
499
500        // Encoding to bytes
501        let mut f = vec![];
502        let mut writer = ArrowWriterBuilder::new(&mut f, batch.schema())
503            .try_build()
504            .unwrap();
505        writer.write(&batch).unwrap();
506        writer.close().unwrap();
507        let mut f = Bytes::from(f);
508        let builder = ArrowReaderBuilder::try_new(f.clone()).unwrap();
509
510        // Ensure the ORC file we wrote indeed has a present stream
511        let stripe = Stripe::new(
512            &mut f,
513            &builder.file_metadata,
514            builder.file_metadata().root_data_type(),
515            &builder.file_metadata().stripe_metadatas()[0],
516        )
517        .unwrap();
518        assert_eq!(stripe.columns().len(), 1);
519        // Make sure we're getting the right column
520        assert_eq!(stripe.columns()[0].name(), "int64");
521        // Then check present stream
522        let present_stream = stripe
523            .stream_map()
524            .get_opt(&stripe.columns()[0], proto::stream::Kind::Present);
525        assert!(present_stream.is_some());
526
527        // Decoding from bytes
528        let reader = builder.build();
529        let rows = reader.collect::<Result<Vec<_>, _>>().unwrap();
530
531        assert_eq!(rows.len(), 1);
532        assert_eq!(rows[0].num_columns(), 1);
533        // Ensure read array has no null buffer
534        assert!(rows[0].column(0).nulls().is_none());
535    }
536}