use std::sync::Arc;
use crate::array::*;
use crate::compute::kernels::concat::concat;
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
#[derive(Clone, Debug, PartialEq)]
pub struct RecordBatch {
schema: SchemaRef,
columns: Vec<Arc<dyn Array>>,
}
impl RecordBatch {
pub fn try_new(schema: SchemaRef, columns: Vec<ArrayRef>) -> Result<Self> {
let options = RecordBatchOptions::default();
Self::validate_new_batch(&schema, columns.as_slice(), &options)?;
Ok(RecordBatch { schema, columns })
}
pub fn try_new_with_options(
schema: SchemaRef,
columns: Vec<ArrayRef>,
options: &RecordBatchOptions,
) -> Result<Self> {
Self::validate_new_batch(&schema, columns.as_slice(), options)?;
Ok(RecordBatch { schema, columns })
}
pub fn new_empty(schema: SchemaRef) -> Self {
let columns = schema
.fields()
.iter()
.map(|field| new_empty_array(field.data_type()))
.collect();
RecordBatch { schema, columns }
}
fn validate_new_batch(
schema: &SchemaRef,
columns: &[ArrayRef],
options: &RecordBatchOptions,
) -> Result<()> {
if columns.is_empty() {
return Err(ArrowError::InvalidArgumentError(
"at least one column must be defined to create a record batch"
.to_string(),
));
}
if schema.fields().len() != columns.len() {
return Err(ArrowError::InvalidArgumentError(format!(
"number of columns({}) must match number of fields({}) in schema",
columns.len(),
schema.fields().len(),
)));
}
let len = columns[0].data().len();
if options.match_field_names {
for (i, column) in columns.iter().enumerate() {
if column.len() != len {
return Err(ArrowError::InvalidArgumentError(
"all columns in a record batch must have the same length"
.to_string(),
));
}
if column.data_type() != schema.field(i).data_type() {
return Err(ArrowError::InvalidArgumentError(format!(
"column types must match schema types, expected {:?} but found {:?} at column index {}",
schema.field(i).data_type(),
column.data_type(),
i)));
}
}
} else {
for (i, column) in columns.iter().enumerate() {
if column.len() != len {
return Err(ArrowError::InvalidArgumentError(
"all columns in a record batch must have the same length"
.to_string(),
));
}
if !column
.data_type()
.equals_datatype(schema.field(i).data_type())
{
return Err(ArrowError::InvalidArgumentError(format!(
"column types must match schema types, expected {:?} but found {:?} at column index {}",
schema.field(i).data_type(),
column.data_type(),
i)));
}
}
}
Ok(())
}
pub fn schema(&self) -> SchemaRef {
self.schema.clone()
}
pub fn num_columns(&self) -> usize {
self.columns.len()
}
pub fn num_rows(&self) -> usize {
self.columns[0].data().len()
}
pub fn column(&self, index: usize) -> &ArrayRef {
&self.columns[index]
}
pub fn columns(&self) -> &[ArrayRef] {
&self.columns[..]
}
pub fn slice(&self, offset: usize, length: usize) -> RecordBatch {
if self.schema.fields().is_empty() {
assert!((offset + length) == 0);
return RecordBatch::new_empty(self.schema.clone());
}
assert!((offset + length) <= self.num_rows());
let columns = self
.columns()
.iter()
.map(|column| column.slice(offset, length))
.collect();
Self {
schema: self.schema.clone(),
columns,
}
}
pub fn try_from_iter<I, F>(value: I) -> Result<Self>
where
I: IntoIterator<Item = (F, ArrayRef)>,
F: AsRef<str>,
{
let iter = value.into_iter().map(|(field_name, array)| {
let nullable = array.null_count() > 0;
(field_name, array, nullable)
});
Self::try_from_iter_with_nullable(iter)
}
pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self>
where
I: IntoIterator<Item = (F, ArrayRef, bool)>,
F: AsRef<str>,
{
let (fields, columns) = value
.into_iter()
.map(|(field_name, array, nullable)| {
let field_name = field_name.as_ref();
let field = Field::new(field_name, array.data_type().clone(), nullable);
(field, array)
})
.unzip();
let schema = Arc::new(Schema::new(fields));
RecordBatch::try_new(schema, columns)
}
pub fn concat(schema: &SchemaRef, batches: &[Self]) -> Result<Self> {
if batches.is_empty() {
return Ok(RecordBatch::new_empty(schema.clone()));
}
if let Some((i, _)) = batches
.iter()
.enumerate()
.find(|&(_, batch)| batch.schema() != *schema)
{
return Err(ArrowError::InvalidArgumentError(format!(
"batches[{}] schema is different with argument schema.",
i
)));
}
let field_num = schema.fields().len();
let mut arrays = Vec::with_capacity(field_num);
for i in 0..field_num {
let array = concat(
&batches
.iter()
.map(|batch| batch.column(i).as_ref())
.collect::<Vec<_>>(),
)?;
arrays.push(array);
}
Self::try_new(schema.clone(), arrays)
}
}
#[derive(Debug)]
pub struct RecordBatchOptions {
pub match_field_names: bool,
}
impl Default for RecordBatchOptions {
fn default() -> Self {
Self {
match_field_names: true,
}
}
}
impl From<&StructArray> for RecordBatch {
fn from(struct_array: &StructArray) -> Self {
if let DataType::Struct(fields) = struct_array.data_type() {
let schema = Schema::new(fields.clone());
let columns = struct_array.boxed_fields.clone();
RecordBatch {
schema: Arc::new(schema),
columns,
}
} else {
unreachable!("unable to get datatype as struct")
}
}
}
impl From<RecordBatch> for StructArray {
fn from(batch: RecordBatch) -> Self {
batch
.schema
.fields
.iter()
.zip(batch.columns.iter())
.map(|t| (t.0.clone(), t.1.clone()))
.collect::<Vec<(Field, ArrayRef)>>()
.into()
}
}
pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch>> {
fn schema(&self) -> SchemaRef;
#[deprecated(
since = "2.0.0",
note = "This method is deprecated in favour of `next` from the trait Iterator."
)]
fn next_batch(&mut self) -> Result<Option<RecordBatch>> {
self.next().transpose()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::buffer::Buffer;
#[test]
fn create_record_batch() {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]);
let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
let b = StringArray::from(vec!["a", "b", "c", "d", "e"]);
let record_batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])
.unwrap();
check_batch(record_batch, 5)
}
fn check_batch(record_batch: RecordBatch, num_rows: usize) {
assert_eq!(num_rows, record_batch.num_rows());
assert_eq!(2, record_batch.num_columns());
assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type());
assert_eq!(num_rows, record_batch.column(0).data().len());
assert_eq!(num_rows, record_batch.column(1).data().len());
}
#[test]
#[should_panic(expected = "assertion failed: (offset + length) <= self.num_rows()")]
fn create_record_batch_slice() {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]);
let expected_schema = schema.clone();
let a = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]);
let b = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "h", "i"]);
let record_batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])
.unwrap();
let offset = 2;
let length = 5;
let record_batch_slice = record_batch.slice(offset, length);
assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
check_batch(record_batch_slice, 5);
let offset = 2;
let length = 0;
let record_batch_slice = record_batch.slice(offset, length);
assert_eq!(record_batch_slice.schema().as_ref(), &expected_schema);
check_batch(record_batch_slice, 0);
let offset = 2;
let length = 10;
let _record_batch_slice = record_batch.slice(offset, length);
}
#[test]
#[should_panic(expected = "assertion failed: (offset + length) == 0")]
fn create_record_batch_slice_empty_batch() {
let schema = Schema::new(vec![]);
let record_batch = RecordBatch::new_empty(Arc::new(schema));
let offset = 0;
let length = 0;
let record_batch_slice = record_batch.slice(offset, length);
assert_eq!(0, record_batch_slice.schema().fields().len());
let offset = 1;
let length = 2;
let _record_batch_slice = record_batch.slice(offset, length);
}
#[test]
fn create_record_batch_try_from_iter() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![
Some(1),
Some(2),
None,
Some(4),
Some(5),
]));
let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
let record_batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])
.expect("valid conversion");
let expected_schema = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, false),
]);
assert_eq!(record_batch.schema().as_ref(), &expected_schema);
check_batch(record_batch, 5);
}
#[test]
fn create_record_batch_try_from_iter_with_nullable() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
let record_batch = RecordBatch::try_from_iter_with_nullable(vec![
("a", a, false),
("b", b, true),
])
.expect("valid conversion");
let expected_schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, true),
]);
assert_eq!(record_batch.schema().as_ref(), &expected_schema);
check_batch(record_batch, 5);
}
#[test]
fn create_record_batch_schema_mismatch() {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]);
assert!(!batch.is_ok());
}
#[test]
fn create_record_batch_field_name_mismatch() {
let struct_fields = vec![
Field::new("a1", DataType::Int32, false),
Field::new(
"a2",
DataType::List(Box::new(Field::new("item", DataType::Int8, false))),
false,
),
];
let struct_type = DataType::Struct(struct_fields);
let schema = Arc::new(Schema::new(vec![Field::new("a", struct_type, true)]));
let a1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
let a2_child = Int8Array::from(vec![1, 2, 3, 4]);
let a2 = ArrayDataBuilder::new(DataType::List(Box::new(Field::new(
"array",
DataType::Int8,
false,
))))
.add_child_data(a2_child.data().clone())
.len(2)
.add_buffer(Buffer::from(vec![0i32, 3, 4].to_byte_slice()))
.build()
.unwrap();
let a2: ArrayRef = Arc::new(ListArray::from(a2));
let a = ArrayDataBuilder::new(DataType::Struct(vec![
Field::new("aa1", DataType::Int32, false),
Field::new("a2", a2.data_type().clone(), false),
]))
.add_child_data(a1.data().clone())
.add_child_data(a2.data().clone())
.len(2)
.build()
.unwrap();
let a: ArrayRef = Arc::new(StructArray::from(a));
let batch = RecordBatch::try_new(schema.clone(), vec![a.clone()]);
assert!(batch.is_err());
let options = RecordBatchOptions {
match_field_names: false,
};
let batch = RecordBatch::try_new_with_options(schema, vec![a], &options);
assert!(batch.is_ok());
}
#[test]
fn create_record_batch_record_mismatch() {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
let batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
assert!(!batch.is_ok());
}
#[test]
fn create_record_batch_from_struct_array() {
let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
let struct_array = StructArray::from(vec![
(
Field::new("b", DataType::Boolean, false),
boolean.clone() as ArrayRef,
),
(
Field::new("c", DataType::Int32, false),
int.clone() as ArrayRef,
),
]);
let batch = RecordBatch::from(&struct_array);
assert_eq!(2, batch.num_columns());
assert_eq!(4, batch.num_rows());
assert_eq!(
struct_array.data_type(),
&DataType::Struct(batch.schema().fields().to_vec())
);
assert_eq!(batch.column(0).as_ref(), boolean.as_ref());
assert_eq!(batch.column(1).as_ref(), int.as_ref());
}
#[test]
fn concat_record_batches() {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]));
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(StringArray::from(vec!["a", "b"])),
],
)
.unwrap();
let batch2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![3, 4])),
Arc::new(StringArray::from(vec!["c", "d"])),
],
)
.unwrap();
let new_batch = RecordBatch::concat(&schema, &[batch1, batch2]).unwrap();
assert_eq!(new_batch.schema().as_ref(), schema.as_ref());
assert_eq!(2, new_batch.num_columns());
assert_eq!(4, new_batch.num_rows());
}
#[test]
fn concat_empty_record_batch() {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]));
let batch = RecordBatch::concat(&schema, &[]).unwrap();
assert_eq!(batch.schema().as_ref(), schema.as_ref());
assert_eq!(0, batch.num_rows());
}
#[test]
fn concat_record_batches_of_different_schemas() {
let schema1 = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]));
let schema2 = Arc::new(Schema::new(vec![
Field::new("c", DataType::Int32, false),
Field::new("d", DataType::Utf8, false),
]));
let batch1 = RecordBatch::try_new(
schema1.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(StringArray::from(vec!["a", "b"])),
],
)
.unwrap();
let batch2 = RecordBatch::try_new(
schema2,
vec![
Arc::new(Int32Array::from(vec![3, 4])),
Arc::new(StringArray::from(vec!["c", "d"])),
],
)
.unwrap();
let error = RecordBatch::concat(&schema1, &[batch1, batch2]).unwrap_err();
assert_eq!(
error.to_string(),
"Invalid argument error: batches[1] schema is different with argument schema.",
);
}
#[test]
fn record_batch_equality() {
let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
let schema1 = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("val", DataType::Int32, false),
]);
let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
let schema2 = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("val", DataType::Int32, false),
]);
let batch1 = RecordBatch::try_new(
Arc::new(schema1),
vec![Arc::new(id_arr1), Arc::new(val_arr1)],
)
.unwrap();
let batch2 = RecordBatch::try_new(
Arc::new(schema2),
vec![Arc::new(id_arr2), Arc::new(val_arr2)],
)
.unwrap();
assert_eq!(batch1, batch2);
}
#[test]
fn record_batch_vals_ne() {
let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
let schema1 = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("val", DataType::Int32, false),
]);
let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
let val_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
let schema2 = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("val", DataType::Int32, false),
]);
let batch1 = RecordBatch::try_new(
Arc::new(schema1),
vec![Arc::new(id_arr1), Arc::new(val_arr1)],
)
.unwrap();
let batch2 = RecordBatch::try_new(
Arc::new(schema2),
vec![Arc::new(id_arr2), Arc::new(val_arr2)],
)
.unwrap();
assert_ne!(batch1, batch2);
}
#[test]
fn record_batch_column_names_ne() {
let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
let schema1 = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("val", DataType::Int32, false),
]);
let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
let schema2 = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("num", DataType::Int32, false),
]);
let batch1 = RecordBatch::try_new(
Arc::new(schema1),
vec![Arc::new(id_arr1), Arc::new(val_arr1)],
)
.unwrap();
let batch2 = RecordBatch::try_new(
Arc::new(schema2),
vec![Arc::new(id_arr2), Arc::new(val_arr2)],
)
.unwrap();
assert_ne!(batch1, batch2);
}
#[test]
fn record_batch_column_number_ne() {
let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
let val_arr1 = Int32Array::from(vec![5, 6, 7, 8]);
let schema1 = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("val", DataType::Int32, false),
]);
let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
let num_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
let schema2 = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("val", DataType::Int32, false),
Field::new("num", DataType::Int32, false),
]);
let batch1 = RecordBatch::try_new(
Arc::new(schema1),
vec![Arc::new(id_arr1), Arc::new(val_arr1)],
)
.unwrap();
let batch2 = RecordBatch::try_new(
Arc::new(schema2),
vec![Arc::new(id_arr2), Arc::new(val_arr2), Arc::new(num_arr2)],
)
.unwrap();
assert_ne!(batch1, batch2);
}
#[test]
fn record_batch_row_count_ne() {
let id_arr1 = Int32Array::from(vec![1, 2, 3]);
let val_arr1 = Int32Array::from(vec![5, 6, 7]);
let schema1 = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("val", DataType::Int32, false),
]);
let id_arr2 = Int32Array::from(vec![1, 2, 3, 4]);
let val_arr2 = Int32Array::from(vec![5, 6, 7, 8]);
let schema2 = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("num", DataType::Int32, false),
]);
let batch1 = RecordBatch::try_new(
Arc::new(schema1),
vec![Arc::new(id_arr1), Arc::new(val_arr1)],
)
.unwrap();
let batch2 = RecordBatch::try_new(
Arc::new(schema2),
vec![Arc::new(id_arr2), Arc::new(val_arr2)],
)
.unwrap();
assert_ne!(batch1, batch2);
}
}