use crate::codec::AvroFieldBuilder;
use crate::compression::CompressionCodec;
use crate::errors::AvroError;
use crate::schema::{
AvroSchema, Fingerprint, FingerprintAlgorithm, FingerprintStrategy, SCHEMA_METADATA_KEY,
};
use crate::writer::encoder::{RecordEncoder, RecordEncoderBuilder, write_long};
use crate::writer::format::{AvroFormat, AvroOcfFormat, AvroSoeFormat};
use arrow_array::RecordBatch;
use arrow_schema::{Schema, SchemaRef};
use bytes::{Bytes, BytesMut};
use std::io::Write;
use std::sync::Arc;
mod encoder;
pub mod format;
#[derive(Debug, Clone)]
pub struct EncodedRows {
data: Bytes,
offsets: Vec<usize>,
}
impl EncodedRows {
pub fn new(data: Bytes, offsets: Vec<usize>) -> Self {
Self { data, offsets }
}
#[inline]
pub fn len(&self) -> usize {
self.offsets.len().saturating_sub(1)
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn bytes(&self) -> &Bytes {
&self.data
}
#[inline]
pub fn offsets(&self) -> &[usize] {
&self.offsets
}
pub fn row(&self, n: usize) -> Result<Bytes, AvroError> {
if n >= self.len() {
return Err(AvroError::General(format!(
"Row index {n} out of bounds for len {}",
self.len()
)));
}
let (start, end) = unsafe {
(
*self.offsets.get_unchecked(n),
*self.offsets.get_unchecked(n + 1),
)
};
if start > end || end > self.data.len() {
return Err(AvroError::General(format!(
"Invalid row offsets for row {n}: start={start}, end={end}, data_len={}",
self.data.len()
)));
}
Ok(self.data.slice(start..end))
}
#[inline]
pub fn iter(&self) -> impl ExactSizeIterator<Item = Bytes> + '_ {
self.offsets.windows(2).map(|w| self.data.slice(w[0]..w[1]))
}
}
#[derive(Debug, Clone)]
pub struct WriterBuilder {
schema: Schema,
codec: Option<CompressionCodec>,
row_capacity: Option<usize>,
capacity: usize,
fingerprint_strategy: Option<FingerprintStrategy>,
}
impl WriterBuilder {
pub fn new(schema: Schema) -> Self {
Self {
schema,
codec: None,
row_capacity: None,
capacity: 1024,
fingerprint_strategy: None,
}
}
pub fn with_fingerprint_strategy(mut self, strategy: FingerprintStrategy) -> Self {
self.fingerprint_strategy = Some(strategy);
self
}
pub fn with_compression(mut self, codec: Option<CompressionCodec>) -> Self {
self.codec = codec;
self
}
pub fn with_capacity(mut self, capacity: usize) -> Self {
self.capacity = capacity;
self
}
pub fn with_row_capacity(mut self, capacity: usize) -> Self {
self.row_capacity = Some(capacity);
self
}
fn prepare_encoder<F: AvroFormat>(&self) -> Result<(Arc<Schema>, RecordEncoder), AvroError> {
let avro_schema = match self.schema.metadata.get(SCHEMA_METADATA_KEY) {
Some(json) => AvroSchema::new(json.clone()),
None => AvroSchema::try_from(&self.schema)?,
};
let maybe_fingerprint = if F::NEEDS_PREFIX {
match &self.fingerprint_strategy {
Some(FingerprintStrategy::Id(id)) => Some(Fingerprint::Id(*id)),
Some(FingerprintStrategy::Id64(id)) => Some(Fingerprint::Id64(*id)),
Some(strategy) => {
Some(avro_schema.fingerprint(FingerprintAlgorithm::from(*strategy))?)
}
None => Some(
avro_schema
.fingerprint(FingerprintAlgorithm::from(FingerprintStrategy::Rabin))?,
),
}
} else {
None
};
let mut md = self.schema.metadata().clone();
md.insert(
SCHEMA_METADATA_KEY.to_string(),
avro_schema.clone().json_string,
);
let schema = Arc::new(Schema::new_with_metadata(self.schema.fields().clone(), md));
let avro_root = AvroFieldBuilder::new(&avro_schema.schema()?).build()?;
let encoder = RecordEncoderBuilder::new(&avro_root, schema.as_ref())
.with_fingerprint(maybe_fingerprint)
.build()?;
Ok((schema, encoder))
}
pub fn build_encoder<F: AvroFormat>(self) -> Result<Encoder, AvroError> {
if F::default().sync_marker().is_some() {
return Err(AvroError::InvalidArgument(
"Encoder only supports stream formats (no OCF header/sync marker)".to_string(),
));
}
let (schema, encoder) = self.prepare_encoder::<F>()?;
Ok(Encoder {
schema,
encoder,
row_capacity: self.row_capacity,
buffer: BytesMut::with_capacity(self.capacity),
offsets: vec![0],
})
}
pub fn build<W, F>(self, mut writer: W) -> Result<Writer<W, F>, AvroError>
where
W: Write,
F: AvroFormat,
{
let mut format = F::default();
if format.sync_marker().is_none() && !F::NEEDS_PREFIX {
return Err(AvroError::InvalidArgument(
"AvroBinaryFormat is only supported with Encoder, use build_encoder instead"
.to_string(),
));
}
let (schema, encoder) = self.prepare_encoder::<F>()?;
format.start_stream(&mut writer, &schema, self.codec)?;
Ok(Writer {
writer,
schema,
format,
compression: self.codec,
capacity: self.capacity,
encoder,
})
}
}
#[derive(Debug)]
pub struct Encoder {
schema: SchemaRef,
encoder: RecordEncoder,
row_capacity: Option<usize>,
buffer: BytesMut,
offsets: Vec<usize>,
}
impl Encoder {
pub fn encode(&mut self, batch: &RecordBatch) -> Result<(), AvroError> {
if batch.schema().fields() != self.schema.fields() {
return Err(AvroError::SchemaError(
"Schema of RecordBatch differs from Writer schema".to_string(),
));
}
self.encoder.encode_rows(
batch,
self.row_capacity.unwrap_or(0),
&mut self.buffer,
&mut self.offsets,
)?;
Ok(())
}
pub fn encode_batches(&mut self, batches: &[RecordBatch]) -> Result<(), AvroError> {
for b in batches {
self.encode(b)?;
}
Ok(())
}
pub fn flush(&mut self) -> EncodedRows {
let data = self.buffer.split().freeze();
let mut offsets = Vec::with_capacity(self.offsets.len());
offsets.append(&mut self.offsets);
self.offsets.push(0);
EncodedRows::new(data, offsets)
}
pub fn schema(&self) -> SchemaRef {
self.schema.clone()
}
pub fn buffered_len(&self) -> usize {
self.offsets.len().saturating_sub(1)
}
}
#[derive(Debug)]
pub struct Writer<W: Write, F: AvroFormat> {
writer: W,
schema: SchemaRef,
format: F,
compression: Option<CompressionCodec>,
capacity: usize,
encoder: RecordEncoder,
}
pub type AvroWriter<W> = Writer<W, AvroOcfFormat>;
pub type AvroStreamWriter<W> = Writer<W, AvroSoeFormat>;
impl<W: Write> Writer<W, AvroOcfFormat> {
pub fn new(writer: W, schema: Schema) -> Result<Self, AvroError> {
WriterBuilder::new(schema).build::<W, AvroOcfFormat>(writer)
}
pub fn sync_marker(&self) -> Option<&[u8; 16]> {
self.format.sync_marker()
}
}
impl<W: Write> Writer<W, AvroSoeFormat> {
pub fn new(writer: W, schema: Schema) -> Result<Self, AvroError> {
WriterBuilder::new(schema).build::<W, AvroSoeFormat>(writer)
}
}
impl<W: Write, F: AvroFormat> Writer<W, F> {
pub fn write(&mut self, batch: &RecordBatch) -> Result<(), AvroError> {
if batch.schema().fields() != self.schema.fields() {
return Err(AvroError::SchemaError(
"Schema of RecordBatch differs from Writer schema".to_string(),
));
}
match self.format.sync_marker() {
Some(&sync) => self.write_ocf_block(batch, &sync),
None => self.write_stream(batch),
}
}
pub fn write_batches(&mut self, batches: &[&RecordBatch]) -> Result<(), AvroError> {
for b in batches {
self.write(b)?;
}
Ok(())
}
pub fn finish(&mut self) -> Result<(), AvroError> {
self.writer
.flush()
.map_err(|e| AvroError::IoError(format!("Error flushing writer: {e}"), e))
}
pub fn into_inner(self) -> W {
self.writer
}
fn write_ocf_block(&mut self, batch: &RecordBatch, sync: &[u8; 16]) -> Result<(), AvroError> {
let mut buf = Vec::<u8>::with_capacity(self.capacity);
self.encoder.encode(&mut buf, batch)?;
let encoded = match self.compression {
Some(codec) => codec.compress(&buf)?,
None => buf,
};
write_long(&mut self.writer, batch.num_rows() as i64)?;
write_long(&mut self.writer, encoded.len() as i64)?;
self.writer
.write_all(&encoded)
.map_err(|e| AvroError::IoError(format!("Error writing Avro block: {e}"), e))?;
self.writer
.write_all(sync)
.map_err(|e| AvroError::IoError(format!("Error writing Avro sync: {e}"), e))?;
Ok(())
}
fn write_stream(&mut self, batch: &RecordBatch) -> Result<(), AvroError> {
self.encoder.encode(&mut self.writer, batch)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compression::CompressionCodec;
use crate::reader::ReaderBuilder;
use crate::schema::AVRO_NAME_METADATA_KEY;
use crate::schema::{AvroSchema, SchemaStore};
use crate::test_util::arrow_test_data;
use arrow::datatypes::TimeUnit;
use arrow::util::pretty::pretty_format_batches;
#[cfg(not(feature = "avro_custom_types"))]
use arrow_array::Float32Array;
#[cfg(feature = "avro_custom_types")]
use arrow_array::RunArray;
use arrow_array::builder::{Int32Builder, ListBuilder};
use arrow_array::cast::AsArray;
#[cfg(feature = "avro_custom_types")]
use arrow_array::types::{Int16Type, Int64Type};
use arrow_array::types::{
Int32Type, Time32MillisecondType, Time64MicrosecondType, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType,
};
use arrow_array::{
Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Float16Array,
Int8Array, Int16Array, Int32Array, Int64Array, IntervalDayTimeArray,
IntervalMonthDayNanoArray, IntervalYearMonthArray, PrimitiveArray, RecordBatch,
StringArray, StructArray, Time32MillisecondArray, Time32SecondArray,
Time64MicrosecondArray, Time64NanosecondArray, TimestampMillisecondArray,
TimestampSecondArray, UInt8Array, UInt16Array, UInt32Array, UInt64Array, UnionArray,
};
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
#[cfg(not(feature = "avro_custom_types"))]
use arrow_schema::{DataType, Field, Schema};
#[cfg(feature = "avro_custom_types")]
use arrow_schema::{DataType, Field, Schema};
use arrow_schema::{IntervalUnit, UnionMode};
use bytes::BytesMut;
use half::f16;
use serde_json::{Value, json};
use std::collections::HashMap;
use std::collections::HashSet;
use std::fs::File;
use std::io::{BufReader, Cursor};
use std::path::PathBuf;
use std::sync::Arc;
use tempfile::NamedTempFile;
fn files() -> impl Iterator<Item = &'static str> {
[
#[cfg(feature = "snappy")]
"avro/alltypes_plain.avro",
#[cfg(feature = "snappy")]
"avro/alltypes_plain.snappy.avro",
#[cfg(feature = "zstd")]
"avro/alltypes_plain.zstandard.avro",
#[cfg(feature = "bzip2")]
"avro/alltypes_plain.bzip2.avro",
#[cfg(feature = "xz")]
"avro/alltypes_plain.xz.avro",
]
.into_iter()
}
fn make_schema() -> Schema {
Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Binary, false),
])
}
fn make_batch() -> RecordBatch {
let ids = Int32Array::from(vec![1, 2, 3]);
let names = BinaryArray::from_vec(vec![b"a".as_ref(), b"b".as_ref(), b"c".as_ref()]);
RecordBatch::try_new(
Arc::new(make_schema()),
vec![Arc::new(ids) as ArrayRef, Arc::new(names) as ArrayRef],
)
.expect("failed to build test RecordBatch")
}
#[test]
fn test_stream_writer_writes_prefix_per_row_rt() -> Result<(), AvroError> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef],
)?;
let buf: Vec<u8> = Vec::new();
let mut writer = AvroStreamWriter::new(buf, schema.clone())?;
writer.write(&batch)?;
let encoded = writer.into_inner();
let mut store = SchemaStore::new(); let avro_schema = AvroSchema::try_from(&schema)?;
let _fp = store.register(avro_schema)?;
let mut decoder = ReaderBuilder::new()
.with_writer_schema_store(store)
.build_decoder()?;
let _consumed = decoder.decode(&encoded)?;
let decoded = decoder
.flush()?
.expect("expected at least one batch from decoder");
assert_eq!(decoded.num_columns(), 1);
assert_eq!(decoded.num_rows(), 2);
let col = decoded.column(0).as_primitive::<Int32Type>();
assert_eq!(col, &Int32Array::from(vec![10, 20]));
Ok(())
}
#[test]
fn test_nullable_struct_with_nonnullable_field_sliced_encoding() {
use arrow_array::{ArrayRef, Int32Array, StringArray, StructArray};
use arrow_buffer::NullBuffer;
use arrow_schema::{DataType, Field, Fields, Schema};
use std::sync::Arc;
let inner_fields = Fields::from(vec![
Field::new("id", DataType::Int32, false), Field::new("name", DataType::Utf8, true), ]);
let inner_struct_type = DataType::Struct(inner_fields.clone());
let schema = Schema::new(vec![
Field::new("before", inner_struct_type.clone(), true), Field::new("after", inner_struct_type.clone(), true), Field::new("op", DataType::Utf8, false), ]);
let before_ids = Int32Array::from(vec![None, None]);
let before_names = StringArray::from(vec![None::<&str>, None]);
let before_struct = StructArray::new(
inner_fields.clone(),
vec![
Arc::new(before_ids) as ArrayRef,
Arc::new(before_names) as ArrayRef,
],
Some(NullBuffer::from(vec![false, false])),
);
let after_ids = Int32Array::from(vec![1, 2]); let after_names = StringArray::from(vec![Some("Alice"), Some("Bob")]);
let after_struct = StructArray::new(
inner_fields.clone(),
vec![
Arc::new(after_ids) as ArrayRef,
Arc::new(after_names) as ArrayRef,
],
Some(NullBuffer::from(vec![true, true])),
);
let op_col = StringArray::from(vec!["r", "r"]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(before_struct) as ArrayRef,
Arc::new(after_struct) as ArrayRef,
Arc::new(op_col) as ArrayRef,
],
)
.expect("failed to create test batch");
let mut sink = Vec::new();
let mut writer = WriterBuilder::new(schema)
.with_fingerprint_strategy(FingerprintStrategy::Id(1))
.build::<_, AvroSoeFormat>(&mut sink)
.expect("failed to create writer");
for row_idx in 0..batch.num_rows() {
let single_row = batch.slice(row_idx, 1);
let after_col = single_row.column(1);
assert_eq!(
after_col.null_count(),
0,
"after column should have no nulls in sliced row"
);
writer
.write(&single_row)
.unwrap_or_else(|e| panic!("Failed to encode row {row_idx}: {e}"));
}
writer.finish().expect("failed to finish writer");
assert!(!sink.is_empty(), "encoded output should not be empty");
}
#[test]
fn test_nullable_struct_with_decimal_and_timestamp_sliced() {
use arrow_array::{
ArrayRef, Decimal128Array, Int32Array, StringArray, StructArray,
TimestampMicrosecondArray,
};
use arrow_buffer::NullBuffer;
use arrow_schema::{DataType, Field, Fields, Schema};
use std::sync::Arc;
let row_fields = Fields::from(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
Field::new("category", DataType::Utf8, true),
Field::new("price", DataType::Decimal128(10, 2), true),
Field::new("stock_quantity", DataType::Int32, true),
Field::new(
"created_at",
DataType::Timestamp(TimeUnit::Microsecond, None),
true,
),
]);
let row_struct_type = DataType::Struct(row_fields.clone());
let schema = Schema::new(vec![
Field::new("before", row_struct_type.clone(), true),
Field::new("after", row_struct_type.clone(), true),
Field::new("op", DataType::Utf8, false),
]);
let before_struct = StructArray::new_null(row_fields.clone(), 2);
let ids = Int32Array::from(vec![1, 2]);
let names = StringArray::from(vec![Some("Widget"), Some("Gadget")]);
let categories = StringArray::from(vec![Some("Electronics"), Some("Electronics")]);
let prices = Decimal128Array::from(vec![Some(1999), Some(2999)])
.with_precision_and_scale(10, 2)
.unwrap();
let quantities = Int32Array::from(vec![Some(100), Some(50)]);
let timestamps = TimestampMicrosecondArray::from(vec![
Some(1700000000000000i64),
Some(1700000001000000i64),
]);
let after_struct = StructArray::new(
row_fields.clone(),
vec![
Arc::new(ids) as ArrayRef,
Arc::new(names) as ArrayRef,
Arc::new(categories) as ArrayRef,
Arc::new(prices) as ArrayRef,
Arc::new(quantities) as ArrayRef,
Arc::new(timestamps) as ArrayRef,
],
Some(NullBuffer::from(vec![true, true])),
);
let op_col = StringArray::from(vec!["r", "r"]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(before_struct) as ArrayRef,
Arc::new(after_struct) as ArrayRef,
Arc::new(op_col) as ArrayRef,
],
)
.expect("failed to create products batch");
let mut sink = Vec::new();
let mut writer = WriterBuilder::new(schema)
.with_fingerprint_strategy(FingerprintStrategy::Id(1))
.build::<_, AvroSoeFormat>(&mut sink)
.expect("failed to create writer");
for row_idx in 0..batch.num_rows() {
let single_row = batch.slice(row_idx, 1);
writer
.write(&single_row)
.unwrap_or_else(|e| panic!("Failed to encode product row {row_idx}: {e}"));
}
writer.finish().expect("failed to finish writer");
assert!(!sink.is_empty());
}
#[test]
fn non_nullable_child_in_nullable_struct_should_encode_per_row() {
use arrow_array::{
ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray, StructArray,
};
use arrow_schema::{DataType, Field, Fields, Schema};
use std::sync::Arc;
let row_fields = Fields::from(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]);
let row_struct_dt = DataType::Struct(row_fields.clone());
let before: ArrayRef = Arc::new(StructArray::new_null(row_fields.clone(), 1));
let id_col: ArrayRef = Arc::new(Int32Array::from(vec![1]));
let name_col: ArrayRef = Arc::new(StringArray::from(vec![None::<&str>]));
let after: ArrayRef = Arc::new(StructArray::new(
row_fields.clone(),
vec![id_col, name_col],
None,
));
let schema = Arc::new(Schema::new(vec![
Field::new("before", row_struct_dt.clone(), true),
Field::new("after", row_struct_dt, true),
Field::new("op", DataType::Utf8, false),
Field::new("ts_ms", DataType::Int64, false),
]));
let op = Arc::new(StringArray::from(vec!["r"])) as ArrayRef;
let ts_ms = Arc::new(Int64Array::from(vec![1732900000000_i64])) as ArrayRef;
let batch = RecordBatch::try_new(schema.clone(), vec![before, after, op, ts_ms]).unwrap();
let mut buf = Vec::new();
let mut writer = WriterBuilder::new(schema.as_ref().clone())
.build::<_, AvroSoeFormat>(&mut buf)
.unwrap();
let single = batch.slice(0, 1);
let res = writer.write(&single);
assert!(
res.is_ok(),
"expected to encode successfully, got: {:?}",
res.err()
);
}
#[test]
fn test_union_nonzero_type_ids() -> Result<(), AvroError> {
use arrow_array::UnionArray;
use arrow_buffer::Buffer;
use arrow_schema::UnionFields;
let union_fields = UnionFields::try_new(
vec![2, 5],
vec![
Field::new("v_str", DataType::Utf8, true),
Field::new("v_int", DataType::Int32, true),
],
)
.unwrap();
let strings = StringArray::from(vec!["hello", "world"]);
let ints = Int32Array::from(vec![10, 20, 30]);
let type_ids = Buffer::from_slice_ref([2_i8, 5, 5, 2, 5]);
let offsets = Buffer::from_slice_ref([0_i32, 0, 1, 1, 2]);
let union_array = UnionArray::try_new(
union_fields.clone(),
type_ids.into(),
Some(offsets.into()),
vec![Arc::new(strings) as ArrayRef, Arc::new(ints) as ArrayRef],
)?;
let schema = Schema::new(vec![Field::new(
"union_col",
DataType::Union(union_fields, UnionMode::Dense),
false,
)]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(union_array) as ArrayRef],
)?;
let mut writer = AvroWriter::new(Vec::<u8>::new(), schema.clone())?;
assert!(
writer.write(&batch).is_ok(),
"Expected no error from writing"
);
writer.finish()?;
assert!(
writer.finish().is_ok(),
"Expected no error from finishing writer"
);
Ok(())
}
#[test]
fn test_stream_writer_with_id_fingerprint_rt() -> Result<(), AvroError> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef],
)?;
let schema_id: u32 = 42;
let mut writer = WriterBuilder::new(schema.clone())
.with_fingerprint_strategy(FingerprintStrategy::Id(schema_id))
.build::<_, AvroSoeFormat>(Vec::new())?;
writer.write(&batch)?;
let encoded = writer.into_inner();
let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::Id);
let avro_schema = AvroSchema::try_from(&schema)?;
let _ = store.set(Fingerprint::Id(schema_id), avro_schema)?;
let mut decoder = ReaderBuilder::new()
.with_writer_schema_store(store)
.build_decoder()?;
let _ = decoder.decode(&encoded)?;
let decoded = decoder
.flush()?
.expect("expected at least one batch from decoder");
assert_eq!(decoded.num_columns(), 1);
assert_eq!(decoded.num_rows(), 3);
let col = decoded.column(0).as_primitive::<Int32Type>();
assert_eq!(col, &Int32Array::from(vec![1, 2, 3]));
Ok(())
}
#[test]
fn test_stream_writer_with_id64_fingerprint_rt() -> Result<(), AvroError> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef],
)?;
let schema_id: u64 = 42;
let mut writer = WriterBuilder::new(schema.clone())
.with_fingerprint_strategy(FingerprintStrategy::Id64(schema_id))
.build::<_, AvroSoeFormat>(Vec::new())?;
writer.write(&batch)?;
let encoded = writer.into_inner();
let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::Id64);
let avro_schema = AvroSchema::try_from(&schema)?;
let _ = store.set(Fingerprint::Id64(schema_id), avro_schema)?;
let mut decoder = ReaderBuilder::new()
.with_writer_schema_store(store)
.build_decoder()?;
let _ = decoder.decode(&encoded)?;
let decoded = decoder
.flush()?
.expect("expected at least one batch from decoder");
assert_eq!(decoded.num_columns(), 1);
assert_eq!(decoded.num_rows(), 3);
let col = decoded.column(0).as_primitive::<Int32Type>();
assert_eq!(col, &Int32Array::from(vec![1, 2, 3]));
Ok(())
}
#[test]
fn test_ocf_writer_generates_header_and_sync() -> Result<(), AvroError> {
let batch = make_batch();
let buffer: Vec<u8> = Vec::new();
let mut writer = AvroWriter::new(buffer, make_schema())?;
writer.write(&batch)?;
writer.finish()?;
let out = writer.into_inner();
assert_eq!(&out[..4], b"Obj\x01", "OCF magic bytes missing/incorrect");
let trailer = &out[out.len() - 16..];
assert_eq!(trailer.len(), 16, "expected 16‑byte sync marker");
Ok(())
}
#[test]
fn test_schema_mismatch_yields_error() {
let batch = make_batch();
let alt_schema = Schema::new(vec![Field::new("x", DataType::Int32, false)]);
let buffer = Vec::<u8>::new();
let mut writer = AvroWriter::new(buffer, alt_schema).unwrap();
let err = writer.write(&batch).unwrap_err();
assert!(matches!(err, AvroError::SchemaError(_)));
}
#[test]
fn test_write_batches_accumulates_multiple() -> Result<(), AvroError> {
let batch1 = make_batch();
let batch2 = make_batch();
let buffer = Vec::<u8>::new();
let mut writer = AvroWriter::new(buffer, make_schema())?;
writer.write_batches(&[&batch1, &batch2])?;
writer.finish()?;
let out = writer.into_inner();
assert!(out.len() > 4, "combined batches produced tiny file");
Ok(())
}
#[test]
fn test_finish_without_write_adds_header() -> Result<(), AvroError> {
let buffer = Vec::<u8>::new();
let mut writer = AvroWriter::new(buffer, make_schema())?;
writer.finish()?;
let out = writer.into_inner();
assert_eq!(&out[..4], b"Obj\x01", "finish() should emit OCF header");
Ok(())
}
#[test]
fn test_write_long_encodes_zigzag_varint() -> Result<(), AvroError> {
let mut buf = Vec::new();
write_long(&mut buf, 0)?;
write_long(&mut buf, -1)?;
write_long(&mut buf, 1)?;
write_long(&mut buf, -2)?;
write_long(&mut buf, 2147483647)?;
assert!(
buf.starts_with(&[0x00, 0x01, 0x02, 0x03]),
"zig‑zag varint encodings incorrect: {buf:?}"
);
Ok(())
}
#[test]
fn test_roundtrip_alltypes_roundtrip_writer() -> Result<(), AvroError> {
for rel in files() {
let path = arrow_test_data(rel);
let rdr_file = File::open(&path).expect("open input avro");
let reader = ReaderBuilder::new()
.build(BufReader::new(rdr_file))
.expect("build reader");
let schema = reader.schema();
let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
let original =
arrow::compute::concat_batches(&schema, &input_batches).expect("concat input");
let tmp = NamedTempFile::new().expect("create temp file");
let out_path = tmp.into_temp_path();
let out_file = File::create(&out_path).expect("create temp avro");
let codec = if rel.contains(".snappy.") {
Some(CompressionCodec::Snappy)
} else if rel.contains(".zstandard.") {
Some(CompressionCodec::ZStandard)
} else if rel.contains(".bzip2.") {
Some(CompressionCodec::Bzip2)
} else if rel.contains(".xz.") {
Some(CompressionCodec::Xz)
} else {
None
};
let mut writer = WriterBuilder::new(original.schema().as_ref().clone())
.with_compression(codec)
.build::<_, AvroOcfFormat>(out_file)?;
writer.write(&original)?;
writer.finish()?;
drop(writer);
let rt_file = File::open(&out_path).expect("open roundtrip avro");
let rt_reader = ReaderBuilder::new()
.build(BufReader::new(rt_file))
.expect("build roundtrip reader");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let roundtrip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip");
assert_eq!(
roundtrip, original,
"Round-trip batch mismatch for file: {}",
rel
);
}
Ok(())
}
#[test]
fn test_roundtrip_nested_records_writer() -> Result<(), AvroError> {
let path = arrow_test_data("avro/nested_records.avro");
let rdr_file = File::open(&path).expect("open nested_records.avro");
let reader = ReaderBuilder::new()
.build(BufReader::new(rdr_file))
.expect("build reader for nested_records.avro");
let schema = reader.schema();
let batches = reader.collect::<Result<Vec<_>, _>>()?;
let original = arrow::compute::concat_batches(&schema, &batches).expect("concat original");
let tmp = NamedTempFile::new().expect("create temp file");
let out_path = tmp.into_temp_path();
{
let out_file = File::create(&out_path).expect("create output avro");
let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?;
writer.write(&original)?;
writer.finish()?;
}
let rt_file = File::open(&out_path).expect("open round_trip avro");
let rt_reader = ReaderBuilder::new()
.build(BufReader::new(rt_file))
.expect("build round_trip reader");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let round_trip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
assert_eq!(
round_trip, original,
"Round-trip batch mismatch for nested_records.avro"
);
Ok(())
}
#[test]
#[cfg(feature = "snappy")]
fn test_roundtrip_nested_lists_writer() -> Result<(), AvroError> {
let path = arrow_test_data("avro/nested_lists.snappy.avro");
let rdr_file = File::open(&path).expect("open nested_lists.snappy.avro");
let reader = ReaderBuilder::new()
.build(BufReader::new(rdr_file))
.expect("build reader for nested_lists.snappy.avro");
let schema = reader.schema();
let batches = reader.collect::<Result<Vec<_>, _>>()?;
let original = arrow::compute::concat_batches(&schema, &batches).expect("concat original");
let tmp = NamedTempFile::new().expect("create temp file");
let out_path = tmp.into_temp_path();
{
let out_file = File::create(&out_path).expect("create output avro");
let mut writer = WriterBuilder::new(original.schema().as_ref().clone())
.with_compression(Some(CompressionCodec::Snappy))
.build::<_, AvroOcfFormat>(out_file)?;
writer.write(&original)?;
writer.finish()?;
}
let rt_file = File::open(&out_path).expect("open round_trip avro");
let rt_reader = ReaderBuilder::new()
.build(BufReader::new(rt_file))
.expect("build round_trip reader");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let round_trip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
assert_eq!(
round_trip, original,
"Round-trip batch mismatch for nested_lists.snappy.avro"
);
Ok(())
}
#[test]
fn test_round_trip_simple_fixed_ocf() -> Result<(), AvroError> {
let path = arrow_test_data("avro/simple_fixed.avro");
let rdr_file = File::open(&path).expect("open avro/simple_fixed.avro");
let reader = ReaderBuilder::new()
.build(BufReader::new(rdr_file))
.expect("build avro reader");
let schema = reader.schema();
let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
let original =
arrow::compute::concat_batches(&schema, &input_batches).expect("concat input");
let tmp = NamedTempFile::new().expect("create temp file");
let out_file = File::create(tmp.path()).expect("create temp avro");
let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?;
writer.write(&original)?;
writer.finish()?;
drop(writer);
let rt_file = File::open(tmp.path()).expect("open round_trip avro");
let rt_reader = ReaderBuilder::new()
.build(BufReader::new(rt_file))
.expect("build round_trip reader");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let round_trip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
assert_eq!(round_trip, original);
Ok(())
}
#[test]
#[cfg(feature = "canonical_extension_types")]
fn test_round_trip_duration_and_uuid_ocf() -> Result<(), AvroError> {
use arrow_schema::{DataType, IntervalUnit};
let in_file =
File::open("test/data/duration_uuid.avro").expect("open test/data/duration_uuid.avro");
let reader = ReaderBuilder::new()
.build(BufReader::new(in_file))
.expect("build reader for duration_uuid.avro");
let in_schema = reader.schema();
let has_mdn = in_schema.fields().iter().any(|f| {
matches!(
f.data_type(),
DataType::Interval(IntervalUnit::MonthDayNano)
)
});
assert!(
has_mdn,
"expected at least one Interval(MonthDayNano) field in duration_uuid.avro"
);
let has_uuid_fixed = in_schema
.fields()
.iter()
.any(|f| matches!(f.data_type(), DataType::FixedSizeBinary(16)));
assert!(
has_uuid_fixed,
"expected at least one FixedSizeBinary(16) (uuid) field in duration_uuid.avro"
);
let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
let input =
arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
let mut writer = AvroWriter::new(Vec::<u8>::new(), in_schema.as_ref().clone())?;
writer.write(&input)?;
writer.finish()?;
let bytes = writer.into_inner();
let rt_reader = ReaderBuilder::new()
.build(Cursor::new(bytes))
.expect("build round_trip reader");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let round_trip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
assert_eq!(round_trip, input);
Ok(())
}
#[test]
#[cfg(not(feature = "canonical_extension_types"))]
fn test_duration_and_uuid_ocf_without_extensions_round_trips_values() -> Result<(), AvroError> {
use arrow::datatypes::{DataType, IntervalUnit};
use std::io::BufReader;
let in_file =
File::open("test/data/duration_uuid.avro").expect("open test/data/duration_uuid.avro");
let reader = ReaderBuilder::new()
.build(BufReader::new(in_file))
.expect("build reader for duration_uuid.avro");
let in_schema = reader.schema();
assert!(
in_schema.fields().iter().any(|f| {
matches!(
f.data_type(),
DataType::Interval(IntervalUnit::MonthDayNano)
)
}),
"expected at least one Interval(MonthDayNano) field"
);
assert!(
in_schema
.fields()
.iter()
.any(|f| matches!(f.data_type(), DataType::FixedSizeBinary(16))),
"expected a FixedSizeBinary(16) field (uuid)"
);
let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
let input =
arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
let mut writer = AvroWriter::new(Vec::<u8>::new(), in_schema.as_ref().clone())?;
writer.write(&input)?;
writer.finish()?;
let bytes = writer.into_inner();
let rt_reader = ReaderBuilder::new()
.build(Cursor::new(bytes))
.expect("build round_trip reader");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let round_trip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
assert_eq!(
round_trip.column(0),
input.column(0),
"duration column values differ"
);
assert_eq!(round_trip.column(1), input.column(1), "uuid bytes differ");
let uuid_rt = rt_schema.field_with_name("uuid_field")?;
assert_eq!(uuid_rt.data_type(), &DataType::FixedSizeBinary(16));
assert_eq!(
uuid_rt.metadata().get("logicalType").map(|s| s.as_str()),
Some("uuid"),
"expected `logicalType = \"uuid\"` on round-tripped field metadata"
);
let dur_rt = rt_schema.field_with_name("duration_field")?;
assert!(matches!(
dur_rt.data_type(),
DataType::Interval(IntervalUnit::MonthDayNano)
));
Ok(())
}
#[test]
#[cfg(feature = "snappy")]
fn test_nonnullable_impala_roundtrip_writer() -> Result<(), AvroError> {
let path = arrow_test_data("avro/nonnullable.impala.avro");
let rdr_file = File::open(&path).expect("open avro/nonnullable.impala.avro");
let reader = ReaderBuilder::new()
.build(BufReader::new(rdr_file))
.expect("build reader for nonnullable.impala.avro");
let in_schema = reader.schema();
let has_map = in_schema
.fields()
.iter()
.any(|f| matches!(f.data_type(), DataType::Map(_, _)));
assert!(
has_map,
"expected at least one Map field in avro/nonnullable.impala.avro"
);
let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
let original =
arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
let buffer = Vec::<u8>::new();
let mut writer = AvroWriter::new(buffer, in_schema.as_ref().clone())?;
writer.write(&original)?;
writer.finish()?;
let out_bytes = writer.into_inner();
let rt_reader = ReaderBuilder::new()
.build(Cursor::new(out_bytes))
.expect("build reader for round-tripped in-memory OCF");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let roundtrip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip");
assert_eq!(
roundtrip, original,
"Round-trip Avro map data mismatch for nonnullable.impala.avro"
);
Ok(())
}
#[test]
#[cfg(feature = "snappy")]
fn test_roundtrip_decimals_via_writer() -> Result<(), AvroError> {
let files: [(&str, bool); 8] = [
("avro/fixed_length_decimal.avro", true), ("avro/fixed_length_decimal_legacy.avro", true), ("avro/int32_decimal.avro", true), ("avro/int64_decimal.avro", true), ("test/data/int256_decimal.avro", false), ("test/data/fixed256_decimal.avro", false), ("test/data/fixed_length_decimal_legacy_32.avro", false), ("test/data/int128_decimal.avro", false), ];
for (rel, in_test_data_dir) in files {
let path: String = if in_test_data_dir {
arrow_test_data(rel)
} else {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join(rel)
.to_string_lossy()
.into_owned()
};
let f_in = File::open(&path).expect("open input avro");
let rdr = ReaderBuilder::new().build(BufReader::new(f_in))?;
let in_schema = rdr.schema();
let in_batches = rdr.collect::<Result<Vec<_>, _>>()?;
let original =
arrow::compute::concat_batches(&in_schema, &in_batches).expect("concat input");
let tmp = NamedTempFile::new().expect("create temp file");
let out_path = tmp.into_temp_path();
let out_file = File::create(&out_path).expect("create temp avro");
let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?;
writer.write(&original)?;
writer.finish()?;
let f_rt = File::open(&out_path).expect("open roundtrip avro");
let rt_rdr = ReaderBuilder::new().build(BufReader::new(f_rt))?;
let rt_schema = rt_rdr.schema();
let rt_batches = rt_rdr.collect::<Result<Vec<_>, _>>()?;
let roundtrip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat rt");
assert_eq!(roundtrip, original, "decimal round-trip mismatch for {rel}");
}
Ok(())
}
#[test]
fn test_named_types_complex_roundtrip() -> Result<(), AvroError> {
let path =
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test/data/named_types_complex.avro");
let rdr_file = File::open(&path).expect("open avro/named_types_complex.avro");
let reader = ReaderBuilder::new()
.build(BufReader::new(rdr_file))
.expect("build reader for named_types_complex.avro");
let in_schema = reader.schema();
let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
let original =
arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
{
let arrow_schema = original.schema();
let author_field = arrow_schema.field_with_name("author")?;
let author_type = author_field.data_type();
let editors_field = arrow_schema.field_with_name("editors")?;
let editors_item_type = match editors_field.data_type() {
DataType::List(item_field) => item_field.data_type(),
other => panic!("Editors field should be a List, but was {:?}", other),
};
assert_eq!(
author_type, editors_item_type,
"The DataType for the 'author' struct and the 'editors' list items must be identical"
);
let status_field = arrow_schema.field_with_name("status")?;
let status_type = status_field.data_type();
assert!(
matches!(status_type, DataType::Dictionary(_, _)),
"Status field should be a Dictionary (Enum)"
);
let prev_status_field = arrow_schema.field_with_name("previous_status")?;
let prev_status_type = prev_status_field.data_type();
assert_eq!(
status_type, prev_status_type,
"The DataType for 'status' and 'previous_status' enums must be identical"
);
let content_hash_field = arrow_schema.field_with_name("content_hash")?;
let content_hash_type = content_hash_field.data_type();
assert!(
matches!(content_hash_type, DataType::FixedSizeBinary(16)),
"Content hash should be FixedSizeBinary(16)"
);
let thumb_hash_field = arrow_schema.field_with_name("thumbnail_hash")?;
let thumb_hash_type = thumb_hash_field.data_type();
assert_eq!(
content_hash_type, thumb_hash_type,
"The DataType for 'content_hash' and 'thumbnail_hash' fixed types must be identical"
);
}
let buffer: Vec<u8> = Vec::new();
let mut writer = AvroWriter::new(buffer, original.schema().as_ref().clone())?;
writer.write(&original)?;
writer.finish()?;
let bytes = writer.into_inner();
let rt_reader = ReaderBuilder::new()
.build(Cursor::new(bytes))
.expect("build reader for round-trip");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let roundtrip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip");
assert_eq!(
roundtrip, original,
"Avro complex named types round-trip mismatch"
);
Ok(())
}
fn assert_schema_is_semantically_equivalent(expected: &Schema, actual: &Schema) {
assert_metadata_is_superset(expected.metadata(), actual.metadata(), "Schema");
assert_eq!(
expected.fields().len(),
actual.fields().len(),
"Schema must have the same number of fields"
);
for (expected_field, actual_field) in expected.fields().iter().zip(actual.fields().iter()) {
assert_field_is_semantically_equivalent(expected_field, actual_field);
}
}
fn assert_field_is_semantically_equivalent(expected: &Field, actual: &Field) {
let context = format!("Field '{}'", expected.name());
assert_eq!(
expected.name(),
actual.name(),
"{context}: names must match"
);
assert_eq!(
expected.is_nullable(),
actual.is_nullable(),
"{context}: nullability must match"
);
assert_datatype_is_semantically_equivalent(
expected.data_type(),
actual.data_type(),
&context,
);
assert_metadata_is_superset(expected.metadata(), actual.metadata(), &context);
}
fn assert_datatype_is_semantically_equivalent(
expected: &DataType,
actual: &DataType,
context: &str,
) {
match (expected, actual) {
(DataType::List(expected_field), DataType::List(actual_field))
| (DataType::LargeList(expected_field), DataType::LargeList(actual_field))
| (DataType::Map(expected_field, _), DataType::Map(actual_field, _)) => {
assert_field_is_semantically_equivalent(expected_field, actual_field);
}
(DataType::Struct(expected_fields), DataType::Struct(actual_fields)) => {
assert_eq!(
expected_fields.len(),
actual_fields.len(),
"{context}: struct must have same number of fields"
);
for (ef, af) in expected_fields.iter().zip(actual_fields.iter()) {
assert_field_is_semantically_equivalent(ef, af);
}
}
(
DataType::Union(expected_fields, expected_mode),
DataType::Union(actual_fields, actual_mode),
) => {
assert_eq!(
expected_mode, actual_mode,
"{context}: union mode must match"
);
assert_eq!(
expected_fields.len(),
actual_fields.len(),
"{context}: union must have same number of variants"
);
for ((exp_id, exp_field), (act_id, act_field)) in
expected_fields.iter().zip(actual_fields.iter())
{
assert_eq!(exp_id, act_id, "{context}: union type ids must match");
assert_field_is_semantically_equivalent(exp_field, act_field);
}
}
_ => {
assert_eq!(expected, actual, "{context}: data types must be identical");
}
}
}
fn assert_batch_data_is_identical(expected: &RecordBatch, actual: &RecordBatch) {
assert_eq!(
expected.num_columns(),
actual.num_columns(),
"RecordBatches must have the same number of columns"
);
assert_eq!(
expected.num_rows(),
actual.num_rows(),
"RecordBatches must have the same number of rows"
);
for i in 0..expected.num_columns() {
let context = format!("Column {i}");
let expected_col = expected.column(i);
let actual_col = actual.column(i);
assert_array_data_is_identical(expected_col, actual_col, &context);
}
}
fn assert_array_data_is_identical(expected: &dyn Array, actual: &dyn Array, context: &str) {
assert_eq!(
expected.nulls(),
actual.nulls(),
"{context}: null buffers must match"
);
assert_eq!(
expected.len(),
actual.len(),
"{context}: array lengths must match"
);
match (expected.data_type(), actual.data_type()) {
(DataType::Union(expected_fields, _), DataType::Union(..)) => {
let expected_union = expected.as_any().downcast_ref::<UnionArray>().unwrap();
let actual_union = actual.as_any().downcast_ref::<UnionArray>().unwrap();
assert_eq!(
&expected.to_data().buffers()[0],
&actual.to_data().buffers()[0],
"{context}: union type_ids buffer mismatch"
);
if expected.to_data().buffers().len() > 1 {
assert_eq!(
&expected.to_data().buffers()[1],
&actual.to_data().buffers()[1],
"{context}: union value_offsets buffer mismatch"
);
}
for (type_id, _) in expected_fields.iter() {
let child_context = format!("{context} -> child variant {type_id}");
assert_array_data_is_identical(
expected_union.child(type_id),
actual_union.child(type_id),
&child_context,
);
}
}
(DataType::Struct(_), DataType::Struct(_)) => {
let expected_struct = expected.as_any().downcast_ref::<StructArray>().unwrap();
let actual_struct = actual.as_any().downcast_ref::<StructArray>().unwrap();
for i in 0..expected_struct.num_columns() {
let child_context = format!("{context} -> struct child {i}");
assert_array_data_is_identical(
expected_struct.column(i),
actual_struct.column(i),
&child_context,
);
}
}
_ => {
assert_eq!(
expected.to_data().buffers(),
actual.to_data().buffers(),
"{context}: data buffers must match"
);
}
}
}
fn assert_metadata_is_superset(
expected_meta: &HashMap<String, String>,
actual_meta: &HashMap<String, String>,
context: &str,
) {
let allowed_additions: HashSet<&str> =
vec!["arrowUnionMode", "arrowUnionTypeIds", "avro.name"]
.into_iter()
.collect();
for (key, expected_value) in expected_meta {
match actual_meta.get(key) {
Some(actual_value) => assert_eq!(
expected_value, actual_value,
"{context}: preserved metadata for key '{key}' must have the same value"
),
None => panic!("{context}: metadata key '{key}' was lost during roundtrip"),
}
}
for key in actual_meta.keys() {
if !expected_meta.contains_key(key) && !allowed_additions.contains(key.as_str()) {
panic!("{context}: unexpected metadata key '{key}' was added during roundtrip");
}
}
}
#[test]
fn test_union_roundtrip() -> Result<(), AvroError> {
let file_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("test/data/union_fields.avro")
.to_string_lossy()
.into_owned();
let rdr_file = File::open(&file_path).expect("open avro/union_fields.avro");
let reader = ReaderBuilder::new()
.build(BufReader::new(rdr_file))
.expect("build reader for union_fields.avro");
let schema = reader.schema();
let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
let original =
arrow::compute::concat_batches(&schema, &input_batches).expect("concat input");
let mut writer = AvroWriter::new(Vec::<u8>::new(), original.schema().as_ref().clone())?;
writer.write(&original)?;
writer.finish()?;
let bytes = writer.into_inner();
let rt_reader = ReaderBuilder::new()
.build(Cursor::new(bytes))
.expect("build round_trip reader");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let round_trip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
assert_schema_is_semantically_equivalent(&original.schema(), &round_trip.schema());
assert_batch_data_is_identical(&original, &round_trip);
Ok(())
}
#[test]
fn test_enum_roundtrip_uses_reader_fixture() -> Result<(), AvroError> {
let path = arrow_test_data("avro/simple_enum.avro");
let rdr_file = File::open(&path).expect("open avro/simple_enum.avro");
let reader = ReaderBuilder::new()
.build(BufReader::new(rdr_file))
.expect("build reader for simple_enum.avro");
let in_schema = reader.schema();
let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
let original =
arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
let has_enum_dict = in_schema.fields().iter().any(|f| {
matches!(
f.data_type(),
DataType::Dictionary(k, v) if **k == DataType::Int32 && **v == DataType::Utf8
)
});
assert!(
has_enum_dict,
"Expected at least one enum-mapped Dictionary<Int32, Utf8> field"
);
let buffer: Vec<u8> = Vec::new();
let mut writer = AvroWriter::new(buffer, in_schema.as_ref().clone())?;
writer.write(&original)?;
writer.finish()?;
let bytes = writer.into_inner();
let rt_reader = ReaderBuilder::new()
.build(Cursor::new(bytes))
.expect("reader for round-trip");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let roundtrip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip");
assert_eq!(roundtrip, original, "Avro enum round-trip mismatch");
Ok(())
}
#[test]
fn test_builder_propagates_capacity_to_writer() -> Result<(), AvroError> {
let cap = 64 * 1024;
let buffer = Vec::<u8>::new();
let mut writer = WriterBuilder::new(make_schema())
.with_capacity(cap)
.build::<_, AvroOcfFormat>(buffer)?;
assert_eq!(writer.capacity, cap, "builder capacity not propagated");
let batch = make_batch();
writer.write(&batch)?;
writer.finish()?;
let out = writer.into_inner();
assert_eq!(&out[..4], b"Obj\x01", "OCF magic missing/incorrect");
Ok(())
}
#[test]
fn test_stream_writer_stores_capacity_direct_writes() -> Result<(), AvroError> {
use arrow_array::{ArrayRef, Int32Array};
use arrow_schema::{DataType, Field, Schema};
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef],
)?;
let cap = 8192;
let mut writer = WriterBuilder::new(schema)
.with_capacity(cap)
.build::<_, AvroSoeFormat>(Vec::new())?;
assert_eq!(writer.capacity, cap);
writer.write(&batch)?;
let _bytes = writer.into_inner();
Ok(())
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_roundtrip_duration_logical_types_ocf() -> Result<(), AvroError> {
let file_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("test/data/duration_logical_types.avro")
.to_string_lossy()
.into_owned();
let in_file = File::open(&file_path)
.unwrap_or_else(|_| panic!("Failed to open test file: {}", file_path));
let reader = ReaderBuilder::new()
.build(BufReader::new(in_file))
.expect("build reader for duration_logical_types.avro");
let in_schema = reader.schema();
let expected_units: HashSet<TimeUnit> = [
TimeUnit::Nanosecond,
TimeUnit::Microsecond,
TimeUnit::Millisecond,
TimeUnit::Second,
]
.into_iter()
.collect();
let found_units: HashSet<TimeUnit> = in_schema
.fields()
.iter()
.filter_map(|f| match f.data_type() {
DataType::Duration(unit) => Some(*unit),
_ => None,
})
.collect();
assert_eq!(
found_units, expected_units,
"Expected to find all four Duration TimeUnits in the schema from the initial read"
);
let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
let input =
arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
let tmp = NamedTempFile::new().expect("create temp file");
{
let out_file = File::create(tmp.path()).expect("create temp avro");
let mut writer = AvroWriter::new(out_file, in_schema.as_ref().clone())?;
writer.write(&input)?;
writer.finish()?;
}
let rt_file = File::open(tmp.path()).expect("open round_trip avro");
let rt_reader = ReaderBuilder::new()
.build(BufReader::new(rt_file))
.expect("build round_trip reader");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let round_trip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
assert_eq!(round_trip, input);
Ok(())
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_run_end_encoded_roundtrip_writer() -> Result<(), AvroError> {
let run_ends = Int32Array::from(vec![3, 5, 7, 8]);
let run_values = Int32Array::from(vec![Some(1), Some(2), None, Some(3)]);
let ree = RunArray::<Int32Type>::try_new(&run_ends, &run_values)?;
let field = Field::new("x", ree.data_type().clone(), true);
let schema = Schema::new(vec![field]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(ree.clone()) as ArrayRef],
)?;
let mut writer = AvroWriter::new(Vec::<u8>::new(), schema.clone())?;
writer.write(&batch)?;
writer.finish()?;
let bytes = writer.into_inner();
let reader = ReaderBuilder::new().build(Cursor::new(bytes))?;
let out_schema = reader.schema();
let batches = reader.collect::<Result<Vec<_>, _>>()?;
let out = arrow::compute::concat_batches(&out_schema, &batches).expect("concat output");
assert_eq!(out.num_columns(), 1);
assert_eq!(out.num_rows(), 8);
match out.schema().field(0).data_type() {
DataType::RunEndEncoded(run_ends_field, values_field) => {
assert_eq!(run_ends_field.name(), "run_ends");
assert_eq!(run_ends_field.data_type(), &DataType::Int32);
assert_eq!(values_field.name(), "values");
assert_eq!(values_field.data_type(), &DataType::Int32);
assert!(values_field.is_nullable());
let got_ree = out
.column(0)
.as_any()
.downcast_ref::<RunArray<Int32Type>>()
.expect("RunArray<Int32Type>");
assert_eq!(got_ree, &ree);
}
other => panic!(
"Unexpected DataType for round-tripped RunEndEncoded column: {:?}",
other
),
}
Ok(())
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_run_end_encoded_string_values_int16_run_ends_roundtrip_writer() -> Result<(), AvroError>
{
let run_ends = Int16Array::from(vec![2, 5, 7]); let run_values = StringArray::from(vec![Some("a"), None, Some("c")]);
let ree = RunArray::<Int16Type>::try_new(&run_ends, &run_values)?;
let field = Field::new("s", ree.data_type().clone(), true);
let schema = Schema::new(vec![field]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(ree.clone()) as ArrayRef],
)?;
let mut writer = AvroWriter::new(Vec::<u8>::new(), schema.clone())?;
writer.write(&batch)?;
writer.finish()?;
let bytes = writer.into_inner();
let reader = ReaderBuilder::new().build(Cursor::new(bytes))?;
let out_schema = reader.schema();
let batches = reader.collect::<Result<Vec<_>, _>>()?;
let out = arrow::compute::concat_batches(&out_schema, &batches).expect("concat output");
assert_eq!(out.num_columns(), 1);
assert_eq!(out.num_rows(), 7);
match out.schema().field(0).data_type() {
DataType::RunEndEncoded(run_ends_field, values_field) => {
assert_eq!(run_ends_field.data_type(), &DataType::Int16);
assert_eq!(values_field.data_type(), &DataType::Utf8);
assert!(
values_field.is_nullable(),
"REE 'values' child should be nullable"
);
let got = out
.column(0)
.as_any()
.downcast_ref::<RunArray<Int16Type>>()
.expect("RunArray<Int16Type>");
assert_eq!(got, &ree);
}
other => panic!("Unexpected DataType: {:?}", other),
}
Ok(())
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_run_end_encoded_int64_run_ends_numeric_values_roundtrip_writer() -> Result<(), AvroError>
{
let run_ends = Int64Array::from(vec![4_i64, 8_i64]);
let run_values = Int32Array::from(vec![Some(999), Some(-5)]);
let ree = RunArray::<Int64Type>::try_new(&run_ends, &run_values)?;
let field = Field::new("y", ree.data_type().clone(), true);
let schema = Schema::new(vec![field]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(ree.clone()) as ArrayRef],
)?;
let mut writer = AvroWriter::new(Vec::<u8>::new(), schema.clone())?;
writer.write(&batch)?;
writer.finish()?;
let bytes = writer.into_inner();
let reader = ReaderBuilder::new().build(Cursor::new(bytes))?;
let out_schema = reader.schema();
let batches = reader.collect::<Result<Vec<_>, _>>()?;
let out = arrow::compute::concat_batches(&out_schema, &batches).expect("concat output");
assert_eq!(out.num_columns(), 1);
assert_eq!(out.num_rows(), 8);
match out.schema().field(0).data_type() {
DataType::RunEndEncoded(run_ends_field, values_field) => {
assert_eq!(run_ends_field.data_type(), &DataType::Int64);
assert_eq!(values_field.data_type(), &DataType::Int32);
assert!(values_field.is_nullable());
let got = out
.column(0)
.as_any()
.downcast_ref::<RunArray<Int64Type>>()
.expect("RunArray<Int64Type>");
assert_eq!(got, &ree);
}
other => panic!("Unexpected DataType for REE column: {:?}", other),
}
Ok(())
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_run_end_encoded_sliced_roundtrip_writer() -> Result<(), AvroError> {
let run_ends = Int32Array::from(vec![3, 5, 7, 8]);
let run_values = Int32Array::from(vec![Some(1), Some(2), None, Some(3)]);
let base = RunArray::<Int32Type>::try_new(&run_ends, &run_values)?;
let offset = 1usize;
let length = 6usize;
let base_values = base.values().as_primitive::<Int32Type>();
let mut logical_window: Vec<Option<i32>> = Vec::with_capacity(length);
for i in offset..offset + length {
let phys = base.get_physical_index(i);
let v = if base_values.is_null(phys) {
None
} else {
Some(base_values.value(phys))
};
logical_window.push(v);
}
fn compress_run_ends_i32(vals: &[Option<i32>]) -> (Int32Array, Int32Array) {
if vals.is_empty() {
return (Int32Array::new_null(0), Int32Array::new_null(0));
}
let mut run_ends_out: Vec<i32> = Vec::new();
let mut run_vals_out: Vec<Option<i32>> = Vec::new();
let mut cur = vals[0];
let mut len = 1i32;
for v in &vals[1..] {
if *v == cur {
len += 1;
} else {
let last_end = run_ends_out.last().copied().unwrap_or(0);
run_ends_out.push(last_end + len);
run_vals_out.push(cur);
cur = *v;
len = 1;
}
}
let last_end = run_ends_out.last().copied().unwrap_or(0);
run_ends_out.push(last_end + len);
run_vals_out.push(cur);
(
Int32Array::from(run_ends_out),
Int32Array::from(run_vals_out),
)
}
let (owned_run_ends, owned_run_values) = compress_run_ends_i32(&logical_window);
let owned_slice = RunArray::<Int32Type>::try_new(&owned_run_ends, &owned_run_values)?;
let field = Field::new("x", owned_slice.data_type().clone(), true);
let schema = Schema::new(vec![field]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(owned_slice.clone()) as ArrayRef],
)?;
let mut writer = AvroWriter::new(Vec::<u8>::new(), schema.clone())?;
writer.write(&batch)?;
writer.finish()?;
let bytes = writer.into_inner();
let reader = ReaderBuilder::new().build(Cursor::new(bytes))?;
let out_schema = reader.schema();
let batches = reader.collect::<Result<Vec<_>, _>>()?;
let out = arrow::compute::concat_batches(&out_schema, &batches).expect("concat output");
assert_eq!(out.num_columns(), 1);
assert_eq!(out.num_rows(), length);
match out.schema().field(0).data_type() {
DataType::RunEndEncoded(run_ends_field, values_field) => {
assert_eq!(run_ends_field.data_type(), &DataType::Int32);
assert_eq!(values_field.data_type(), &DataType::Int32);
assert!(values_field.is_nullable());
let got = out
.column(0)
.as_any()
.downcast_ref::<RunArray<Int32Type>>()
.expect("RunArray<Int32Type>");
fn expand_ree_to_int32(a: &RunArray<Int32Type>) -> Int32Array {
let vals = a.values().as_primitive::<Int32Type>();
let mut out: Vec<Option<i32>> = Vec::with_capacity(a.len());
for i in 0..a.len() {
let phys = a.get_physical_index(i);
out.push(if vals.is_null(phys) {
None
} else {
Some(vals.value(phys))
});
}
Int32Array::from(out)
}
let got_logical = expand_ree_to_int32(got);
let expected_logical = Int32Array::from(logical_window);
assert_eq!(
got_logical, expected_logical,
"Logical values differ after REE slice round-trip"
);
}
other => panic!("Unexpected DataType for REE column: {:?}", other),
}
Ok(())
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_run_end_encoded_roundtrip_writer_feature_off() -> Result<(), AvroError> {
use arrow_schema::{DataType, Field, Schema};
let run_ends = arrow_array::Int32Array::from(vec![3, 5, 7, 8]);
let run_values = arrow_array::Int32Array::from(vec![Some(1), Some(2), None, Some(3)]);
let ree = arrow_array::RunArray::<arrow_array::types::Int32Type>::try_new(
&run_ends,
&run_values,
)?;
let field = Field::new("x", ree.data_type().clone(), true);
let schema = Schema::new(vec![field]);
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(ree) as ArrayRef])?;
let mut writer = AvroWriter::new(Vec::<u8>::new(), schema.clone())?;
writer.write(&batch)?;
writer.finish()?;
let bytes = writer.into_inner();
let reader = ReaderBuilder::new().build(Cursor::new(bytes))?;
let out_schema = reader.schema();
let batches = reader.collect::<Result<Vec<_>, _>>()?;
let out = arrow::compute::concat_batches(&out_schema, &batches).expect("concat output");
assert_eq!(out.num_columns(), 1);
assert_eq!(out.num_rows(), 8);
assert_eq!(out.schema().field(0).data_type(), &DataType::Int32);
let got = out.column(0).as_primitive::<Int32Type>();
let expected = Int32Array::from(vec![
Some(1),
Some(1),
Some(1),
Some(2),
Some(2),
None,
None,
Some(3),
]);
assert_eq!(got, &expected);
Ok(())
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_run_end_encoded_string_values_int16_run_ends_roundtrip_writer_feature_off()
-> Result<(), AvroError> {
use arrow_schema::{DataType, Field, Schema};
let run_ends = arrow_array::Int16Array::from(vec![2, 5, 7]);
let run_values = arrow_array::StringArray::from(vec![Some("a"), None, Some("c")]);
let ree = arrow_array::RunArray::<arrow_array::types::Int16Type>::try_new(
&run_ends,
&run_values,
)?;
let field = Field::new("s", ree.data_type().clone(), true);
let schema = Schema::new(vec![field]);
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(ree) as ArrayRef])?;
let mut writer = AvroWriter::new(Vec::<u8>::new(), schema.clone())?;
writer.write(&batch)?;
writer.finish()?;
let bytes = writer.into_inner();
let reader = ReaderBuilder::new().build(Cursor::new(bytes))?;
let out_schema = reader.schema();
let batches = reader.collect::<Result<Vec<_>, _>>()?;
let out = arrow::compute::concat_batches(&out_schema, &batches).expect("concat output");
assert_eq!(out.num_columns(), 1);
assert_eq!(out.num_rows(), 7);
assert_eq!(out.schema().field(0).data_type(), &DataType::Utf8);
let got = out
.column(0)
.as_any()
.downcast_ref::<arrow_array::StringArray>()
.expect("StringArray");
let expected = arrow_array::StringArray::from(vec![
Some("a"),
Some("a"),
None,
None,
None,
Some("c"),
Some("c"),
]);
assert_eq!(got, &expected);
Ok(())
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_run_end_encoded_int64_run_ends_numeric_values_roundtrip_writer_feature_off()
-> Result<(), AvroError> {
use arrow_schema::{DataType, Field, Schema};
let run_ends = arrow_array::Int64Array::from(vec![4_i64, 8_i64]);
let run_values = Int32Array::from(vec![Some(999), Some(-5)]);
let ree = arrow_array::RunArray::<arrow_array::types::Int64Type>::try_new(
&run_ends,
&run_values,
)?;
let field = Field::new("y", ree.data_type().clone(), true);
let schema = Schema::new(vec![field]);
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(ree) as ArrayRef])?;
let mut writer = AvroWriter::new(Vec::<u8>::new(), schema.clone())?;
writer.write(&batch)?;
writer.finish()?;
let bytes = writer.into_inner();
let reader = ReaderBuilder::new().build(Cursor::new(bytes))?;
let out_schema = reader.schema();
let batches = reader.collect::<Result<Vec<_>, _>>()?;
let out = arrow::compute::concat_batches(&out_schema, &batches).expect("concat output");
assert_eq!(out.num_columns(), 1);
assert_eq!(out.num_rows(), 8);
assert_eq!(out.schema().field(0).data_type(), &DataType::Int32);
let got = out.column(0).as_primitive::<Int32Type>();
let expected = Int32Array::from(vec![
Some(999),
Some(999),
Some(999),
Some(999),
Some(-5),
Some(-5),
Some(-5),
Some(-5),
]);
assert_eq!(got, &expected);
Ok(())
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_run_end_encoded_sliced_roundtrip_writer_feature_off() -> Result<(), AvroError> {
use arrow_schema::{DataType, Field, Schema};
let run_ends = Int32Array::from(vec![2, 4, 6]);
let run_values = Int32Array::from(vec![Some(1), Some(2), None]);
let ree = arrow_array::RunArray::<arrow_array::types::Int32Type>::try_new(
&run_ends,
&run_values,
)?;
let field = Field::new("x", ree.data_type().clone(), true);
let schema = Schema::new(vec![field]);
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(ree) as ArrayRef])?;
let mut writer = AvroWriter::new(Vec::<u8>::new(), schema.clone())?;
writer.write(&batch)?;
writer.finish()?;
let bytes = writer.into_inner();
let reader = ReaderBuilder::new().build(Cursor::new(bytes))?;
let out_schema = reader.schema();
let batches = reader.collect::<Result<Vec<_>, _>>()?;
let out = arrow::compute::concat_batches(&out_schema, &batches).expect("concat output");
assert_eq!(out.num_columns(), 1);
assert_eq!(out.num_rows(), 6);
assert_eq!(out.schema().field(0).data_type(), &DataType::Int32);
let got = out.column(0).as_primitive::<Int32Type>();
let expected = Int32Array::from(vec![Some(1), Some(1), Some(2), Some(2), None, None]);
assert_eq!(got, &expected);
Ok(())
}
#[test]
#[cfg(feature = "snappy")]
fn test_nullable_impala_roundtrip() -> Result<(), AvroError> {
let path = arrow_test_data("avro/nullable.impala.avro");
let rdr_file = File::open(&path).expect("open avro/nullable.impala.avro");
let reader = ReaderBuilder::new()
.build(BufReader::new(rdr_file))
.expect("build reader for nullable.impala.avro");
let in_schema = reader.schema();
assert!(
in_schema.fields().iter().any(|f| f.is_nullable()),
"expected at least one nullable field in avro/nullable.impala.avro"
);
let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
let original =
arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
let buffer: Vec<u8> = Vec::new();
let mut writer = AvroWriter::new(buffer, in_schema.as_ref().clone())?;
writer.write(&original)?;
writer.finish()?;
let out_bytes = writer.into_inner();
let rt_reader = ReaderBuilder::new()
.build(Cursor::new(out_bytes))
.expect("build reader for round-tripped in-memory OCF");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let roundtrip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip");
assert_eq!(
roundtrip, original,
"Round-trip Avro data mismatch for nullable.impala.avro"
);
Ok(())
}
#[test]
#[cfg(feature = "snappy")]
fn test_datapage_v2_roundtrip() -> Result<(), AvroError> {
let path = arrow_test_data("avro/datapage_v2.snappy.avro");
let rdr_file = File::open(&path).expect("open avro/datapage_v2.snappy.avro");
let reader = ReaderBuilder::new()
.build(BufReader::new(rdr_file))
.expect("build reader for datapage_v2.snappy.avro");
let in_schema = reader.schema();
let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
let original =
arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
let mut writer = AvroWriter::new(Vec::<u8>::new(), in_schema.as_ref().clone())?;
writer.write(&original)?;
writer.finish()?;
let bytes = writer.into_inner();
let rt_reader = ReaderBuilder::new()
.build(Cursor::new(bytes))
.expect("build round-trip reader");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let round_trip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
assert_eq!(
round_trip, original,
"Round-trip batch mismatch for datapage_v2.snappy.avro"
);
Ok(())
}
#[test]
#[cfg(feature = "snappy")]
fn test_single_nan_roundtrip() -> Result<(), AvroError> {
let path = arrow_test_data("avro/single_nan.avro");
let in_file = File::open(&path).expect("open avro/single_nan.avro");
let reader = ReaderBuilder::new()
.build(BufReader::new(in_file))
.expect("build reader for single_nan.avro");
let in_schema = reader.schema();
let in_batches = reader.collect::<Result<Vec<_>, _>>()?;
let original =
arrow::compute::concat_batches(&in_schema, &in_batches).expect("concat input");
let mut writer = AvroWriter::new(Vec::<u8>::new(), original.schema().as_ref().clone())?;
writer.write(&original)?;
writer.finish()?;
let bytes = writer.into_inner();
let rt_reader = ReaderBuilder::new()
.build(Cursor::new(bytes))
.expect("build round_trip reader");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let round_trip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
assert_eq!(
round_trip, original,
"Round-trip batch mismatch for avro/single_nan.avro"
);
Ok(())
}
#[test]
#[cfg(feature = "snappy")]
fn test_dict_pages_offset_zero_roundtrip() -> Result<(), AvroError> {
let path = arrow_test_data("avro/dict-page-offset-zero.avro");
let rdr_file = File::open(&path).expect("open avro/dict-page-offset-zero.avro");
let reader = ReaderBuilder::new()
.build(BufReader::new(rdr_file))
.expect("build reader for dict-page-offset-zero.avro");
let in_schema = reader.schema();
let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
let original =
arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
let buffer: Vec<u8> = Vec::new();
let mut writer = AvroWriter::new(buffer, original.schema().as_ref().clone())?;
writer.write(&original)?;
writer.finish()?;
let bytes = writer.into_inner();
let rt_reader = ReaderBuilder::new()
.build(Cursor::new(bytes))
.expect("build reader for round-trip");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let roundtrip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip");
assert_eq!(
roundtrip, original,
"Round-trip batch mismatch for avro/dict-page-offset-zero.avro"
);
Ok(())
}
#[test]
#[cfg(feature = "snappy")]
fn test_repeated_no_annotation_roundtrip() -> Result<(), AvroError> {
let path = arrow_test_data("avro/repeated_no_annotation.avro");
let in_file = File::open(&path).expect("open avro/repeated_no_annotation.avro");
let reader = ReaderBuilder::new()
.build(BufReader::new(in_file))
.expect("build reader for repeated_no_annotation.avro");
let in_schema = reader.schema();
let in_batches = reader.collect::<Result<Vec<_>, _>>()?;
let original =
arrow::compute::concat_batches(&in_schema, &in_batches).expect("concat input");
let mut writer = AvroWriter::new(Vec::<u8>::new(), original.schema().as_ref().clone())?;
writer.write(&original)?;
writer.finish()?;
let bytes = writer.into_inner();
let rt_reader = ReaderBuilder::new()
.build(Cursor::new(bytes))
.expect("build reader for round-trip buffer");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let round_trip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round-trip");
assert_eq!(
round_trip, original,
"Round-trip batch mismatch for avro/repeated_no_annotation.avro"
);
Ok(())
}
#[test]
fn test_nested_record_type_reuse_roundtrip() -> Result<(), AvroError> {
let path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("test/data/nested_record_reuse.avro")
.to_string_lossy()
.into_owned();
let in_file = File::open(&path).expect("open avro/nested_record_reuse.avro");
let reader = ReaderBuilder::new()
.build(BufReader::new(in_file))
.expect("build reader for nested_record_reuse.avro");
let in_schema = reader.schema();
let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
let input =
arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
let mut writer = AvroWriter::new(Vec::<u8>::new(), in_schema.as_ref().clone())?;
writer.write(&input)?;
writer.finish()?;
let bytes = writer.into_inner();
let rt_reader = ReaderBuilder::new()
.build(Cursor::new(bytes))
.expect("build round_trip reader");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let round_trip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
assert_eq!(
round_trip, input,
"Round-trip batch mismatch for nested_record_reuse.avro"
);
Ok(())
}
#[test]
fn test_enum_type_reuse_roundtrip() -> Result<(), AvroError> {
let path =
std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test/data/enum_reuse.avro");
let rdr_file = std::fs::File::open(&path).expect("open test/data/enum_reuse.avro");
let reader = ReaderBuilder::new()
.build(std::io::BufReader::new(rdr_file))
.expect("build reader for enum_reuse.avro");
let in_schema = reader.schema();
let input_batches = reader.collect::<Result<Vec<_>, _>>()?;
let original =
arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input");
let mut writer = AvroWriter::new(Vec::<u8>::new(), original.schema().as_ref().clone())?;
writer.write(&original)?;
writer.finish()?;
let bytes = writer.into_inner();
let rt_reader = ReaderBuilder::new()
.build(std::io::Cursor::new(bytes))
.expect("build round_trip reader");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let round_trip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip");
assert_eq!(
round_trip, original,
"Avro enum type reuse round-trip mismatch"
);
Ok(())
}
#[test]
fn comprehensive_e2e_test_roundtrip() -> Result<(), AvroError> {
let path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("test/data/comprehensive_e2e.avro");
let rdr_file = File::open(&path).expect("open test/data/comprehensive_e2e.avro");
let reader = ReaderBuilder::new()
.build(BufReader::new(rdr_file))
.expect("build reader for comprehensive_e2e.avro");
let in_schema = reader.schema();
let in_batches = reader.collect::<Result<Vec<_>, _>>()?;
let original =
arrow::compute::concat_batches(&in_schema, &in_batches).expect("concat input");
let sink: Vec<u8> = Vec::new();
let mut writer = AvroWriter::new(sink, original.schema().as_ref().clone())?;
writer.write(&original)?;
writer.finish()?;
let bytes = writer.into_inner();
let rt_reader = ReaderBuilder::new()
.build(Cursor::new(bytes))
.expect("build round-trip reader");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let roundtrip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip");
assert_eq!(
roundtrip, original,
"Round-trip batch mismatch for comprehensive_e2e.avro"
);
Ok(())
}
#[test]
fn test_roundtrip_new_time_encoders_writer() -> Result<(), AvroError> {
let schema = Schema::new(vec![
Field::new("d32", DataType::Date32, false),
Field::new("t32_ms", DataType::Time32(TimeUnit::Millisecond), false),
Field::new("t64_us", DataType::Time64(TimeUnit::Microsecond), false),
Field::new(
"ts_ms",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new(
"ts_us",
DataType::Timestamp(TimeUnit::Microsecond, None),
false,
),
Field::new(
"ts_ns",
DataType::Timestamp(TimeUnit::Nanosecond, None),
false,
),
]);
let d32 = Date32Array::from(vec![0, 1, -1]);
let t32_ms: PrimitiveArray<Time32MillisecondType> =
vec![0_i32, 12_345_i32, 86_399_999_i32].into();
let t64_us: PrimitiveArray<Time64MicrosecondType> =
vec![0_i64, 1_234_567_i64, 86_399_999_999_i64].into();
let ts_ms: PrimitiveArray<TimestampMillisecondType> =
vec![0_i64, -1_i64, 1_700_000_000_000_i64].into();
let ts_us: PrimitiveArray<TimestampMicrosecondType> = vec![0_i64, 1_i64, -1_i64].into();
let ts_ns: PrimitiveArray<TimestampNanosecondType> = vec![0_i64, 1_i64, -1_i64].into();
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(d32) as ArrayRef,
Arc::new(t32_ms) as ArrayRef,
Arc::new(t64_us) as ArrayRef,
Arc::new(ts_ms) as ArrayRef,
Arc::new(ts_us) as ArrayRef,
Arc::new(ts_ns) as ArrayRef,
],
)?;
let mut writer = AvroWriter::new(Vec::<u8>::new(), schema.clone())?;
writer.write(&batch)?;
writer.finish()?;
let bytes = writer.into_inner();
let rt_reader = ReaderBuilder::new()
.build(std::io::Cursor::new(bytes))
.expect("build reader for round-trip of new time encoders");
let rt_schema = rt_reader.schema();
let rt_batches = rt_reader.collect::<Result<Vec<_>, _>>()?;
let roundtrip =
arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip");
assert_eq!(roundtrip, batch);
Ok(())
}
fn make_encoder_schema() -> Schema {
Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
])
}
fn make_encoder_batch(schema: &Schema) -> RecordBatch {
let a = Int32Array::from(vec![1, 2, 3]);
let b = Int32Array::from(vec![10, 20, 30]);
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(a) as ArrayRef, Arc::new(b) as ArrayRef],
)
.expect("failed to build test RecordBatch")
}
fn make_real_avro_schema_and_batch() -> Result<(Schema, RecordBatch, AvroSchema), AvroError> {
let avro_json = r#"
{
"type": "record",
"name": "User",
"fields": [
{ "name": "id", "type": "long" },
{ "name": "name", "type": "string" },
{ "name": "active", "type": "boolean" },
{ "name": "tags", "type": { "type": "array", "items": "int" } },
{ "name": "opt", "type": ["null", "string"], "default": null }
]
}"#;
let avro_schema = AvroSchema::new(avro_json.to_string());
let mut md = HashMap::new();
md.insert(
SCHEMA_METADATA_KEY.to_string(),
avro_schema.json_string.clone(),
);
let item_field = Arc::new(Field::new(
Field::LIST_FIELD_DEFAULT_NAME,
DataType::Int32,
false,
));
let schema = Schema::new_with_metadata(
vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
Field::new("active", DataType::Boolean, false),
Field::new("tags", DataType::List(item_field.clone()), false),
Field::new("opt", DataType::Utf8, true),
],
md,
);
let id = Int64Array::from(vec![1, 2, 3]);
let name = StringArray::from(vec!["alice", "bob", "carol"]);
let active = BooleanArray::from(vec![true, false, true]);
let mut tags_builder = ListBuilder::new(Int32Builder::new()).with_field(item_field);
tags_builder.values().append_value(1);
tags_builder.values().append_value(2);
tags_builder.append(true);
tags_builder.append(true);
tags_builder.values().append_value(3);
tags_builder.append(true);
let tags = tags_builder.finish();
let opt = StringArray::from(vec![Some("x"), None, Some("z")]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(id) as ArrayRef,
Arc::new(name) as ArrayRef,
Arc::new(active) as ArrayRef,
Arc::new(tags) as ArrayRef,
Arc::new(opt) as ArrayRef,
],
)?;
Ok((schema, batch, avro_schema))
}
#[test]
fn test_row_writer_matches_stream_writer_soe() -> Result<(), AvroError> {
let schema = make_encoder_schema();
let batch = make_encoder_batch(&schema);
let mut stream = AvroStreamWriter::new(Vec::<u8>::new(), schema.clone())?;
stream.write(&batch)?;
stream.finish()?;
let stream_bytes = stream.into_inner();
let mut row_writer = WriterBuilder::new(schema).build_encoder::<AvroSoeFormat>()?;
row_writer.encode(&batch)?;
let rows = row_writer.flush();
let row_bytes: Vec<u8> = rows.bytes().to_vec();
assert_eq!(stream_bytes, row_bytes);
Ok(())
}
#[test]
fn test_row_writer_flush_clears_buffer() -> Result<(), AvroError> {
let schema = make_encoder_schema();
let batch = make_encoder_batch(&schema);
let mut row_writer = WriterBuilder::new(schema).build_encoder::<AvroSoeFormat>()?;
row_writer.encode(&batch)?;
assert_eq!(row_writer.buffered_len(), batch.num_rows());
let out1 = row_writer.flush();
assert_eq!(out1.len(), batch.num_rows());
assert_eq!(row_writer.buffered_len(), 0);
let out2 = row_writer.flush();
assert_eq!(out2.len(), 0);
Ok(())
}
#[test]
fn test_row_writer_roundtrip_decoder_soe_real_avro_data() -> Result<(), AvroError> {
let (schema, batch, avro_schema) = make_real_avro_schema_and_batch()?;
let mut store = SchemaStore::new();
store.register(avro_schema.clone())?;
let mut row_writer = WriterBuilder::new(schema).build_encoder::<AvroSoeFormat>()?;
row_writer.encode(&batch)?;
let rows = row_writer.flush();
let mut decoder = ReaderBuilder::new()
.with_writer_schema_store(store)
.with_batch_size(1024)
.build_decoder()?;
for row in rows.iter() {
let consumed = decoder.decode(row.as_ref())?;
assert_eq!(
consumed,
row.len(),
"decoder should consume the full row frame"
);
}
let out = decoder.flush()?.expect("decoded batch");
let expected = pretty_format_batches(std::slice::from_ref(&batch))?.to_string();
let actual = pretty_format_batches(&[out])?.to_string();
assert_eq!(expected, actual);
Ok(())
}
#[test]
fn test_row_writer_roundtrip_decoder_soe_streaming_chunks() -> Result<(), AvroError> {
let (schema, batch, avro_schema) = make_real_avro_schema_and_batch()?;
let mut store = SchemaStore::new();
store.register(avro_schema.clone())?;
let mut row_writer = WriterBuilder::new(schema).build_encoder::<AvroSoeFormat>()?;
row_writer.encode(&batch)?;
let rows = row_writer.flush();
let mut stream: Vec<u8> = Vec::new();
let mut boundaries: Vec<usize> = Vec::with_capacity(rows.len() + 1);
boundaries.push(0usize);
for row in rows.iter() {
stream.extend_from_slice(row.as_ref());
boundaries.push(stream.len());
}
let mut decoder = ReaderBuilder::new()
.with_writer_schema_store(store)
.with_batch_size(1024)
.build_decoder()?;
let mut buffered = BytesMut::new();
let chunk_rows = [1usize, 2, 3, 1, 4, 2];
let mut row_idx = 0usize;
let mut i = 0usize;
let n_rows = rows.len();
while row_idx < n_rows {
let take = chunk_rows[i % chunk_rows.len()];
i += 1;
let end_row = (row_idx + take).min(n_rows);
let byte_start = boundaries[row_idx];
let byte_end = boundaries[end_row];
buffered.extend_from_slice(&stream[byte_start..byte_end]);
loop {
let consumed = decoder.decode(&buffered)?;
if consumed == 0 {
break;
}
let _ = buffered.split_to(consumed);
}
assert!(
buffered.is_empty(),
"expected decoder to consume the entire frame-aligned chunk"
);
row_idx = end_row;
}
let out = decoder.flush()?.expect("decoded batch");
let expected = pretty_format_batches(std::slice::from_ref(&batch))?.to_string();
let actual = pretty_format_batches(&[out])?.to_string();
assert_eq!(expected, actual);
Ok(())
}
#[test]
fn test_row_writer_roundtrip_decoder_confluent_wire_format_id() -> Result<(), AvroError> {
let (schema, batch, avro_schema) = make_real_avro_schema_and_batch()?;
let schema_id: u32 = 42;
let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::Id);
store.set(Fingerprint::Id(schema_id), avro_schema.clone())?;
let mut row_writer = WriterBuilder::new(schema)
.with_fingerprint_strategy(FingerprintStrategy::Id(schema_id))
.build_encoder::<AvroSoeFormat>()?;
row_writer.encode(&batch)?;
let rows = row_writer.flush();
let mut decoder = ReaderBuilder::new()
.with_writer_schema_store(store)
.with_batch_size(1024)
.build_decoder()?;
for row in rows.iter() {
let consumed = decoder.decode(row.as_ref())?;
assert_eq!(consumed, row.len());
}
let out = decoder.flush()?.expect("decoded batch");
let expected = pretty_format_batches(std::slice::from_ref(&batch))?.to_string();
let actual = pretty_format_batches(&[out])?.to_string();
assert_eq!(expected, actual);
Ok(())
}
#[test]
fn test_encoder_encode_batches_flush_and_encoded_rows_methods_with_avro_binary_format()
-> Result<(), AvroError> {
use crate::writer::format::AvroBinaryFormat;
use arrow_array::{ArrayRef, Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
let schema_ref = Arc::new(schema.clone());
let batch1 = RecordBatch::try_new(
schema_ref.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef,
Arc::new(Int32Array::from(vec![10, 20, 30])) as ArrayRef,
],
)?;
let batch2 = RecordBatch::try_new(
schema_ref,
vec![
Arc::new(Int32Array::from(vec![4, 5])) as ArrayRef,
Arc::new(Int32Array::from(vec![40, 50])) as ArrayRef,
],
)?;
let mut encoder = WriterBuilder::new(schema).build_encoder::<AvroBinaryFormat>()?;
let empty = Encoder::flush(&mut encoder);
assert_eq!(EncodedRows::len(&empty), 0);
assert!(EncodedRows::is_empty(&empty));
assert_eq!(EncodedRows::bytes(&empty).as_ref(), &[] as &[u8]);
assert_eq!(EncodedRows::offsets(&empty), &[0usize]);
assert_eq!(EncodedRows::iter(&empty).count(), 0);
let empty_vecs: Vec<Vec<u8>> = empty.iter().map(|b| b.to_vec()).collect();
assert!(empty_vecs.is_empty());
let batches = vec![batch1, batch2];
Encoder::encode_batches(&mut encoder, &batches)?;
assert_eq!(encoder.buffered_len(), 5);
let rows = Encoder::flush(&mut encoder);
assert_eq!(
encoder.buffered_len(),
0,
"Encoder::flush should reset the internal offsets"
);
assert_eq!(EncodedRows::len(&rows), 5);
assert!(!EncodedRows::is_empty(&rows));
let expected_offsets: &[usize] = &[0, 2, 4, 6, 8, 10];
assert_eq!(EncodedRows::offsets(&rows), expected_offsets);
let expected_rows: Vec<Vec<u8>> = vec![
vec![2, 20],
vec![4, 40],
vec![6, 60],
vec![8, 80],
vec![10, 100],
];
let expected_stream: Vec<u8> = expected_rows.concat();
assert_eq!(
EncodedRows::bytes(&rows).as_ref(),
expected_stream.as_slice()
);
for (i, expected) in expected_rows.iter().enumerate() {
assert_eq!(EncodedRows::row(&rows, i)?.as_ref(), expected.as_slice());
}
let iter_rows: Vec<Vec<u8>> = EncodedRows::iter(&rows).map(|b| b.to_vec()).collect();
assert_eq!(iter_rows, expected_rows);
let recreated = EncodedRows::new(
EncodedRows::bytes(&rows).clone(),
EncodedRows::offsets(&rows).to_vec(),
);
assert_eq!(EncodedRows::len(&recreated), EncodedRows::len(&rows));
assert_eq!(EncodedRows::bytes(&recreated), EncodedRows::bytes(&rows));
assert_eq!(
EncodedRows::offsets(&recreated),
EncodedRows::offsets(&rows)
);
let rec_vecs: Vec<Vec<u8>> = recreated.iter().map(|b| b.to_vec()).collect();
assert_eq!(rec_vecs, iter_rows);
let empty_again = Encoder::flush(&mut encoder);
assert!(EncodedRows::is_empty(&empty_again));
Ok(())
}
#[test]
fn test_writer_builder_build_rejects_avro_binary_format() {
use crate::writer::format::AvroBinaryFormat;
use arrow_schema::{DataType, Field, Schema};
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let err = WriterBuilder::new(schema)
.build::<_, AvroBinaryFormat>(Vec::<u8>::new())
.unwrap_err();
match err {
AvroError::InvalidArgument(msg) => assert_eq!(
msg,
"AvroBinaryFormat is only supported with Encoder, use build_encoder instead"
),
other => panic!("expected InvalidArgumentError, got {:?}", other),
}
}
#[test]
fn test_row_encoder_avro_binary_format_roundtrip_decoder_with_soe_framing()
-> Result<(), AvroError> {
use crate::writer::format::AvroBinaryFormat;
let (schema, batch, avro_schema) = make_real_avro_schema_and_batch()?;
let batches: Vec<RecordBatch> = vec![batch.clone(), batch.slice(1, 2)];
let expected = arrow::compute::concat_batches(&batch.schema(), &batches)?;
let mut binary_encoder =
WriterBuilder::new(schema.clone()).build_encoder::<AvroBinaryFormat>()?;
binary_encoder.encode_batches(&batches)?;
let binary_rows = binary_encoder.flush();
assert_eq!(
binary_rows.len(),
expected.num_rows(),
"binary encoder row count mismatch"
);
let mut soe_encoder = WriterBuilder::new(schema).build_encoder::<AvroSoeFormat>()?;
soe_encoder.encode_batches(&batches)?;
let soe_rows = soe_encoder.flush();
assert_eq!(
soe_rows.len(),
binary_rows.len(),
"SOE vs binary row count mismatch"
);
let mut store = SchemaStore::new(); let fp = store.register(avro_schema)?;
let fp_le_bytes = match fp {
Fingerprint::Rabin(v) => v.to_le_bytes(),
other => panic!("expected Rabin fingerprint from SchemaStore::new(), got {other:?}"),
};
const SOE_MAGIC: [u8; 2] = [0xC3, 0x01];
const SOE_PREFIX_LEN: usize = 2 + 8;
for i in 0..binary_rows.len() {
let body = binary_rows.row(i)?;
let soe = soe_rows.row(i)?;
assert!(
soe.len() >= SOE_PREFIX_LEN,
"expected SOE row to include prefix"
);
assert_eq!(&soe.as_ref()[..2], &SOE_MAGIC);
assert_eq!(&soe.as_ref()[2..SOE_PREFIX_LEN], &fp_le_bytes);
assert_eq!(
&soe.as_ref()[SOE_PREFIX_LEN..],
body.as_ref(),
"SOE body bytes differ from AvroBinaryFormat body bytes (row {i})"
);
}
let mut decoder = ReaderBuilder::new()
.with_writer_schema_store(store)
.with_batch_size(1024)
.build_decoder()?;
for body in binary_rows.iter() {
let mut framed = Vec::with_capacity(SOE_PREFIX_LEN + body.len());
framed.extend_from_slice(&SOE_MAGIC);
framed.extend_from_slice(&fp_le_bytes);
framed.extend_from_slice(body.as_ref());
let consumed = decoder.decode(&framed)?;
assert_eq!(
consumed,
framed.len(),
"decoder should consume the full SOE-framed message"
);
}
let out = decoder.flush()?.expect("expected a decoded RecordBatch");
let expected_str = pretty_format_batches(&[expected])?.to_string();
let actual_str = pretty_format_batches(&[out])?.to_string();
assert_eq!(expected_str, actual_str);
Ok(())
}
#[test]
fn test_row_encoder_avro_binary_format_roundtrip_decoder_streaming_chunks()
-> Result<(), AvroError> {
use crate::writer::format::AvroBinaryFormat;
let (schema, batch, avro_schema) = make_real_avro_schema_and_batch()?;
let mut encoder = WriterBuilder::new(schema).build_encoder::<AvroBinaryFormat>()?;
encoder.encode(&batch)?;
let rows = encoder.flush();
let mut store = SchemaStore::new();
let fp = store.register(avro_schema)?;
let fp_le_bytes = match fp {
Fingerprint::Rabin(v) => v.to_le_bytes(),
other => panic!("expected Rabin fingerprint from SchemaStore::new(), got {other:?}"),
};
const SOE_MAGIC: [u8; 2] = [0xC3, 0x01];
const SOE_PREFIX_LEN: usize = 2 + 8;
let mut stream: Vec<u8> = Vec::new();
for body in rows.iter() {
let msg_len: u32 = (SOE_PREFIX_LEN + body.len())
.try_into()
.expect("message length must fit in u32");
stream.extend_from_slice(&msg_len.to_le_bytes());
stream.extend_from_slice(&SOE_MAGIC);
stream.extend_from_slice(&fp_le_bytes);
stream.extend_from_slice(body.as_ref());
}
let mut decoder = ReaderBuilder::new()
.with_writer_schema_store(store)
.with_batch_size(1024)
.build_decoder()?;
let chunk_sizes = [1usize, 2, 3, 5, 8, 13, 21, 34];
let mut pos = 0usize;
let mut i = 0usize;
let mut buffered = BytesMut::new();
let mut decoded_frames = 0usize;
while pos < stream.len() {
let take = chunk_sizes[i % chunk_sizes.len()];
i += 1;
let end = (pos + take).min(stream.len());
buffered.extend_from_slice(&stream[pos..end]);
pos = end;
loop {
if buffered.len() < 4 {
break;
}
let msg_len =
u32::from_le_bytes([buffered[0], buffered[1], buffered[2], buffered[3]])
as usize;
if buffered.len() < 4 + msg_len {
break;
}
let frame = buffered.split_to(4 + msg_len);
let payload = &frame[4..];
let consumed = decoder.decode(payload)?;
assert_eq!(
consumed,
payload.len(),
"decoder should consume the full SOE-framed message"
);
decoded_frames += 1;
}
}
assert!(
buffered.is_empty(),
"expected transport framer to consume all bytes; leftover = {}",
buffered.len()
);
assert_eq!(
decoded_frames,
rows.len(),
"expected to decode exactly one frame per encoded row"
);
let out = decoder.flush()?.expect("expected decoded RecordBatch");
let expected_str = pretty_format_batches(std::slice::from_ref(&batch))?.to_string();
let actual_str = pretty_format_batches(&[out])?.to_string();
assert_eq!(expected_str, actual_str);
Ok(())
}
fn roundtrip_ocf(batch: &RecordBatch) -> Result<RecordBatch, AvroError> {
let schema = batch.schema();
let mut buffer = Vec::<u8>::new();
let mut writer = AvroWriter::new(&mut buffer, schema.as_ref().clone())?;
writer.write(batch)?;
writer.finish()?;
drop(writer);
let reader = ReaderBuilder::new()
.build(Cursor::new(buffer))
.expect("build reader for roundtrip OCF");
let avro_schema_json = reader
.avro_header()
.get(SCHEMA_METADATA_KEY)
.map(|raw| std::str::from_utf8(raw).expect("valid UTF-8").to_string());
let arrow_schema = reader.schema();
let rt_schema = if let Some(json) = avro_schema_json {
let mut metadata = arrow_schema.metadata().clone();
metadata.insert(SCHEMA_METADATA_KEY.to_string(), json);
Arc::new(Schema::new_with_metadata(
arrow_schema.fields().clone(),
metadata,
))
} else {
arrow_schema
};
let rt_batches: Vec<RecordBatch> = reader.collect::<Result<Vec<_>, _>>()?;
Ok(arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip"))
}
#[cfg(feature = "avro_custom_types")]
fn assert_round_trip(array: ArrayRef) {
assert_round_trip_widened(array.clone(), array);
}
fn assert_round_trip_widened(input: ArrayRef, expected: ArrayRef) {
let schema = Schema::new(vec![Field::new("val", input.data_type().clone(), true)]);
let batch =
RecordBatch::try_new(Arc::new(schema), vec![input]).expect("failed to create batch");
let roundtrip = roundtrip_ocf(&batch).expect("roundtrip failed");
assert_eq!(
roundtrip.column(0).data_type(),
expected.data_type(),
"output data type mismatch"
);
assert_eq!(
roundtrip.column(0).to_data(),
expected.to_data(),
"output data mismatch"
);
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_roundtrip_int8_custom_types() {
assert_round_trip(Arc::new(Int8Array::from(vec![
Some(i8::MIN),
Some(-1),
Some(0),
None,
Some(1),
Some(i8::MAX),
])));
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_roundtrip_int8_no_custom_widens_to_int32() {
assert_round_trip_widened(
Arc::new(Int8Array::from(vec![
Some(i8::MIN),
Some(-1),
Some(0),
None,
Some(1),
Some(i8::MAX),
])),
Arc::new(Int32Array::from(vec![
Some(i8::MIN as i32),
Some(-1),
Some(0),
None,
Some(1),
Some(i8::MAX as i32),
])),
);
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_roundtrip_int16_custom_types() {
assert_round_trip(Arc::new(Int16Array::from(vec![
Some(i16::MIN),
Some(-1),
Some(0),
None,
Some(1),
Some(i16::MAX),
])));
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_roundtrip_int16_no_custom_widens_to_int32() {
assert_round_trip_widened(
Arc::new(Int16Array::from(vec![
Some(i16::MIN),
Some(-1),
Some(0),
None,
Some(1),
Some(i16::MAX),
])),
Arc::new(Int32Array::from(vec![
Some(i16::MIN as i32),
Some(-1),
Some(0),
None,
Some(1),
Some(i16::MAX as i32),
])),
);
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_roundtrip_uint8_custom_types() {
assert_round_trip(Arc::new(UInt8Array::from(vec![
Some(0u8),
Some(1),
None,
Some(127),
Some(u8::MAX),
])));
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_roundtrip_uint8_no_custom_widens_to_int32() {
assert_round_trip_widened(
Arc::new(UInt8Array::from(vec![
Some(0u8),
Some(1),
None,
Some(127),
Some(u8::MAX),
])),
Arc::new(Int32Array::from(vec![
Some(0i32),
Some(1),
None,
Some(127),
Some(u8::MAX as i32),
])),
);
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_roundtrip_uint16_custom_types() {
assert_round_trip(Arc::new(UInt16Array::from(vec![
Some(0u16),
Some(1),
None,
Some(32767),
Some(u16::MAX),
])));
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_roundtrip_uint16_no_custom_widens_to_int32() {
assert_round_trip_widened(
Arc::new(UInt16Array::from(vec![
Some(0u16),
Some(1),
None,
Some(32767),
Some(u16::MAX),
])),
Arc::new(Int32Array::from(vec![
Some(0i32),
Some(1),
None,
Some(32767),
Some(u16::MAX as i32),
])),
);
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_roundtrip_uint32_custom_types() {
assert_round_trip(Arc::new(UInt32Array::from(vec![
Some(0u32),
Some(1),
None,
Some(i32::MAX as u32),
Some(u32::MAX),
])));
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_roundtrip_uint32_no_custom_widens_to_int64() {
assert_round_trip_widened(
Arc::new(UInt32Array::from(vec![
Some(0u32),
Some(1),
None,
Some(i32::MAX as u32),
Some(u32::MAX),
])),
Arc::new(Int64Array::from(vec![
Some(0i64),
Some(1),
None,
Some(i32::MAX as i64),
Some(u32::MAX as i64),
])),
);
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_roundtrip_uint64_custom_types() {
assert_round_trip(Arc::new(UInt64Array::from(vec![
Some(0u64),
Some(1),
None,
Some(i64::MAX as u64),
Some(u64::MAX),
])));
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_roundtrip_uint64_no_custom_widens_to_int64() {
assert_round_trip_widened(
Arc::new(UInt64Array::from(vec![
Some(0u64),
Some(1),
None,
Some(i64::MAX as u64),
])),
Arc::new(Int64Array::from(vec![
Some(0i64),
Some(1),
None,
Some(i64::MAX),
])),
);
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_roundtrip_uint64_overflow_errors_without_custom() {
use arrow_array::UInt64Array;
let schema = Schema::new(vec![Field::new("val", DataType::UInt64, false)]);
let values: Vec<u64> = vec![u64::MAX];
let array = UInt64Array::from(values);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array) as ArrayRef])
.expect("create batch");
let result = roundtrip_ocf(&batch);
assert!(
result.is_err(),
"Expected error when encoding UInt64 > i64::MAX without avro_custom_types"
);
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_roundtrip_float16_custom_types() {
assert_round_trip(Arc::new(Float16Array::from(vec![
Some(f16::ZERO),
Some(f16::ONE),
None,
Some(f16::NEG_ONE),
Some(f16::MAX),
Some(f16::MIN),
])));
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_roundtrip_float16_no_custom_widens_to_float32() {
assert_round_trip_widened(
Arc::new(Float16Array::from(vec![
Some(f16::ZERO),
Some(f16::ONE),
None,
Some(f16::NEG_ONE),
])),
Arc::new(Float32Array::from(vec![
Some(0.0f32),
Some(1.0),
None,
Some(-1.0),
])),
);
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_roundtrip_date64_custom_types() {
assert_round_trip(Arc::new(Date64Array::from(vec![
Some(0i64),
Some(86_400_000),
None,
Some(1_609_459_200_000),
])));
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_roundtrip_date64_no_custom_as_timestamp_millis() {
assert_round_trip_widened(
Arc::new(Date64Array::from(vec![
Some(0i64),
Some(86_400_000),
None,
Some(1_609_459_200_000),
])),
Arc::new(TimestampMillisecondArray::from(vec![
Some(0i64),
Some(86_400_000),
None,
Some(1_609_459_200_000),
])),
);
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_roundtrip_time64_nanosecond_custom_types() {
assert_round_trip(Arc::new(Time64NanosecondArray::from(vec![
Some(0i64),
Some(1_000_000_000),
None,
Some(86_399_999_999_999),
])));
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_roundtrip_time64_nanos_no_custom_truncates_to_micros() {
assert_round_trip_widened(
Arc::new(Time64NanosecondArray::from(vec![
Some(0i64),
Some(1_000_000_000),
None,
Some(86_399_999_000_000),
])),
Arc::new(Time64MicrosecondArray::from(vec![
Some(0i64),
Some(1_000_000),
None,
Some(86_399_999_000),
])),
);
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_roundtrip_time32_second_custom_types() {
assert_round_trip(Arc::new(Time32SecondArray::from(vec![
Some(0i32),
Some(3600),
None,
Some(86399),
])));
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_roundtrip_time32_second_no_custom_scales_to_millis() {
assert_round_trip_widened(
Arc::new(Time32SecondArray::from(vec![
Some(0i32),
Some(3600),
None,
Some(86399),
])),
Arc::new(Time32MillisecondArray::from(vec![
Some(0i32),
Some(3_600_000),
None,
Some(86_399_000),
])),
);
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_roundtrip_timestamp_second_custom_types() {
assert_round_trip(Arc::new(
TimestampSecondArray::from(vec![Some(0i64), Some(1609459200), None, Some(1735689600)])
.with_timezone("+00:00"),
));
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_roundtrip_timestamp_second_no_custom_scales_to_millis() {
assert_round_trip_widened(
Arc::new(
TimestampSecondArray::from(vec![
Some(0i64),
Some(1609459200),
None,
Some(1735689600),
])
.with_timezone("+00:00"),
),
Arc::new(
TimestampMillisecondArray::from(vec![
Some(0i64),
Some(1_609_459_200_000),
None,
Some(1_735_689_600_000),
])
.with_timezone("+00:00"),
),
);
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_roundtrip_interval_year_month_custom_types() {
assert_round_trip(Arc::new(IntervalYearMonthArray::from(vec![
Some(0i32),
Some(12),
None,
Some(-6),
Some(25),
])));
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_roundtrip_interval_year_month_no_custom() {
assert_round_trip_widened(
Arc::new(IntervalYearMonthArray::from(vec![
Some(0i32),
Some(12),
None,
Some(25),
])),
Arc::new(IntervalMonthDayNanoArray::from(vec![
Some(IntervalMonthDayNano::new(0, 0, 0)),
Some(IntervalMonthDayNano::new(12, 0, 0)),
None,
Some(IntervalMonthDayNano::new(25, 0, 0)),
])),
);
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_roundtrip_interval_day_time_custom_types() {
assert_round_trip(Arc::new(IntervalDayTimeArray::from(vec![
Some(IntervalDayTime::new(0, 0)),
Some(IntervalDayTime::new(1, 1000)),
None,
Some(IntervalDayTime::new(30, 3600000)),
])));
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_roundtrip_interval_day_time_no_custom() {
assert_round_trip_widened(
Arc::new(IntervalDayTimeArray::from(vec![
Some(IntervalDayTime::new(0, 0)),
Some(IntervalDayTime::new(1, 1000)),
None,
Some(IntervalDayTime::new(30, 3600000)),
])),
Arc::new(IntervalMonthDayNanoArray::from(vec![
Some(IntervalMonthDayNano::new(0, 0, 0)),
Some(IntervalMonthDayNano::new(0, 1, 1_000_000_000)),
None,
Some(IntervalMonthDayNano::new(0, 30, 3_600_000_000_000)),
])),
);
}
#[cfg(feature = "avro_custom_types")]
#[test]
fn test_roundtrip_interval_month_day_nano_custom_types() {
assert_round_trip(Arc::new(IntervalMonthDayNanoArray::from(vec![
Some(IntervalMonthDayNano::new(0, 0, 0)),
Some(IntervalMonthDayNano::new(1, 2, 3)),
None,
Some(IntervalMonthDayNano::new(-4, -5, -6)),
])));
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn test_roundtrip_interval_month_day_nano_no_custom() {
assert_round_trip_widened(
Arc::new(IntervalMonthDayNanoArray::from(vec![
Some(IntervalMonthDayNano::new(0, 0, 0)),
Some(IntervalMonthDayNano::new(1, 2, 3_000_000)),
None,
Some(IntervalMonthDayNano::new(4, 5, 6_000_000)),
])),
Arc::new(IntervalMonthDayNanoArray::from(vec![
Some(IntervalMonthDayNano::new(0, 0, 0)),
Some(IntervalMonthDayNano::new(1, 2, 3_000_000)),
None,
Some(IntervalMonthDayNano::new(4, 5, 6_000_000)),
])),
);
}
fn schemas_equal_ignoring_metadata(left: &Schema, right: &Schema) -> bool {
if left.fields().len() != right.fields().len() {
return false;
}
for (l, r) in left.fields().iter().zip(right.fields().iter()) {
if l.name() != r.name()
|| l.data_type() != r.data_type()
|| l.is_nullable() != r.is_nullable()
{
return false;
}
}
true
}
fn avro_field_type<'a>(avro_schema: &'a Value, name: &str) -> &'a Value {
let fields = avro_schema
.get("fields")
.and_then(|v| v.as_array())
.expect("avro schema has 'fields' array");
fields
.iter()
.find(|f| f.get("name").and_then(|n| n.as_str()) == Some(name))
.unwrap_or_else(|| panic!("avro schema missing field '{name}'"))
.get("type")
.expect("field has 'type'")
}
#[test]
fn e2e_types_and_schema_alignment() -> Result<(), AvroError> {
let i8_values: Vec<Option<i8>> = vec![Some(i8::MIN), Some(-1), Some(i8::MAX)];
let i16_values: Vec<Option<i16>> = vec![Some(i16::MIN), Some(-1), Some(i16::MAX)];
let u8_values: Vec<Option<u8>> = vec![Some(0), Some(1), Some(u8::MAX)];
let u16_values: Vec<Option<u16>> = vec![Some(0), Some(1), Some(u16::MAX)];
let u32_values: Vec<Option<u32>> = vec![Some(0), Some(1), Some(u32::MAX)];
let u64_values: Vec<Option<u64>> = if cfg!(feature = "avro_custom_types") {
vec![Some(0), Some(i64::MAX as u64), Some((i64::MAX as u64) + 1)]
} else {
vec![Some(0), Some((i64::MAX as u64) - 1), Some(i64::MAX as u64)]
};
let f16_values: Vec<Option<f16>> = vec![
Some(f16::from_f32(1.5)),
Some(f16::from_f32(-2.0)),
Some(f16::from_f32(0.0)),
];
let date64_values: Vec<Option<i64>> = vec![Some(-86_400_000), Some(0), Some(86_400_000)];
let time32s_values: Vec<Option<i32>> = vec![Some(0), Some(1), Some(86_399)];
let time64ns_values: Vec<Option<i64>> = vec![
Some(0),
Some(1_234_567_890), Some(86_399_000_000_123_i64), ];
let ts_s_local_values: Vec<Option<i64>> = vec![Some(-1), Some(0), Some(1)];
let ts_s_utc_values: Vec<Option<i64>> = vec![Some(1), Some(2), Some(3)];
let iv_ym_values: Vec<Option<i32>> = if cfg!(feature = "avro_custom_types") {
vec![Some(0), Some(-6), Some(25)]
} else {
vec![Some(0), Some(12), Some(25)]
};
let iv_dt_values: Vec<Option<IntervalDayTime>> = if cfg!(feature = "avro_custom_types") {
vec![
Some(IntervalDayTime::new(0, 0)),
Some(IntervalDayTime::new(1, 1000)),
Some(IntervalDayTime::new(-1, -1000)),
]
} else {
vec![
Some(IntervalDayTime::new(0, 0)),
Some(IntervalDayTime::new(1, 1000)),
Some(IntervalDayTime::new(30, 3_600_000)),
]
};
let iv_mdn_values: Vec<Option<IntervalMonthDayNano>> =
if cfg!(feature = "avro_custom_types") {
vec![
Some(IntervalMonthDayNano::new(0, 0, 0)),
Some(IntervalMonthDayNano::new(1, 2, 3)), Some(IntervalMonthDayNano::new(-1, -2, -3)), ]
} else {
vec![
Some(IntervalMonthDayNano::new(0, 0, 0)),
Some(IntervalMonthDayNano::new(1, 2, 3_000_000)), Some(IntervalMonthDayNano::new(10, 20, 30_000_000_000)), ]
};
let schema = Schema::new(vec![
Field::new("i8", DataType::Int8, false),
Field::new("i16", DataType::Int16, false),
Field::new("u8", DataType::UInt8, false),
Field::new("u16", DataType::UInt16, false),
Field::new("u32", DataType::UInt32, false),
Field::new("u64", DataType::UInt64, false),
Field::new("f16", DataType::Float16, false),
Field::new("date64", DataType::Date64, false),
Field::new("time32s", DataType::Time32(TimeUnit::Second), false),
Field::new("time64ns", DataType::Time64(TimeUnit::Nanosecond), false),
Field::new(
"ts_s_local",
DataType::Timestamp(TimeUnit::Second, None),
false,
),
Field::new(
"ts_s_utc",
DataType::Timestamp(TimeUnit::Second, Some("+00:00".into())),
false,
),
Field::new("iv_ym", DataType::Interval(IntervalUnit::YearMonth), false),
Field::new("iv_dt", DataType::Interval(IntervalUnit::DayTime), false),
Field::new(
"iv_mdn",
DataType::Interval(IntervalUnit::MonthDayNano),
false,
),
]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int8Array::from(i8_values.clone())) as ArrayRef,
Arc::new(Int16Array::from(i16_values.clone())) as ArrayRef,
Arc::new(UInt8Array::from(u8_values.clone())) as ArrayRef,
Arc::new(UInt16Array::from(u16_values.clone())) as ArrayRef,
Arc::new(UInt32Array::from(u32_values.clone())) as ArrayRef,
Arc::new(UInt64Array::from(u64_values.clone())) as ArrayRef,
Arc::new(Float16Array::from(f16_values.clone())) as ArrayRef,
Arc::new(Date64Array::from(date64_values.clone())) as ArrayRef,
Arc::new(Time32SecondArray::from(time32s_values.clone())) as ArrayRef,
Arc::new(Time64NanosecondArray::from(time64ns_values.clone())) as ArrayRef,
Arc::new(TimestampSecondArray::from(ts_s_local_values.clone())) as ArrayRef,
Arc::new(
TimestampSecondArray::from(ts_s_utc_values.clone()).with_timezone("+00:00"),
) as ArrayRef,
Arc::new(IntervalYearMonthArray::from(iv_ym_values.clone())) as ArrayRef,
Arc::new(IntervalDayTimeArray::from(iv_dt_values.clone())) as ArrayRef,
Arc::new(IntervalMonthDayNanoArray::from(iv_mdn_values.clone())) as ArrayRef,
],
)?;
let rt = roundtrip_ocf(&batch)?;
let rt_schema = rt.schema();
let avro_schema_json = rt_schema
.metadata()
.get(SCHEMA_METADATA_KEY)
.expect("avro.schema missing in round-tripped batch metadata");
let avro_schema: Value =
serde_json::from_str(avro_schema_json).expect("valid avro schema json");
let rt_arrow_schema = rt.schema();
if cfg!(feature = "avro_custom_types") {
assert!(
schemas_equal_ignoring_metadata(rt_arrow_schema.as_ref(), &schema),
"Schema fields mismatch.\nExpected: {:?}\nGot: {:?}",
schema,
rt_arrow_schema
);
for field_name in ["u64", "f16", "iv_ym", "iv_dt", "iv_mdn"] {
let field = rt_arrow_schema
.field_with_name(field_name)
.expect("field exists");
assert!(
field.metadata().get(AVRO_NAME_METADATA_KEY).is_some(),
"Field '{}' should have avro.name metadata",
field_name
);
}
} else {
let exp_schema = Schema::new(vec![
Field::new("i8", DataType::Int32, false),
Field::new("i16", DataType::Int32, false),
Field::new("u8", DataType::Int32, false),
Field::new("u16", DataType::Int32, false),
Field::new("u32", DataType::Int64, false),
Field::new("u64", DataType::Int64, false),
Field::new("f16", DataType::Float32, false),
Field::new(
"date64",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new("time32s", DataType::Time32(TimeUnit::Millisecond), false),
Field::new("time64ns", DataType::Time64(TimeUnit::Microsecond), false),
Field::new(
"ts_s_local",
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new(
"ts_s_utc",
DataType::Timestamp(TimeUnit::Millisecond, Some("+00:00".into())),
false,
),
Field::new(
"iv_ym",
DataType::Interval(IntervalUnit::MonthDayNano),
false,
),
Field::new(
"iv_dt",
DataType::Interval(IntervalUnit::MonthDayNano),
false,
),
Field::new(
"iv_mdn",
DataType::Interval(IntervalUnit::MonthDayNano),
false,
),
]);
assert!(
schemas_equal_ignoring_metadata(rt_arrow_schema.as_ref(), &exp_schema),
"Schema fields mismatch.\nExpected: {:?}\nGot: {:?}",
exp_schema,
rt_arrow_schema
);
for field_name in ["iv_ym", "iv_dt", "iv_mdn"] {
let field = rt_arrow_schema
.field_with_name(field_name)
.expect("field exists");
assert!(
field.metadata().get(AVRO_NAME_METADATA_KEY).is_some(),
"Field '{}' should have avro.name metadata",
field_name
);
}
}
if cfg!(feature = "avro_custom_types") {
assert_eq!(
avro_field_type(&avro_schema, "i8"),
&json!({"type":"int","logicalType":"arrow.int8"})
);
assert_eq!(
avro_field_type(&avro_schema, "i16"),
&json!({"type":"int","logicalType":"arrow.int16"})
);
assert_eq!(
avro_field_type(&avro_schema, "u8"),
&json!({"type":"int","logicalType":"arrow.uint8"})
);
assert_eq!(
avro_field_type(&avro_schema, "u16"),
&json!({"type":"int","logicalType":"arrow.uint16"})
);
assert_eq!(
avro_field_type(&avro_schema, "u32"),
&json!({"type":"long","logicalType":"arrow.uint32"})
);
assert_eq!(
avro_field_type(&avro_schema, "u64"),
&json!({"type":"fixed","name":"u64","size":8,"logicalType":"arrow.uint64"})
);
assert_eq!(
avro_field_type(&avro_schema, "f16"),
&json!({"type":"fixed","name":"f16","size":2,"logicalType":"arrow.float16"})
);
assert_eq!(
avro_field_type(&avro_schema, "date64"),
&json!({"type":"long","logicalType":"arrow.date64"})
);
assert_eq!(
avro_field_type(&avro_schema, "time32s"),
&json!({"type":"int","logicalType":"arrow.time32-second"})
);
assert_eq!(
avro_field_type(&avro_schema, "time64ns"),
&json!({"type":"long","logicalType":"arrow.time64-nanosecond"})
);
assert_eq!(
avro_field_type(&avro_schema, "ts_s_local"),
&json!({"type":"long","logicalType":"arrow.local-timestamp-second"})
);
assert_eq!(
avro_field_type(&avro_schema, "ts_s_utc"),
&json!({"type":"long","logicalType":"arrow.timestamp-second"})
);
assert_eq!(
avro_field_type(&avro_schema, "iv_ym"),
&json!({"type":"fixed","name":"iv_ym","size":4,"logicalType":"arrow.interval-year-month"})
);
assert_eq!(
avro_field_type(&avro_schema, "iv_dt"),
&json!({"type":"fixed","name":"iv_dt","size":8,"logicalType":"arrow.interval-day-time"})
);
assert_eq!(
avro_field_type(&avro_schema, "iv_mdn"),
&json!({"type":"fixed","name":"iv_mdn","size":16,"logicalType":"arrow.interval-month-day-nano"})
);
} else {
assert_eq!(avro_field_type(&avro_schema, "i8"), &json!("int"));
assert_eq!(avro_field_type(&avro_schema, "i16"), &json!("int"));
assert_eq!(avro_field_type(&avro_schema, "u8"), &json!("int"));
assert_eq!(avro_field_type(&avro_schema, "u16"), &json!("int"));
assert_eq!(avro_field_type(&avro_schema, "u32"), &json!("long"));
assert_eq!(avro_field_type(&avro_schema, "u64"), &json!("long"));
assert_eq!(avro_field_type(&avro_schema, "f16"), &json!("float"));
assert_eq!(
avro_field_type(&avro_schema, "date64"),
&json!({"type":"long","logicalType":"local-timestamp-millis"})
);
assert_eq!(
avro_field_type(&avro_schema, "time32s"),
&json!({"type":"int","logicalType":"time-millis"})
);
assert_eq!(
avro_field_type(&avro_schema, "time64ns"),
&json!({"type":"long","logicalType":"time-micros"})
);
assert_eq!(
avro_field_type(&avro_schema, "ts_s_local"),
&json!({"type":"long","logicalType":"local-timestamp-millis"})
);
assert_eq!(
avro_field_type(&avro_schema, "ts_s_utc"),
&json!({"type":"long","logicalType":"timestamp-millis"})
);
assert_eq!(
avro_field_type(&avro_schema, "iv_ym"),
&json!({"type":"fixed","name":"iv_ym","size":12,"logicalType":"duration"})
);
assert_eq!(
avro_field_type(&avro_schema, "iv_dt"),
&json!({"type":"fixed","name":"iv_dt","size":12,"logicalType":"duration"})
);
assert_eq!(
avro_field_type(&avro_schema, "iv_mdn"),
&json!({"type":"fixed","name":"iv_mdn","size":12,"logicalType":"duration"})
);
}
if cfg!(feature = "avro_custom_types") {
assert_eq!(
rt.column(0).as_ref(),
&Int8Array::from(i8_values) as &dyn Array
);
assert_eq!(
rt.column(1).as_ref(),
&Int16Array::from(i16_values) as &dyn Array
);
assert_eq!(
rt.column(2).as_ref(),
&UInt8Array::from(u8_values) as &dyn Array
);
assert_eq!(
rt.column(3).as_ref(),
&UInt16Array::from(u16_values) as &dyn Array
);
assert_eq!(
rt.column(4).as_ref(),
&UInt32Array::from(u32_values) as &dyn Array
);
assert_eq!(
rt.column(5).as_ref(),
&UInt64Array::from(u64_values) as &dyn Array
);
assert_eq!(
rt.column(6).as_ref(),
&Float16Array::from(f16_values) as &dyn Array
);
assert_eq!(
rt.column(7).as_ref(),
&Date64Array::from(date64_values) as &dyn Array
);
assert_eq!(
rt.column(8).as_ref(),
&Time32SecondArray::from(time32s_values) as &dyn Array
);
assert_eq!(
rt.column(9).as_ref(),
&Time64NanosecondArray::from(time64ns_values) as &dyn Array
);
assert_eq!(
rt.column(10).as_ref(),
&TimestampSecondArray::from(ts_s_local_values) as &dyn Array
);
assert_eq!(
rt.column(11).as_ref(),
&TimestampSecondArray::from(ts_s_utc_values).with_timezone("+00:00") as &dyn Array
);
assert_eq!(
rt.column(12).as_ref(),
&IntervalYearMonthArray::from(iv_ym_values) as &dyn Array
);
assert_eq!(
rt.column(13).as_ref(),
&IntervalDayTimeArray::from(iv_dt_values) as &dyn Array
);
assert_eq!(
rt.column(14).as_ref(),
&IntervalMonthDayNanoArray::from(iv_mdn_values) as &dyn Array
);
} else {
let exp_i8: Vec<Option<i32>> = i8_values.iter().map(|v| v.map(|x| x as i32)).collect();
let exp_i16: Vec<Option<i32>> =
i16_values.iter().map(|v| v.map(|x| x as i32)).collect();
let exp_u8: Vec<Option<i32>> = u8_values.iter().map(|v| v.map(|x| x as i32)).collect();
let exp_u16: Vec<Option<i32>> =
u16_values.iter().map(|v| v.map(|x| x as i32)).collect();
let exp_u32: Vec<Option<i64>> =
u32_values.iter().map(|v| v.map(|x| x as i64)).collect();
let exp_u64: Vec<Option<i64>> =
u64_values.iter().map(|v| v.map(|x| x as i64)).collect();
let exp_f16: Vec<Option<f32>> =
f16_values.iter().map(|v| v.map(|x| x.to_f32())).collect();
let exp_time32_ms: Vec<Option<i32>> = time32s_values
.iter()
.map(|v| v.map(|x| x.saturating_mul(1000)))
.collect();
let exp_time64_us: Vec<Option<i64>> = time64ns_values
.iter()
.map(|v| v.map(|x| x / 1000))
.collect();
let exp_ts_local_ms: Vec<Option<i64>> = ts_s_local_values
.iter()
.map(|v| v.map(|x| x * 1000))
.collect();
let exp_ts_utc_ms: Vec<Option<i64>> = ts_s_utc_values
.iter()
.map(|v| v.map(|x| x * 1000))
.collect();
let exp_iv_ym: Vec<Option<IntervalMonthDayNano>> = iv_ym_values
.iter()
.map(|v| v.map(|months| IntervalMonthDayNano::new(months, 0, 0)))
.collect();
let exp_iv_dt: Vec<Option<IntervalMonthDayNano>> = iv_dt_values
.iter()
.map(|v| {
v.map(|dt| {
IntervalMonthDayNano::new(0, dt.days, (dt.milliseconds as i64) * 1_000_000)
})
})
.collect();
assert_eq!(
rt.column(0).as_ref(),
&Int32Array::from(exp_i8) as &dyn Array
);
assert_eq!(
rt.column(1).as_ref(),
&Int32Array::from(exp_i16) as &dyn Array
);
assert_eq!(
rt.column(2).as_ref(),
&Int32Array::from(exp_u8) as &dyn Array
);
assert_eq!(
rt.column(3).as_ref(),
&Int32Array::from(exp_u16) as &dyn Array
);
assert_eq!(
rt.column(4).as_ref(),
&arrow_array::Int64Array::from(exp_u32) as &dyn Array
);
assert_eq!(
rt.column(5).as_ref(),
&arrow_array::Int64Array::from(exp_u64) as &dyn Array
);
assert_eq!(
rt.column(6).as_ref(),
&arrow_array::Float32Array::from(exp_f16) as &dyn Array
);
assert_eq!(
rt.column(7).as_ref(),
&TimestampMillisecondArray::from(date64_values) as &dyn Array
);
assert_eq!(
rt.column(8).as_ref(),
&Time32MillisecondArray::from(exp_time32_ms) as &dyn Array
);
assert_eq!(
rt.column(9).as_ref(),
&Time64MicrosecondArray::from(exp_time64_us) as &dyn Array
);
assert_eq!(
rt.column(10).as_ref(),
&TimestampMillisecondArray::from(exp_ts_local_ms) as &dyn Array
);
assert_eq!(
rt.column(11).as_ref(),
&TimestampMillisecondArray::from(exp_ts_utc_ms).with_timezone("+00:00")
as &dyn Array
);
assert_eq!(
rt.column(12).as_ref(),
&IntervalMonthDayNanoArray::from(exp_iv_ym) as &dyn Array
);
assert_eq!(
rt.column(13).as_ref(),
&IntervalMonthDayNanoArray::from(exp_iv_dt) as &dyn Array
);
assert_eq!(
rt.column(14).as_ref(),
&IntervalMonthDayNanoArray::from(iv_mdn_values) as &dyn Array
);
}
Ok(())
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn non_custom_uint64_overflow_errors() -> Result<(), AvroError> {
let schema = Schema::new(vec![Field::new("u64", DataType::UInt64, false)]);
let values: Vec<Option<u64>> = vec![Some((i64::MAX as u64) + 1)];
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(UInt64Array::from(values)) as ArrayRef],
)?;
let mut w = AvroWriter::new(Vec::<u8>::new(), schema)?;
let err = w
.write(&batch)
.expect_err("expected UInt64 overflow error when avro_custom_types is disabled");
match err {
AvroError::InvalidArgument(msg) => {
assert_eq!(
msg,
"UInt64 value 9223372036854775808 exceeds i64::MAX; enable avro_custom_types feature for full UInt64 support"
);
}
other => panic!("expected InvalidArgument, got {other:?}"),
}
Ok(())
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn non_custom_interval_year_month_negative_errors() -> Result<(), AvroError> {
let schema = Schema::new(vec![Field::new(
"iv_ym",
DataType::Interval(IntervalUnit::YearMonth),
false,
)]);
let values: Vec<Option<i32>> = vec![Some(-1)];
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(IntervalYearMonthArray::from(values)) as ArrayRef],
)?;
let mut w = AvroWriter::new(Vec::<u8>::new(), schema)?;
let err = w
.write(&batch)
.expect_err("expected negative Interval(YearMonth) error");
match err {
AvroError::InvalidArgument(msg) => {
assert_eq!(
msg,
"Avro 'duration' cannot encode negative months; enable `avro_custom_types` to round-trip signed Arrow Interval(YearMonth)"
);
}
other => panic!("expected InvalidArgument, got {other:?}"),
}
Ok(())
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn non_custom_interval_day_time_negative_errors() -> Result<(), AvroError> {
let schema = Schema::new(vec![Field::new(
"iv_dt",
DataType::Interval(IntervalUnit::DayTime),
false,
)]);
let values: Vec<Option<IntervalDayTime>> = vec![Some(IntervalDayTime::new(-1, 0))];
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(IntervalDayTimeArray::from(values)) as ArrayRef],
)?;
let mut w = AvroWriter::new(Vec::<u8>::new(), schema)?;
let err = w
.write(&batch)
.expect_err("expected negative Interval(DayTime) error");
match err {
AvroError::InvalidArgument(msg) => {
assert_eq!(
msg,
"Avro 'duration' cannot encode negative days or milliseconds; enable `avro_custom_types` to round-trip signed Arrow Interval(DayTime)"
);
}
other => panic!("expected InvalidArgument, got {other:?}"),
}
Ok(())
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn non_custom_interval_month_day_nano_negative_errors() -> Result<(), AvroError> {
let schema = Schema::new(vec![Field::new(
"iv_mdn",
DataType::Interval(IntervalUnit::MonthDayNano),
false,
)]);
let values: Vec<Option<IntervalMonthDayNano>> =
vec![Some(IntervalMonthDayNano::new(-1, 0, 0))];
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(IntervalMonthDayNanoArray::from(values)) as ArrayRef],
)?;
let mut w = AvroWriter::new(Vec::<u8>::new(), schema)?;
let err = w
.write(&batch)
.expect_err("expected negative Interval(MonthDayNano) error");
match err {
AvroError::InvalidArgument(msg) => {
assert_eq!(
msg,
"Avro 'duration' cannot encode negative months/days/nanoseconds; enable `avro_custom_types` to round-trip signed Arrow intervals"
);
}
other => panic!("expected InvalidArgument, got {other:?}"),
}
Ok(())
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn non_custom_interval_month_day_nano_sub_millis_errors() -> Result<(), AvroError> {
let schema = Schema::new(vec![Field::new(
"iv_mdn",
DataType::Interval(IntervalUnit::MonthDayNano),
false,
)]);
let values: Vec<Option<IntervalMonthDayNano>> =
vec![Some(IntervalMonthDayNano::new(0, 0, 1))];
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(IntervalMonthDayNanoArray::from(values)) as ArrayRef],
)?;
let mut w = AvroWriter::new(Vec::<u8>::new(), schema)?;
let err = w
.write(&batch)
.expect_err("expected sub-millisecond Interval(MonthDayNano) error");
match err {
AvroError::InvalidArgument(msg) => {
assert_eq!(
msg,
"Avro 'duration' requires whole milliseconds; nanoseconds must be divisible by 1_000_000 (enable `avro_custom_types` to preserve nanosecond intervals)"
);
}
other => panic!("expected InvalidArgument, got {other:?}"),
}
Ok(())
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn non_custom_time32_second_scaling_overflow_errors() -> Result<(), AvroError> {
let schema = Schema::new(vec![Field::new(
"time32s",
DataType::Time32(TimeUnit::Second),
false,
)]);
let values: Vec<Option<i32>> = vec![Some((i32::MAX / 1000) + 1)];
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Time32SecondArray::from(values)) as ArrayRef],
)?;
let mut w = AvroWriter::new(Vec::<u8>::new(), schema)?;
let err = w
.write(&batch)
.expect_err("expected time32 seconds->millis overflow error");
match err {
AvroError::InvalidArgument(msg) => {
assert_eq!(msg, "time32(secs) * 1000 overflowed");
}
other => panic!("expected InvalidArgument, got {other:?}"),
}
Ok(())
}
#[cfg(not(feature = "avro_custom_types"))]
#[test]
fn non_custom_timestamp_second_scaling_overflow_errors() -> Result<(), AvroError> {
let schema = Schema::new(vec![Field::new(
"ts_s_local",
DataType::Timestamp(TimeUnit::Second, None),
false,
)]);
let values: Vec<Option<i64>> = vec![Some((i64::MAX / 1000) + 1)];
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(TimestampSecondArray::from(values)) as ArrayRef],
)?;
let mut w = AvroWriter::new(Vec::<u8>::new(), schema)?;
let err = w
.write(&batch)
.expect_err("expected timestamp seconds->millis overflow error");
match err {
AvroError::InvalidArgument(msg) => {
assert_eq!(msg, "timestamp(secs) * 1000 overflowed");
}
other => panic!("expected InvalidArgument, got {other:?}"),
}
Ok(())
}
}