use std::sync::Arc;
use crate::array::*;
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
type ArrayRef = Arc<dyn Array>;
#[derive(Clone, Debug, PartialEq)]
pub struct RecordBatch {
schema: Arc<Schema>,
columns: Vec<ArrayRef>,
}
impl RecordBatch {
pub fn try_new(schema: Arc<Schema>, columns: Vec<ArrayRef>) -> Result<Self> {
let options = RecordBatchOptions::default();
Self::validate_new_batch(&schema, columns.as_slice(), &options)?;
Ok(RecordBatch { schema, columns })
}
pub fn try_new_with_options(
schema: Arc<Schema>,
columns: Vec<ArrayRef>,
options: &RecordBatchOptions,
) -> Result<Self> {
Self::validate_new_batch(&schema, columns.as_slice(), options)?;
Ok(RecordBatch { schema, columns })
}
pub fn new_empty(schema: Arc<Schema>) -> Self {
let columns = schema
.fields()
.iter()
.map(|field| new_empty_array(field.data_type().clone()).into())
.collect();
RecordBatch { schema, columns }
}
fn validate_new_batch(
schema: &Schema,
columns: &[ArrayRef],
options: &RecordBatchOptions,
) -> Result<()> {
if columns.is_empty() {
return Err(ArrowError::InvalidArgumentError(
"at least one column must be defined to create a record batch".to_string(),
));
}
if schema.fields().len() != columns.len() {
return Err(ArrowError::InvalidArgumentError(format!(
"number of columns({}) must match number of fields({}) in schema",
columns.len(),
schema.fields().len(),
)));
}
let len = columns[0].len();
if options.match_field_names {
for (i, column) in columns.iter().enumerate() {
if column.len() != len {
return Err(ArrowError::InvalidArgumentError(
"all columns in a record batch must have the same length".to_string(),
));
}
if column.data_type() != schema.field(i).data_type() {
return Err(ArrowError::InvalidArgumentError(format!(
"column types must match schema types, expected {:?} but found {:?} at column index {}",
schema.field(i).data_type(),
column.data_type(),
i)));
}
}
} else {
for (i, column) in columns.iter().enumerate() {
if column.len() != len {
return Err(ArrowError::InvalidArgumentError(
"all columns in a record batch must have the same length".to_string(),
));
}
if !column
.data_type()
.equals_datatype(schema.field(i).data_type())
{
return Err(ArrowError::InvalidArgumentError(format!(
"column types must match schema types, expected {:?} but found {:?} at column index {}",
schema.field(i).data_type(),
column.data_type(),
i)));
}
}
}
Ok(())
}
pub fn schema(&self) -> &Arc<Schema> {
&self.schema
}
pub fn num_columns(&self) -> usize {
self.columns.len()
}
pub fn num_rows(&self) -> usize {
self.columns[0].len()
}
pub fn column(&self, index: usize) -> &ArrayRef {
&self.columns[index]
}
pub fn columns(&self) -> &[ArrayRef] {
&self.columns[..]
}
pub fn try_from_iter<I, F>(value: I) -> Result<Self>
where
I: IntoIterator<Item = (F, ArrayRef)>,
F: AsRef<str>,
{
let iter = value.into_iter().map(|(field_name, array)| {
let nullable = array.null_count() > 0;
(field_name, array, nullable)
});
Self::try_from_iter_with_nullable(iter)
}
pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self>
where
I: IntoIterator<Item = (F, ArrayRef, bool)>,
F: AsRef<str>,
{
let (fields, columns) = value
.into_iter()
.map(|(field_name, array, nullable)| {
let field_name = field_name.as_ref();
let field = Field::new(field_name, array.data_type().clone(), nullable);
(field, array)
})
.unzip();
let schema = Arc::new(Schema::new(fields));
RecordBatch::try_new(schema, columns)
}
}
#[derive(Debug)]
pub struct RecordBatchOptions {
pub match_field_names: bool,
}
impl Default for RecordBatchOptions {
fn default() -> Self {
Self {
match_field_names: true,
}
}
}
impl From<&StructArray> for RecordBatch {
fn from(struct_array: &StructArray) -> Self {
if let DataType::Struct(fields) = struct_array.data_type() {
let schema = Arc::new(Schema::new(fields.clone()));
let columns = struct_array.values().to_vec();
RecordBatch { schema, columns }
} else {
unreachable!("unable to get datatype as struct")
}
}
}
impl From<RecordBatch> for StructArray {
fn from(batch: RecordBatch) -> Self {
let (fields, values) = batch
.schema
.fields
.iter()
.zip(batch.columns.iter())
.map(|t| (t.0.clone(), t.1.clone()))
.unzip();
StructArray::from_data(fields, values, None)
}
}
pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch>> {
fn schema(&self) -> &Schema;
#[deprecated(
since = "2.0.0",
note = "This method is deprecated in favour of `next` from the trait Iterator."
)]
fn next_batch(&mut self) -> Result<Option<RecordBatch>> {
self.next().transpose()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic() {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, false),
]);
let a = Int32Array::from_slice(&[1, 2, 3, 4, 5]);
let b = Utf8Array::<i32>::from_slice(&["a", "b", "c", "d", "e"]);
let record_batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]).unwrap();
check_batch(record_batch)
}
fn check_batch(record_batch: RecordBatch) {
assert_eq!(5, record_batch.num_rows());
assert_eq!(2, record_batch.num_columns());
assert_eq!(&DataType::Int32, record_batch.schema().field(0).data_type());
assert_eq!(&DataType::Utf8, record_batch.schema().field(1).data_type());
assert_eq!(5, record_batch.column(0).len());
assert_eq!(5, record_batch.column(1).len());
}
#[test]
fn try_from_iter() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![
Some(1),
Some(2),
None,
Some(4),
Some(5),
]));
let b: ArrayRef = Arc::new(Utf8Array::<i32>::from_slice(&["a", "b", "c", "d", "e"]));
let record_batch =
RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).expect("valid conversion");
let expected_schema = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, false),
]);
assert_eq!(record_batch.schema().as_ref(), &expected_schema);
check_batch(record_batch);
}
#[test]
fn try_from_iter_with_nullable() {
let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 3, 4, 5]));
let b: ArrayRef = Arc::new(Utf8Array::<i32>::from_slice(&["a", "b", "c", "d", "e"]));
let record_batch =
RecordBatch::try_from_iter_with_nullable(vec![("a", a, false), ("b", b, true)])
.expect("valid conversion");
let expected_schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Utf8, true),
]);
assert_eq!(record_batch.schema().as_ref(), &expected_schema);
check_batch(record_batch);
}
#[test]
fn type_mismatch() {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let a = Int64Array::from_slice(&[1, 2, 3, 4, 5]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]);
assert!(batch.is_err());
}
#[test]
fn number_of_fields_mismatch() {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let a = Int32Array::from_slice(&[1, 2, 3, 4, 5]);
let b = Int32Array::from_slice(&[1, 2, 3, 4, 5]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
assert!(batch.is_err());
}
#[test]
fn from_struct_array() {
let boolean = Arc::new(BooleanArray::from_slice(&[false, false, true, true])) as ArrayRef;
let int = Arc::new(Int32Array::from_slice(&[42, 28, 19, 31])) as ArrayRef;
let struct_array = StructArray::from_data(
vec![
Field::new("b", DataType::Boolean, false),
Field::new("c", DataType::Int32, false),
],
vec![boolean.clone(), int.clone()],
None,
);
let batch = RecordBatch::from(&struct_array);
assert_eq!(2, batch.num_columns());
assert_eq!(4, batch.num_rows());
assert_eq!(
struct_array.data_type(),
&DataType::Struct(batch.schema().fields().to_vec())
);
assert_eq!(boolean.as_ref(), batch.column(0).as_ref());
assert_eq!(int.as_ref(), batch.column(1).as_ref());
}
}