1use std::io::Write;
19
20use arrow::{
21 array::RecordBatch,
22 datatypes::{DataType as ArrowDataType, SchemaRef},
23};
24use prost::Message;
25use snafu::{ensure, ResultExt};
26
27use crate::{
28 error::{IoSnafu, Result, UnexpectedSnafu},
29 memory::EstimateMemory,
30 proto,
31 writer::stripe::{StripeInformation, StripeWriter},
32};
33
34pub struct ArrowWriterBuilder<W> {
37 writer: W,
38 schema: SchemaRef,
39 batch_size: usize,
40 stripe_byte_size: usize,
41}
42
43impl<W: Write> ArrowWriterBuilder<W> {
44 pub fn new(writer: W, schema: SchemaRef) -> Self {
47 Self {
48 writer,
49 schema,
50 batch_size: 1024,
51 stripe_byte_size: 64 * 1024 * 1024,
53 }
54 }
55
56 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
59 self.batch_size = batch_size;
60 self
61 }
62
63 pub fn with_stripe_byte_size(mut self, stripe_byte_size: usize) -> Self {
65 self.stripe_byte_size = stripe_byte_size;
66 self
67 }
68
69 pub fn try_build(mut self) -> Result<ArrowWriter<W>> {
72 self.writer.write_all(b"ORC").context(IoSnafu)?;
74 let writer = StripeWriter::new(self.writer, &self.schema);
75 Ok(ArrowWriter {
76 writer,
77 schema: self.schema,
78 batch_size: self.batch_size,
79 stripe_byte_size: self.stripe_byte_size,
80 written_stripes: vec![],
81 total_bytes_written: 3,
83 })
84 }
85}
86
87pub struct ArrowWriter<W> {
91 writer: StripeWriter<W>,
92 schema: SchemaRef,
93 batch_size: usize,
94 stripe_byte_size: usize,
95 written_stripes: Vec<StripeInformation>,
96 total_bytes_written: u64,
98}
99
100impl<W: Write> ArrowWriter<W> {
101 pub fn write(&mut self, batch: &RecordBatch) -> Result<()> {
104 ensure!(
105 batch.schema() == self.schema,
106 UnexpectedSnafu {
107 msg: "RecordBatch doesn't match expected schema"
108 }
109 );
110
111 for offset in (0..batch.num_rows()).step_by(self.batch_size) {
112 let length = self.batch_size.min(batch.num_rows() - offset);
113 let batch = batch.slice(offset, length);
114 self.writer.encode_batch(&batch)?;
115
116 if self.writer.estimate_memory_size() > self.stripe_byte_size {
119 self.flush_stripe()?;
120 }
121 }
122 Ok(())
123 }
124
125 pub fn flush_stripe(&mut self) -> Result<()> {
128 let info = self.writer.finish_stripe(self.total_bytes_written)?;
129 self.total_bytes_written += info.total_byte_size();
130 self.written_stripes.push(info);
131 Ok(())
132 }
133
134 pub fn close(mut self) -> Result<()> {
137 if self.writer.row_count > 0 {
139 self.flush_stripe()?;
140 }
141 let footer = serialize_footer(&self.written_stripes, &self.schema);
142 let footer = footer.encode_to_vec();
143 let postscript = serialize_postscript(footer.len() as u64);
144 let postscript = postscript.encode_to_vec();
145 let postscript_len = postscript.len() as u8;
146
147 let mut writer = self.writer.finish();
148 writer.write_all(&footer).context(IoSnafu)?;
149 writer.write_all(&postscript).context(IoSnafu)?;
150 writer.write_all(&[postscript_len]).context(IoSnafu)?;
152
153 Ok(())
155 }
156}
157
158fn serialize_schema(schema: &SchemaRef) -> Vec<proto::Type> {
159 let mut types = vec![];
160
161 let field_names = schema
162 .fields()
163 .iter()
164 .map(|f| f.name().to_owned())
165 .collect();
166 let subtypes = (1..(schema.fields().len() as u32 + 1)).collect();
168 let root_type = proto::Type {
169 kind: Some(proto::r#type::Kind::Struct.into()),
170 subtypes,
171 field_names,
172 maximum_length: None,
173 precision: None,
174 scale: None,
175 attributes: vec![],
176 };
177 types.push(root_type);
178 for field in schema.fields() {
179 let t = match field.data_type() {
180 ArrowDataType::Float32 => proto::Type {
181 kind: Some(proto::r#type::Kind::Float.into()),
182 ..Default::default()
183 },
184 ArrowDataType::Float64 => proto::Type {
185 kind: Some(proto::r#type::Kind::Double.into()),
186 ..Default::default()
187 },
188 ArrowDataType::Int8 => proto::Type {
189 kind: Some(proto::r#type::Kind::Byte.into()),
190 ..Default::default()
191 },
192 ArrowDataType::Int16 => proto::Type {
193 kind: Some(proto::r#type::Kind::Short.into()),
194 ..Default::default()
195 },
196 ArrowDataType::Int32 => proto::Type {
197 kind: Some(proto::r#type::Kind::Int.into()),
198 ..Default::default()
199 },
200 ArrowDataType::Int64 => proto::Type {
201 kind: Some(proto::r#type::Kind::Long.into()),
202 ..Default::default()
203 },
204 ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => proto::Type {
205 kind: Some(proto::r#type::Kind::String.into()),
206 ..Default::default()
207 },
208 ArrowDataType::Binary | ArrowDataType::LargeBinary => proto::Type {
209 kind: Some(proto::r#type::Kind::Binary.into()),
210 ..Default::default()
211 },
212 ArrowDataType::Boolean => proto::Type {
213 kind: Some(proto::r#type::Kind::Boolean.into()),
214 ..Default::default()
215 },
216 _ => unimplemented!("unsupported datatype"),
218 };
219 types.push(t);
220 }
221 types
222}
223
224fn serialize_footer(stripes: &[StripeInformation], schema: &SchemaRef) -> proto::Footer {
225 let body_length = stripes
226 .iter()
227 .map(|s| s.index_length + s.data_length + s.footer_length)
228 .sum::<u64>();
229 let number_of_rows = stripes.iter().map(|s| s.row_count as u64).sum::<u64>();
230 let stripes = stripes.iter().map(From::from).collect();
231 let types = serialize_schema(schema);
232 proto::Footer {
233 header_length: Some(3),
234 content_length: Some(body_length + 3),
235 stripes,
236 types,
237 metadata: vec![],
238 number_of_rows: Some(number_of_rows),
239 statistics: vec![],
240 row_index_stride: None,
241 writer: Some(u32::MAX),
242 encryption: None,
243 calendar: None,
244 software_version: None,
245 }
246}
247
248fn serialize_postscript(footer_length: u64) -> proto::PostScript {
249 proto::PostScript {
250 footer_length: Some(footer_length),
251 compression: Some(proto::CompressionKind::None.into()), compression_block_size: None,
253 version: vec![0, 12],
254 metadata_length: Some(0), writer_version: Some(u32::MAX), stripe_statistics_length: None,
257 magic: Some("ORC".to_string()),
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use std::sync::Arc;
264
265 use arrow::{
266 array::{
267 Array, BinaryArray, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
268 Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, RecordBatchReader,
269 StringArray,
270 },
271 buffer::NullBuffer,
272 compute::concat_batches,
273 datatypes::{DataType as ArrowDataType, Field, Schema},
274 };
275 use bytes::Bytes;
276
277 use crate::{stripe::Stripe, ArrowReaderBuilder};
278
279 use super::*;
280
281 fn roundtrip(batches: &[RecordBatch]) -> Vec<RecordBatch> {
282 let mut f = vec![];
283 let mut writer = ArrowWriterBuilder::new(&mut f, batches[0].schema())
284 .try_build()
285 .unwrap();
286 for batch in batches {
287 writer.write(batch).unwrap();
288 }
289 writer.close().unwrap();
290
291 let f = Bytes::from(f);
292 let reader = ArrowReaderBuilder::try_new(f).unwrap().build();
293 reader.collect::<Result<Vec<_>, _>>().unwrap()
294 }
295
296 #[test]
297 fn test_roundtrip_write() {
298 let f32_array = Arc::new(Float32Array::from(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]));
299 let f64_array = Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]));
300 let int8_array = Arc::new(Int8Array::from(vec![0, 1, 2, 3, 4, 5, 6]));
301 let int16_array = Arc::new(Int16Array::from(vec![0, 1, 2, 3, 4, 5, 6]));
302 let int32_array = Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6]));
303 let int64_array = Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, 6]));
304 let utf8_array = Arc::new(StringArray::from(vec![
305 "Hello",
306 "there",
307 "楡井希実",
308 "💯",
309 "ORC",
310 "",
311 "123",
312 ]));
313 let binary_array = Arc::new(BinaryArray::from(vec![
314 "Hello".as_bytes(),
315 "there".as_bytes(),
316 "楡井希実".as_bytes(),
317 "💯".as_bytes(),
318 "ORC".as_bytes(),
319 "".as_bytes(),
320 "123".as_bytes(),
321 ]));
322 let boolean_array = Arc::new(BooleanArray::from(vec![
323 true, false, true, false, true, true, false,
324 ]));
325 let schema = Schema::new(vec![
326 Field::new("f32", ArrowDataType::Float32, false),
327 Field::new("f64", ArrowDataType::Float64, false),
328 Field::new("int8", ArrowDataType::Int8, false),
329 Field::new("int16", ArrowDataType::Int16, false),
330 Field::new("int32", ArrowDataType::Int32, false),
331 Field::new("int64", ArrowDataType::Int64, false),
332 Field::new("utf8", ArrowDataType::Utf8, false),
333 Field::new("binary", ArrowDataType::Binary, false),
334 Field::new("boolean", ArrowDataType::Boolean, false),
335 ]);
336
337 let batch = RecordBatch::try_new(
338 Arc::new(schema),
339 vec![
340 f32_array,
341 f64_array,
342 int8_array,
343 int16_array,
344 int32_array,
345 int64_array,
346 utf8_array,
347 binary_array,
348 boolean_array,
349 ],
350 )
351 .unwrap();
352
353 let rows = roundtrip(std::slice::from_ref(&batch));
354 assert_eq!(batch, rows[0]);
355 }
356
357 #[test]
358 fn test_roundtrip_write_large_type() {
359 let large_utf8_array = Arc::new(LargeStringArray::from(vec![
360 "Hello",
361 "there",
362 "楡井希実",
363 "💯",
364 "ORC",
365 "",
366 "123",
367 ]));
368 let large_binary_array = Arc::new(LargeBinaryArray::from(vec![
369 "Hello".as_bytes(),
370 "there".as_bytes(),
371 "楡井希実".as_bytes(),
372 "💯".as_bytes(),
373 "ORC".as_bytes(),
374 "".as_bytes(),
375 "123".as_bytes(),
376 ]));
377 let schema = Schema::new(vec![
378 Field::new("large_utf8", ArrowDataType::LargeUtf8, false),
379 Field::new("large_binary", ArrowDataType::LargeBinary, false),
380 ]);
381 let batch =
382 RecordBatch::try_new(Arc::new(schema), vec![large_utf8_array, large_binary_array])
383 .unwrap();
384
385 let rows = roundtrip(&[batch]);
386
387 let utf8_array = Arc::new(StringArray::from(vec![
389 "Hello",
390 "there",
391 "楡井希実",
392 "💯",
393 "ORC",
394 "",
395 "123",
396 ]));
397 let binary_array = Arc::new(BinaryArray::from(vec![
398 "Hello".as_bytes(),
399 "there".as_bytes(),
400 "楡井希実".as_bytes(),
401 "💯".as_bytes(),
402 "ORC".as_bytes(),
403 "".as_bytes(),
404 "123".as_bytes(),
405 ]));
406 let schema = Schema::new(vec![
407 Field::new("large_utf8", ArrowDataType::Utf8, false),
408 Field::new("large_binary", ArrowDataType::Binary, false),
409 ]);
410 let batch = RecordBatch::try_new(Arc::new(schema), vec![utf8_array, binary_array]).unwrap();
411 assert_eq!(batch, rows[0]);
412 }
413
414 #[test]
415 fn test_write_small_stripes() {
416 let data: Vec<i64> = (0..1_000_000).collect();
418 let int64_array = Arc::new(Int64Array::from(data));
419 let schema = Schema::new(vec![Field::new("int64", ArrowDataType::Int64, true)]);
420
421 let batch = RecordBatch::try_new(Arc::new(schema), vec![int64_array]).unwrap();
422
423 let mut f = vec![];
424 let mut writer = ArrowWriterBuilder::new(&mut f, batch.schema())
425 .with_stripe_byte_size(256)
426 .try_build()
427 .unwrap();
428 writer.write(&batch).unwrap();
429 writer.close().unwrap();
430
431 let f = Bytes::from(f);
432 let reader = ArrowReaderBuilder::try_new(f).unwrap().build();
433 let schema = reader.schema();
434 let rows = reader.collect::<Result<Vec<_>, _>>().unwrap();
437 assert!(
438 rows.len() > 1,
439 "must have written more than 1 stripe (each stripe read as separate recordbatch)"
440 );
441 let actual = concat_batches(&schema, rows.iter()).unwrap();
442 assert_eq!(batch, actual);
443 }
444
445 #[test]
446 fn test_write_inconsistent_null_buffers() {
447 let schema = Arc::new(Schema::new(vec![Field::new(
449 "int64",
450 ArrowDataType::Int64,
451 true,
452 )]));
453
454 let array_no_nulls = Arc::new(Int64Array::from(vec![1, 2, 3]));
456 assert!(array_no_nulls.nulls().is_none());
457 let array_with_nulls = Arc::new(Int64Array::from(vec![None, Some(4), None]));
459 assert!(array_with_nulls.nulls().is_some());
460
461 let batch1 = RecordBatch::try_new(schema.clone(), vec![array_no_nulls]).unwrap();
462 let batch2 = RecordBatch::try_new(schema.clone(), vec![array_with_nulls]).unwrap();
463
464 let expected_array = Arc::new(Int64Array::from(vec![
466 Some(1),
467 Some(2),
468 Some(3),
469 None,
470 Some(4),
471 None,
472 ]));
473 let expected_batch = RecordBatch::try_new(schema, vec![expected_array]).unwrap();
474
475 let rows = roundtrip(&[batch1, batch2]);
476 assert_eq!(expected_batch, rows[0]);
477 }
478
479 #[test]
480 fn test_empty_null_buffers() {
481 let schema = Arc::new(Schema::new(vec![Field::new(
485 "int64",
486 ArrowDataType::Int64,
487 true,
488 )]));
489
490 let array_empty_nulls = Arc::new(Int64Array::from_iter_values_with_nulls(
492 vec![1],
493 Some(NullBuffer::from_iter(vec![true])),
494 ));
495 assert!(array_empty_nulls.nulls().is_some());
496 assert!(array_empty_nulls.null_count() == 0);
497
498 let batch = RecordBatch::try_new(schema, vec![array_empty_nulls]).unwrap();
499
500 let mut f = vec![];
502 let mut writer = ArrowWriterBuilder::new(&mut f, batch.schema())
503 .try_build()
504 .unwrap();
505 writer.write(&batch).unwrap();
506 writer.close().unwrap();
507 let mut f = Bytes::from(f);
508 let builder = ArrowReaderBuilder::try_new(f.clone()).unwrap();
509
510 let stripe = Stripe::new(
512 &mut f,
513 &builder.file_metadata,
514 builder.file_metadata().root_data_type(),
515 &builder.file_metadata().stripe_metadatas()[0],
516 )
517 .unwrap();
518 assert_eq!(stripe.columns().len(), 1);
519 assert_eq!(stripe.columns()[0].name(), "int64");
521 let present_stream = stripe
523 .stream_map()
524 .get_opt(&stripe.columns()[0], proto::stream::Kind::Present);
525 assert!(present_stream.is_some());
526
527 let reader = builder.build();
529 let rows = reader.collect::<Result<Vec<_>, _>>().unwrap();
530
531 assert_eq!(rows.len(), 1);
532 assert_eq!(rows[0].num_columns(), 1);
533 assert!(rows[0].column(0).nulls().is_none());
535 }
536}