Skip to main content

rustsim_io/
arrow.rs

1//! Arrow batch building utilities.
2//!
3//! Provides [`ArrowBatchBuilder`] for incrementally constructing Arrow
4//! [`RecordBatch`] values from dynamically-typed [`ArrowValue`] rows,
5//! and [`RecordBatchCollector`] for accumulating completed batches.
6
7use arrow_array::builder::{
8    ArrayBuilder, BooleanBuilder, Float64Builder, Int64Builder, StringBuilder,
9};
10use arrow_array::{ArrayRef, RecordBatch};
11use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
12use std::sync::Arc;
13
14/// A dynamically-typed cell value for one column of an Arrow row.
15///
16/// Used with [`ArrowBatchBuilder::push_row`] to append rows without
17/// requiring a compile-time schema.
18#[derive(Debug, Clone)]
19pub enum ArrowValue {
20    /// 64-bit signed integer.
21    Int64(i64),
22    /// 64-bit float.
23    Float64(f64),
24    /// Boolean.
25    Boolean(bool),
26    /// UTF-8 string.
27    Utf8(String),
28}
29
30/// Incremental Arrow [`RecordBatch`] builder.
31///
32/// Rows are pushed one at a time via [`push_row`](Self::push_row). Call
33/// [`finish`](Self::finish) to produce a `RecordBatch` and reset the builder.
34///
35/// The builder validates that each row's length matches the schema.
36pub struct ArrowBatchBuilder {
37    schema: SchemaRef,
38    builders: Vec<Box<dyn ArrayBuilder>>,
39    len: usize,
40}
41
42impl ArrowBatchBuilder {
43    /// Create a new builder for the given schema.
44    ///
45    /// Returns an error if the schema contains unsupported data types.
46    pub fn new(schema: SchemaRef) -> Result<Self, ArrowError> {
47        let builders = build_builders(&schema)?;
48        Ok(Self {
49            schema,
50            builders,
51            len: 0,
52        })
53    }
54
55    /// The schema this builder produces.
56    pub fn schema(&self) -> &SchemaRef {
57        &self.schema
58    }
59
60    /// Number of rows appended since the last `finish`.
61    pub fn len(&self) -> usize {
62        self.len
63    }
64
65    /// Returns `true` if no rows have been appended.
66    pub fn is_empty(&self) -> bool {
67        self.len == 0
68    }
69
70    /// Append a row of values.
71    ///
72    /// Returns [`ArrowError::SchemaError`] if the row length doesn't match
73    /// the schema, or if a value type doesn't match its column's data type.
74    pub fn push_row(&mut self, values: &[ArrowValue]) -> Result<(), ArrowError> {
75        if values.len() != self.builders.len() {
76            return Err(ArrowError::SchemaError(
77                "row length does not match schema".to_string(),
78            ));
79        }
80
81        for (value, builder) in values.iter().zip(self.builders.iter_mut()) {
82            append_value(builder.as_mut(), value)?;
83        }
84
85        self.len += 1;
86        Ok(())
87    }
88
89    /// Consume the accumulated rows and produce a [`RecordBatch`].
90    ///
91    /// The builder is reset and can be reused for the next batch.
92    pub fn finish(&mut self) -> Result<RecordBatch, ArrowError> {
93        let arrays = self
94            .builders
95            .iter_mut()
96            .map(|builder| builder.finish())
97            .collect::<Vec<ArrayRef>>();
98
99        let batch = RecordBatch::try_new(self.schema.clone(), arrays)?;
100        self.builders = build_builders(&self.schema)?;
101        self.len = 0;
102        Ok(batch)
103    }
104}
105
106/// Accumulator for completed [`RecordBatch`] values.
107#[derive(Debug, Default)]
108pub struct RecordBatchCollector {
109    batches: Vec<RecordBatch>,
110}
111
112impl RecordBatchCollector {
113    /// Create an empty collector.
114    pub fn new() -> Self {
115        Self {
116            batches: Vec::new(),
117        }
118    }
119
120    /// Add a batch to the collector.
121    pub fn push(&mut self, batch: RecordBatch) {
122        self.batches.push(batch);
123    }
124
125    /// Take all collected batches, leaving the collector empty.
126    pub fn take(&mut self) -> Vec<RecordBatch> {
127        std::mem::take(&mut self.batches)
128    }
129
130    /// Borrow the collected batches.
131    pub fn batches(&self) -> &[RecordBatch] {
132        &self.batches
133    }
134}
135
136/// Create an Arrow schema from a list of fields.
137pub fn schema_from_fields(fields: Vec<Field>) -> SchemaRef {
138    Arc::new(Schema::new(fields))
139}
140
141fn build_builders(schema: &SchemaRef) -> Result<Vec<Box<dyn ArrayBuilder>>, ArrowError> {
142    schema
143        .fields()
144        .iter()
145        .map(|field| builder_for_field(field))
146        .collect()
147}
148
149fn builder_for_field(field: &Field) -> Result<Box<dyn ArrayBuilder>, ArrowError> {
150    match field.data_type() {
151        DataType::Int64 => Ok(Box::new(Int64Builder::new())),
152        DataType::Float64 => Ok(Box::new(Float64Builder::new())),
153        DataType::Boolean => Ok(Box::new(BooleanBuilder::new())),
154        DataType::Utf8 => Ok(Box::new(StringBuilder::new())),
155        other => Err(ArrowError::SchemaError(format!(
156            "unsupported data type {other:?}"
157        ))),
158    }
159}
160
161fn append_value(builder: &mut dyn ArrayBuilder, value: &ArrowValue) -> Result<(), ArrowError> {
162    if let (Some(builder), ArrowValue::Int64(value)) =
163        (builder.as_any_mut().downcast_mut::<Int64Builder>(), value)
164    {
165        builder.append_value(*value);
166        return Ok(());
167    }
168
169    if let (Some(builder), ArrowValue::Float64(value)) =
170        (builder.as_any_mut().downcast_mut::<Float64Builder>(), value)
171    {
172        builder.append_value(*value);
173        return Ok(());
174    }
175
176    if let (Some(builder), ArrowValue::Boolean(value)) =
177        (builder.as_any_mut().downcast_mut::<BooleanBuilder>(), value)
178    {
179        builder.append_value(*value);
180        return Ok(());
181    }
182
183    if let (Some(builder), ArrowValue::Utf8(value)) =
184        (builder.as_any_mut().downcast_mut::<StringBuilder>(), value)
185    {
186        builder.append_value(value);
187        return Ok(());
188    }
189
190    Err(ArrowError::SchemaError(
191        "value does not match builder type".to_string(),
192    ))
193}