use arrow::{
array::{ArrayBuilder, Int16Builder, Int8Builder, RecordBatch},
datatypes::{DataType, Field, Schema, TimeUnit},
};
use arrow_array::{
builder::DurationMicrosecondBuilder, DurationMicrosecondArray, Int16Array, Int8Array,
};
use std::sync::Arc;
use crate::ev_formats::{streaming::Event, EventFormat};
type Events = Vec<Event>;
#[derive(Debug, thiserror::Error)]
pub enum ArrowBuilderError {
#[error("Arrow array construction failed: {0}")]
ArrayConstruction(String),
#[error("Invalid event data: {message}")]
InvalidData { message: String },
#[error("Memory allocation failed for {event_count} events")]
MemoryAllocation { event_count: usize },
#[error("Schema validation failed: {0}")]
SchemaValidation(String),
#[error("Feature not enabled: Arrow support requires 'arrow' feature flag")]
FeatureNotEnabled,
}
impl From<arrow::error::ArrowError> for ArrowBuilderError {
fn from(e: arrow::error::ArrowError) -> Self {
ArrowBuilderError::ArrayConstruction(e.to_string())
}
}
pub fn create_event_arrow_schema() -> Schema {
Schema::new(vec![
Field::new("x", DataType::Int16, false),
Field::new("y", DataType::Int16, false),
Field::new("t", DataType::Duration(TimeUnit::Microsecond), false),
Field::new("polarity", DataType::Int8, false),
])
}
pub struct ArrowEventBuilder {
x_builder: Int16Builder,
y_builder: Int16Builder,
timestamp_builder: DurationMicrosecondBuilder,
polarity_builder: Int8Builder,
format: EventFormat,
capacity: usize,
schema: Arc<Schema>,
}
impl ArrowEventBuilder {
pub fn new(capacity: usize, format: EventFormat) -> Self {
Self {
x_builder: Int16Builder::with_capacity(capacity),
y_builder: Int16Builder::with_capacity(capacity),
timestamp_builder: DurationMicrosecondBuilder::with_capacity(capacity),
polarity_builder: Int8Builder::with_capacity(capacity),
format,
capacity,
schema: Arc::new(create_event_arrow_schema()),
}
}
pub fn for_events(events: &[Event], format: EventFormat) -> Self {
Self::new(events.len(), format)
}
pub fn append_event(&mut self, event: &Event) -> Result<(), ArrowBuilderError> {
self.x_builder.append_value(event.x as i16);
self.y_builder.append_value(event.y as i16);
let timestamp_us = self.convert_timestamp(event.t);
self.timestamp_builder.append_value(timestamp_us);
let polarity_value = self.convert_polarity(event.polarity);
self.polarity_builder.append_value(polarity_value);
Ok(())
}
pub fn from_events_zero_copy(
events: &[Event],
format: EventFormat,
) -> Result<RecordBatch, ArrowBuilderError> {
if events.is_empty() {
return Self::create_empty_batch();
}
let mut builder = Self::new(events.len(), format);
for event in events {
builder.append_event(event)?;
}
builder.finish()
}
pub fn from_events_iter<I>(
events: I,
format: EventFormat,
size_hint: Option<usize>,
) -> Result<RecordBatch, ArrowBuilderError>
where
I: Iterator<Item = Event>,
{
let capacity = size_hint.unwrap_or(1000);
let mut builder = Self::new(capacity, format);
for event in events {
builder.append_event(&event)?;
}
builder.finish()
}
pub fn finish(mut self) -> Result<RecordBatch, ArrowBuilderError> {
let x_array = self.x_builder.finish();
let y_array = self.y_builder.finish();
let timestamp_array = self.timestamp_builder.finish();
let polarity_array = self.polarity_builder.finish();
let batch = RecordBatch::try_new(
self.schema,
vec![
Arc::new(x_array),
Arc::new(y_array),
Arc::new(timestamp_array),
Arc::new(polarity_array),
],
)?;
Ok(batch)
}
pub fn create_empty_batch() -> Result<RecordBatch, ArrowBuilderError> {
let schema = Arc::new(create_event_arrow_schema());
let x_array = Int16Array::from(Vec::<i16>::new());
let y_array = Int16Array::from(Vec::<i16>::new());
let timestamp_array = DurationMicrosecondArray::from(Vec::<i64>::new());
let polarity_array = Int8Array::from(Vec::<i8>::new());
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(x_array),
Arc::new(y_array),
Arc::new(timestamp_array),
Arc::new(polarity_array),
],
)?;
Ok(batch)
}
fn convert_polarity(&self, polarity: i8) -> i8 {
match self.format {
EventFormat::EVT2 | EventFormat::EVT21 | EventFormat::EVT3 | EventFormat::HDF5 => {
if polarity > 0 {
1i8
} else {
-1i8
}
}
_ => {
if polarity > 0 {
1i8
} else {
0i8
}
}
}
}
fn convert_timestamp(&self, timestamp: f64) -> i64 {
if timestamp >= 1_000_000_000.0 {
(timestamp / 1_000.0) as i64
} else if timestamp >= 1_000.0 {
timestamp as i64
} else {
(timestamp * 1_000_000.0) as i64
}
}
pub fn schema(&self) -> &Schema {
&self.schema
}
pub fn capacity(&self) -> usize {
self.capacity
}
pub fn len(&self) -> usize {
self.x_builder.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub struct ArrowEventStreamer {
chunk_size: usize,
format: EventFormat,
schema: Arc<Schema>,
}
impl ArrowEventStreamer {
pub fn new(chunk_size: usize, format: EventFormat) -> Self {
Self {
chunk_size,
format,
schema: Arc::new(create_event_arrow_schema()),
}
}
pub fn stream_to_arrow<I>(&self, events: I) -> Result<RecordBatch, ArrowBuilderError>
where
I: Iterator<Item = Event>,
{
let mut record_batches = Vec::new();
let mut chunk_buffer = Vec::with_capacity(self.chunk_size);
for event in events {
chunk_buffer.push(event);
if chunk_buffer.len() >= self.chunk_size {
let chunk_batch =
ArrowEventBuilder::from_events_zero_copy(&chunk_buffer, self.format)?;
if chunk_batch.num_rows() > 0 {
record_batches.push(chunk_batch);
}
chunk_buffer.clear();
}
}
if !chunk_buffer.is_empty() {
let chunk_batch = ArrowEventBuilder::from_events_zero_copy(&chunk_buffer, self.format)?;
if chunk_batch.num_rows() > 0 {
record_batches.push(chunk_batch);
}
}
if record_batches.is_empty() {
return ArrowEventBuilder::create_empty_batch();
}
if record_batches.len() == 1 {
Ok(record_batches.into_iter().next().unwrap())
} else {
self.concatenate_batches(&record_batches)
}
}
fn concatenate_batches(
&self,
batches: &[RecordBatch],
) -> Result<RecordBatch, ArrowBuilderError> {
use arrow::compute::concat_batches;
concat_batches(&self.schema, batches.iter())
.map_err(|e| ArrowBuilderError::ArrayConstruction(e.to_string()))
}
pub fn chunk_size(&self) -> usize {
self.chunk_size
}
pub fn schema(&self) -> &Schema {
&self.schema
}
}
pub fn arrow_to_events(batch: &RecordBatch) -> Result<Events, ArrowBuilderError> {
use arrow::array::{Array, DurationMicrosecondArray, Int16Array, Int8Array};
let expected_schema = create_event_arrow_schema();
if !batch.schema().fields().eq(expected_schema.fields()) {
return Err(ArrowBuilderError::SchemaValidation(
"RecordBatch schema does not match expected event schema".to_string(),
));
}
let num_rows = batch.num_rows();
if num_rows == 0 {
return Ok(Vec::new());
}
let x_array = batch
.column(0)
.as_any()
.downcast_ref::<Int16Array>()
.ok_or_else(|| ArrowBuilderError::InvalidData {
message: "x column is not Int16Array".to_string(),
})?;
let y_array = batch
.column(1)
.as_any()
.downcast_ref::<Int16Array>()
.ok_or_else(|| ArrowBuilderError::InvalidData {
message: "y column is not Int16Array".to_string(),
})?;
let timestamp_array = batch
.column(2)
.as_any()
.downcast_ref::<DurationMicrosecondArray>()
.ok_or_else(|| ArrowBuilderError::InvalidData {
message: "timestamp column is not DurationMicrosecondArray".to_string(),
})?;
let polarity_array = batch
.column(3)
.as_any()
.downcast_ref::<Int8Array>()
.ok_or_else(|| ArrowBuilderError::InvalidData {
message: "polarity column is not Int8Array".to_string(),
})?;
let mut events = Vec::with_capacity(num_rows);
for i in 0..num_rows {
let x = x_array.value(i) as u16;
let y = y_array.value(i) as u16;
let timestamp_us = timestamp_array.value(i);
let polarity_raw = polarity_array.value(i);
let t = timestamp_us as f64 / 1_000_000.0;
let polarity = polarity_raw;
events.push(Event { t, x, y, polarity });
}
Ok(events)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ev_formats::EventFormat;
fn create_test_events() -> Vec<Event> {
vec![
Event {
t: 0.001,
x: 100,
y: 200,
polarity: true,
},
Event {
t: 0.002,
x: 101,
y: 201,
polarity: false,
},
Event {
t: 0.003,
x: 102,
y: 202,
polarity: true,
},
]
}
#[test]
fn test_create_event_arrow_schema() {
let schema = create_event_arrow_schema();
assert_eq!(schema.fields().len(), 4);
let field_names: Vec<&str> = schema.fields().iter().map(|f| f.name().as_str()).collect();
assert_eq!(field_names, vec!["x", "y", "t", "polarity"]);
}
#[test]
fn test_arrow_event_builder_empty() {
let builder = ArrowEventBuilder::new(0, EventFormat::HDF5);
assert_eq!(builder.len(), 0);
assert!(builder.is_empty());
assert_eq!(builder.capacity(), 0);
}
#[test]
fn test_arrow_event_builder_basic() {
let events = create_test_events();
let batch = ArrowEventBuilder::from_events_zero_copy(&events, EventFormat::HDF5)
.expect("Failed to create Arrow batch");
assert_eq!(batch.num_rows(), 3);
assert_eq!(batch.num_columns(), 4);
assert_eq!(batch.schema().fields().len(), 4);
}
#[test]
fn test_polarity_encoding_evt2() {
let events = create_test_events();
let batch = ArrowEventBuilder::from_events_zero_copy(&events, EventFormat::EVT2)
.expect("Failed to create Arrow batch");
let polarity_column = batch.column(3);
let polarity_array = polarity_column
.as_any()
.downcast_ref::<Int8Array>()
.unwrap();
assert_eq!(polarity_array.value(0), 1i8); assert_eq!(polarity_array.value(1), -1i8); assert_eq!(polarity_array.value(2), 1i8); }
#[test]
fn test_polarity_encoding_text() {
let events = create_test_events();
let batch = ArrowEventBuilder::from_events_zero_copy(&events, EventFormat::Text)
.expect("Failed to create Arrow batch");
let polarity_column = batch.column(3);
let polarity_array = polarity_column
.as_any()
.downcast_ref::<Int8Array>()
.unwrap();
assert_eq!(polarity_array.value(0), 1i8); assert_eq!(polarity_array.value(1), 0i8); assert_eq!(polarity_array.value(2), 1i8); }
#[test]
fn test_timestamp_conversion() {
let events = vec![
Event {
t: 1.0,
x: 100,
y: 200,
polarity: true,
}, Event {
t: 0.001,
x: 101,
y: 201,
polarity: false,
}, Event {
t: 1_000_000.0,
x: 102,
y: 202,
polarity: true,
}, ];
let batch = ArrowEventBuilder::from_events_zero_copy(&events, EventFormat::HDF5)
.expect("Failed to create Arrow batch");
let timestamp_column = batch.column(2);
let timestamp_array = timestamp_column
.as_any()
.downcast_ref::<DurationMicrosecondArray>()
.unwrap();
assert_eq!(timestamp_array.value(0), 1_000_000i64); assert_eq!(timestamp_array.value(1), 1_000i64); assert_eq!(timestamp_array.value(2), 1_000_000i64); }
#[test]
fn test_arrow_to_events_conversion() {
let events = create_test_events();
let batch = ArrowEventBuilder::from_events_zero_copy(&events, EventFormat::HDF5)
.expect("Failed to create Arrow batch");
let converted_events =
arrow_to_events(&batch).expect("Failed to convert Arrow batch to events");
assert_eq!(converted_events.len(), 3);
assert!((converted_events[0].t - 0.001).abs() < 1e-9);
assert_eq!(converted_events[0].x, 100);
assert_eq!(converted_events[0].y, 200);
assert_eq!(converted_events[0].polarity, true);
}
#[test]
fn test_arrow_event_streamer() {
let events = create_test_events();
let streamer = ArrowEventStreamer::new(2, EventFormat::HDF5);
let batch = streamer
.stream_to_arrow(events.into_iter())
.expect("Failed to stream events to Arrow");
assert_eq!(batch.num_rows(), 3);
assert_eq!(batch.num_columns(), 4);
}
#[test]
fn test_arrow_disabled() {
let result = create_event_arrow_schema();
assert!(matches!(result, Err(ArrowBuilderError::FeatureNotEnabled)));
}
}