use super::arrow_array_reader::AvroArrowArrayReader;
use crate::arrow::datatypes::SchemaRef;
use crate::arrow::record_batch::RecordBatch;
use crate::error::Result;
use arrow::error::Result as ArrowResult;
use std::io::{Read, Seek};
use std::sync::Arc;
#[derive(Debug)]
pub struct ReaderBuilder {
schema: Option<SchemaRef>,
batch_size: usize,
projection: Option<Vec<String>>,
}
impl Default for ReaderBuilder {
fn default() -> Self {
Self {
schema: None,
batch_size: 1024,
projection: None,
}
}
}
impl ReaderBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn with_schema(mut self, schema: SchemaRef) -> Self {
self.schema = Some(schema);
self
}
pub fn read_schema(mut self) -> Self {
self.schema = None;
self
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
pub fn with_projection(mut self, projection: Vec<String>) -> Self {
self.projection = Some(projection);
self
}
pub fn build<'a, R>(self, source: R) -> Result<Reader<'a, R>>
where
R: Read + Seek,
{
let mut source = source;
let schema = match self.schema {
Some(schema) => schema,
None => Arc::new(super::read_avro_schema_from_reader(&mut source)?),
};
source.rewind()?;
Reader::try_new(source, schema, self.batch_size, self.projection)
}
}
pub struct Reader<'a, R: Read> {
array_reader: AvroArrowArrayReader<'a, R>,
schema: SchemaRef,
batch_size: usize,
}
impl<'a, R: Read> Reader<'a, R> {
pub fn try_new(
reader: R,
schema: SchemaRef,
batch_size: usize,
projection: Option<Vec<String>>,
) -> Result<Self> {
Ok(Self {
array_reader: AvroArrowArrayReader::try_new(
reader,
schema.clone(),
projection,
)?,
schema,
batch_size,
})
}
pub fn schema(&self) -> SchemaRef {
self.schema.clone()
}
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> ArrowResult<Option<RecordBatch>> {
self.array_reader.next_batch(self.batch_size)
}
}
impl<'a, R: Read> Iterator for Reader<'a, R> {
type Item = ArrowResult<RecordBatch>;
fn next(&mut self) -> Option<Self::Item> {
self.next().transpose()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arrow::array::*;
use crate::arrow::datatypes::{DataType, Field};
use arrow::datatypes::TimeUnit;
use std::fs::File;
fn build_reader(name: &str) -> Reader<File> {
let testdata = crate::test_util::arrow_test_data();
let filename = format!("{testdata}/avro/{name}");
let builder = ReaderBuilder::new().read_schema().with_batch_size(64);
builder.build(File::open(filename).unwrap()).unwrap()
}
fn get_col<'a, T: 'static>(
batch: &'a RecordBatch,
col: (usize, &Field),
) -> Option<&'a T> {
batch.column(col.0).as_any().downcast_ref::<T>()
}
#[test]
fn test_avro_basic() {
let mut reader = build_reader("alltypes_dictionary.avro");
let batch = reader.next().unwrap().unwrap();
assert_eq!(11, batch.num_columns());
assert_eq!(2, batch.num_rows());
let schema = reader.schema();
let batch_schema = batch.schema();
assert_eq!(schema, batch_schema);
let id = schema.column_with_name("id").unwrap();
assert_eq!(0, id.0);
assert_eq!(&DataType::Int32, id.1.data_type());
let col = get_col::<Int32Array>(&batch, id).unwrap();
assert_eq!(0, col.value(0));
assert_eq!(1, col.value(1));
let bool_col = schema.column_with_name("bool_col").unwrap();
assert_eq!(1, bool_col.0);
assert_eq!(&DataType::Boolean, bool_col.1.data_type());
let col = get_col::<BooleanArray>(&batch, bool_col).unwrap();
assert!(col.value(0));
assert!(!col.value(1));
let tinyint_col = schema.column_with_name("tinyint_col").unwrap();
assert_eq!(2, tinyint_col.0);
assert_eq!(&DataType::Int32, tinyint_col.1.data_type());
let col = get_col::<Int32Array>(&batch, tinyint_col).unwrap();
assert_eq!(0, col.value(0));
assert_eq!(1, col.value(1));
let smallint_col = schema.column_with_name("smallint_col").unwrap();
assert_eq!(3, smallint_col.0);
assert_eq!(&DataType::Int32, smallint_col.1.data_type());
let col = get_col::<Int32Array>(&batch, smallint_col).unwrap();
assert_eq!(0, col.value(0));
assert_eq!(1, col.value(1));
let int_col = schema.column_with_name("int_col").unwrap();
assert_eq!(4, int_col.0);
let col = get_col::<Int32Array>(&batch, int_col).unwrap();
assert_eq!(0, col.value(0));
assert_eq!(1, col.value(1));
assert_eq!(&DataType::Int32, int_col.1.data_type());
let col = get_col::<Int32Array>(&batch, int_col).unwrap();
assert_eq!(0, col.value(0));
assert_eq!(1, col.value(1));
let bigint_col = schema.column_with_name("bigint_col").unwrap();
assert_eq!(5, bigint_col.0);
let col = get_col::<Int64Array>(&batch, bigint_col).unwrap();
assert_eq!(0, col.value(0));
assert_eq!(10, col.value(1));
assert_eq!(&DataType::Int64, bigint_col.1.data_type());
let float_col = schema.column_with_name("float_col").unwrap();
assert_eq!(6, float_col.0);
let col = get_col::<Float32Array>(&batch, float_col).unwrap();
assert_eq!(0.0, col.value(0));
assert_eq!(1.1, col.value(1));
assert_eq!(&DataType::Float32, float_col.1.data_type());
let col = get_col::<Float32Array>(&batch, float_col).unwrap();
assert_eq!(0.0, col.value(0));
assert_eq!(1.1, col.value(1));
let double_col = schema.column_with_name("double_col").unwrap();
assert_eq!(7, double_col.0);
assert_eq!(&DataType::Float64, double_col.1.data_type());
let col = get_col::<Float64Array>(&batch, double_col).unwrap();
assert_eq!(0.0, col.value(0));
assert_eq!(10.1, col.value(1));
let date_string_col = schema.column_with_name("date_string_col").unwrap();
assert_eq!(8, date_string_col.0);
assert_eq!(&DataType::Binary, date_string_col.1.data_type());
let col = get_col::<BinaryArray>(&batch, date_string_col).unwrap();
assert_eq!("01/01/09".as_bytes(), col.value(0));
assert_eq!("01/01/09".as_bytes(), col.value(1));
let string_col = schema.column_with_name("string_col").unwrap();
assert_eq!(9, string_col.0);
assert_eq!(&DataType::Binary, string_col.1.data_type());
let col = get_col::<BinaryArray>(&batch, string_col).unwrap();
assert_eq!("0".as_bytes(), col.value(0));
assert_eq!("1".as_bytes(), col.value(1));
let timestamp_col = schema.column_with_name("timestamp_col").unwrap();
assert_eq!(10, timestamp_col.0);
assert_eq!(
&DataType::Timestamp(TimeUnit::Microsecond, None),
timestamp_col.1.data_type()
);
let col = get_col::<TimestampMicrosecondArray>(&batch, timestamp_col).unwrap();
assert_eq!(1230768000000000, col.value(0));
assert_eq!(1230768060000000, col.value(1));
}
}