use crate::error::{DbxError, DbxResult};
use arrow::ipc::{reader, writer};
use arrow::record_batch::RecordBatch;
use std::io::Cursor;
pub fn write_ipc_batch(batch: &RecordBatch) -> DbxResult<Vec<u8>> {
let mut buffer = Vec::new();
{
let mut writer = writer::FileWriter::try_new(&mut buffer, &batch.schema())
.map_err(|e| DbxError::Storage(format!("Arrow IPC write error: {}", e)))?;
writer
.write(batch)
.map_err(|e| DbxError::Storage(format!("Arrow IPC batch write error: {}", e)))?;
writer
.finish()
.map_err(|e| DbxError::Storage(format!("Arrow IPC finish error: {}", e)))?;
}
Ok(buffer)
}
pub fn read_ipc_batch(bytes: &[u8]) -> DbxResult<RecordBatch> {
let cursor = Cursor::new(bytes);
let mut reader = reader::FileReader::try_new(cursor, None)
.map_err(|e| DbxError::Storage(format!("Arrow IPC read error: {}", e)))?;
let batch = reader
.next()
.ok_or_else(|| DbxError::Storage("No batch in Arrow IPC file".to_string()))?
.map_err(|e| DbxError::Storage(format!("Arrow IPC batch read error: {}", e)))?;
Ok(batch)
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Int64Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;
#[test]
fn test_write_read_ipc_batch() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int64Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])),
],
)
.unwrap();
let ipc_bytes = write_ipc_batch(&batch).unwrap();
let restored_batch = read_ipc_batch(&ipc_bytes).unwrap();
assert_eq!(batch.num_rows(), restored_batch.num_rows());
assert_eq!(batch.num_columns(), restored_batch.num_columns());
assert_eq!(batch.schema(), restored_batch.schema());
}
}