use std::io::Cursor;
use anyhow::{Context, Result};
use arrow::array::RecordBatch;
use arrow::ipc::reader::StreamReader;
use arrow::ipc::writer::StreamWriter;
pub fn encode_record_batch(batch: &RecordBatch) -> Result<Vec<u8>> {
let mut buffer: Vec<u8> = Vec::with_capacity(8 * 1024);
{
let mut writer = StreamWriter::try_new(&mut buffer, batch.schema_ref())
.context("create Arrow IPC stream writer")?;
writer
.write(batch)
.context("write record batch to Arrow IPC stream")?;
writer
.finish()
.context("finish Arrow IPC stream writer")?;
}
Ok(buffer)
}
pub fn decode_record_batches(bytes: &[u8]) -> Result<Vec<RecordBatch>> {
let cursor = Cursor::new(bytes);
let reader = StreamReader::try_new(cursor, None)
.context("open Arrow IPC stream reader")?;
let mut batches = Vec::new();
for batch in reader {
batches.push(batch.context("read record batch from Arrow IPC stream")?);
}
Ok(batches)
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Int64Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;
fn sample_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("event_id", DataType::Int64, false),
Field::new("event_type", DataType::Utf8, false),
]));
let event_ids = Int64Array::from(vec![1, 2, 3]);
let event_types =
StringArray::from(vec!["batch.accepted", "batch.processing", "batch.completed"]);
RecordBatch::try_new(
schema,
vec![Arc::new(event_ids), Arc::new(event_types)],
)
.expect("build record batch")
}
#[test]
fn round_trip_single_batch_preserves_rows_and_schema() {
let original = sample_batch();
let encoded = encode_record_batch(&original).expect("encode");
let decoded = decode_record_batches(&encoded).expect("decode");
assert_eq!(decoded.len(), 1, "single-batch stream");
let decoded_batch = &decoded[0];
assert_eq!(decoded_batch.num_rows(), 3);
assert_eq!(decoded_batch.num_columns(), 2);
assert_eq!(
decoded_batch.schema_ref().as_ref(),
original.schema_ref().as_ref(),
);
}
#[test]
fn encoded_bytes_have_arrow_ipc_magic() {
let batch = sample_batch();
let encoded = encode_record_batch(&batch).expect("encode");
assert!(encoded.len() > 8, "non-trivial output");
assert_eq!(
&encoded[0..4],
&[0xFF, 0xFF, 0xFF, 0xFF],
"Arrow IPC continuation marker",
);
}
#[test]
fn decode_rejects_garbage_bytes() {
let result = decode_record_batches(b"not an arrow ipc stream");
assert!(result.is_err(), "garbage must not decode");
}
}