use std::sync::Arc;
use crate::array::*;
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
type ArrayRef = Arc<dyn Array>;
#[derive(Clone, Debug, PartialEq)]
pub struct RecordBatch {
schema: Arc<Schema>,
columns: Vec<ArrayRef>,
}
impl RecordBatch {
pub fn try_new(schema: Arc<Schema>, columns: Vec<ArrayRef>) -> Result<Self> {
let options = RecordBatchOptions::default();
Self::validate_new_batch(&schema, columns.as_slice(), &options)?;
Ok(RecordBatch { schema, columns })
}
pub fn try_new_with_options(
schema: Arc<Schema>,
columns: Vec<ArrayRef>,
options: &RecordBatchOptions,
) -> Result<Self> {
Self::validate_new_batch(&schema, columns.as_slice(), options)?;
Ok(RecordBatch { schema, columns })
}
pub fn new_empty(schema: Arc<Schema>) -> Self {
let columns = schema
.fields()
.iter()
.map(|field| new_empty_array(field.data_type().clone()).into())
.collect();
RecordBatch { schema, columns }
}
fn validate_new_batch(
schema: &Schema,
columns: &[ArrayRef],
options: &RecordBatchOptions,
) -> Result<()> {
if columns.is_empty() {
return Err(ArrowError::InvalidArgumentError(
"at least one column must be defined to create a record batch".to_string(),
));
}
if schema.fields().len() != columns.len() {
return Err(ArrowError::InvalidArgumentError(format!(
"number of columns({}) must match number of fields({}) in schema",
columns.len(),
schema.fields().len(),
)));
}
let len = columns[0].len();
if options.match_field_names {
for (i, column) in columns.iter().enumerate() {
if column.len() != len {
return Err(ArrowError::InvalidArgumentError(
"all columns in a record batch must have the same length".to_string(),
));
}
if column.data_type() != schema.field(i).data_type() {
return Err(ArrowError::InvalidArgumentError(format!(
"column types must match schema types, expected {:?} but found {:?} at column index {}",
schema.field(i).data_type(),
column.data_type(),
i)));
}
}
} else {
for (i, column) in columns.iter().enumerate() {
if column.len() != len {
return Err(ArrowError::InvalidArgumentError(
"all columns in a record batch must have the same length".to_string(),
));
}
if !column
.data_type()
.equals_datatype(schema.field(i).data_type())
{
return Err(ArrowError::InvalidArgumentError(format!(
"column types must match schema types, expected {:?} but found {:?} at column index {}",
schema.field(i).data_type(),
column.data_type(),
i)));
}
}
}
Ok(())
}
pub fn schema(&self) -> &Arc<Schema> {
&self.schema
}
pub fn num_columns(&self) -> usize {
self.columns.len()
}
pub fn num_rows(&self) -> usize {
self.columns[0].len()
}
pub fn column(&self, index: usize) -> &ArrayRef {
&self.columns[index]
}
pub fn columns(&self) -> &[ArrayRef] {
&self.columns[..]
}
pub fn try_from_iter<I, F>(value: I) -> Result<Self>
where
I: IntoIterator<Item = (F, ArrayRef)>,
F: AsRef<str>,
{
let iter = value.into_iter().map(|(field_name, array)| {
let nullable = array.null_count() > 0;
(field_name, array, nullable)
});
Self::try_from_iter_with_nullable(iter)
}
pub fn try_from_iter_with_nullable<I, F>(value: I) -> Result<Self>
where
I: IntoIterator<Item = (F, ArrayRef, bool)>,
F: AsRef<str>,
{
let (fields, columns) = value
.into_iter()
.map(|(field_name, array, nullable)| {
let field_name = field_name.as_ref();
let field = Field::new(field_name, array.data_type().clone(), nullable);
(field, array)
})
.unzip();
let schema = Arc::new(Schema::new(fields));
RecordBatch::try_new(schema, columns)
}
}
#[derive(Debug)]
pub struct RecordBatchOptions {
pub match_field_names: bool,
}
impl Default for RecordBatchOptions {
fn default() -> Self {
Self {
match_field_names: true,
}
}
}
impl From<StructArray> for RecordBatch {
fn from(array: StructArray) -> Self {
assert!(array.null_count() == 0);
let (fields, values, _) = array.into_data();
RecordBatch {
schema: Arc::new(Schema::new(fields)),
columns: values,
}
}
}
impl From<RecordBatch> for StructArray {
fn from(batch: RecordBatch) -> Self {
let (fields, values) = batch
.schema
.fields
.iter()
.zip(batch.columns.iter())
.map(|t| (t.0.clone(), t.1.clone()))
.unzip();
StructArray::from_data(DataType::Struct(fields), values, None)
}
}
pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch>> {
fn schema(&self) -> &Schema;
#[deprecated(
since = "2.0.0",
note = "This method is deprecated in favour of `next` from the trait Iterator."
)]
fn next_batch(&mut self) -> Result<Option<RecordBatch>> {
self.next().transpose()
}
}