use arrow_array::builder::{
ArrayBuilder, BooleanBuilder, Float64Builder, Int64Builder, StringBuilder,
};
use arrow_array::{ArrayRef, RecordBatch};
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub enum ArrowValue {
Int64(i64),
Float64(f64),
Boolean(bool),
Utf8(String),
}
pub struct ArrowBatchBuilder {
schema: SchemaRef,
builders: Vec<Box<dyn ArrayBuilder>>,
len: usize,
}
impl ArrowBatchBuilder {
pub fn new(schema: SchemaRef) -> Result<Self, ArrowError> {
let builders = build_builders(&schema)?;
Ok(Self {
schema,
builders,
len: 0,
})
}
pub fn schema(&self) -> &SchemaRef {
&self.schema
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn push_row(&mut self, values: &[ArrowValue]) -> Result<(), ArrowError> {
if values.len() != self.builders.len() {
return Err(ArrowError::SchemaError(
"row length does not match schema".to_string(),
));
}
for (value, builder) in values.iter().zip(self.builders.iter_mut()) {
append_value(builder.as_mut(), value)?;
}
self.len += 1;
Ok(())
}
pub fn finish(&mut self) -> Result<RecordBatch, ArrowError> {
let arrays = self
.builders
.iter_mut()
.map(|builder| builder.finish())
.collect::<Vec<ArrayRef>>();
let batch = RecordBatch::try_new(self.schema.clone(), arrays)?;
self.builders = build_builders(&self.schema)?;
self.len = 0;
Ok(batch)
}
}
#[derive(Debug, Default)]
pub struct RecordBatchCollector {
batches: Vec<RecordBatch>,
}
impl RecordBatchCollector {
pub fn new() -> Self {
Self {
batches: Vec::new(),
}
}
pub fn push(&mut self, batch: RecordBatch) {
self.batches.push(batch);
}
pub fn take(&mut self) -> Vec<RecordBatch> {
std::mem::take(&mut self.batches)
}
pub fn batches(&self) -> &[RecordBatch] {
&self.batches
}
}
pub fn schema_from_fields(fields: Vec<Field>) -> SchemaRef {
Arc::new(Schema::new(fields))
}
fn build_builders(schema: &SchemaRef) -> Result<Vec<Box<dyn ArrayBuilder>>, ArrowError> {
schema
.fields()
.iter()
.map(|field| builder_for_field(field))
.collect()
}
fn builder_for_field(field: &Field) -> Result<Box<dyn ArrayBuilder>, ArrowError> {
match field.data_type() {
DataType::Int64 => Ok(Box::new(Int64Builder::new())),
DataType::Float64 => Ok(Box::new(Float64Builder::new())),
DataType::Boolean => Ok(Box::new(BooleanBuilder::new())),
DataType::Utf8 => Ok(Box::new(StringBuilder::new())),
other => Err(ArrowError::SchemaError(format!(
"unsupported data type {other:?}"
))),
}
}
fn append_value(builder: &mut dyn ArrayBuilder, value: &ArrowValue) -> Result<(), ArrowError> {
if let (Some(builder), ArrowValue::Int64(value)) =
(builder.as_any_mut().downcast_mut::<Int64Builder>(), value)
{
builder.append_value(*value);
return Ok(());
}
if let (Some(builder), ArrowValue::Float64(value)) =
(builder.as_any_mut().downcast_mut::<Float64Builder>(), value)
{
builder.append_value(*value);
return Ok(());
}
if let (Some(builder), ArrowValue::Boolean(value)) =
(builder.as_any_mut().downcast_mut::<BooleanBuilder>(), value)
{
builder.append_value(*value);
return Ok(());
}
if let (Some(builder), ArrowValue::Utf8(value)) =
(builder.as_any_mut().downcast_mut::<StringBuilder>(), value)
{
builder.append_value(value);
return Ok(());
}
Err(ArrowError::SchemaError(
"value does not match builder type".to_string(),
))
}