use arrow::array::RecordBatch;
use arrow::ipc::reader::StreamReader;
use arrow::ipc::writer::StreamWriter;
use arrow_schema::SchemaRef;
use crate::error::IpcError;
pub const SECRET_HANDLE_EXTENSION: &str = "uni-db.secret-handle";
const ARROW_EXTENSION_KEY: &str = "ARROW:extension:name";
fn reject_secret_handles(batch: &RecordBatch) -> Result<(), IpcError> {
fn walk(field: &arrow_schema::Field) -> Result<(), IpcError> {
use arrow_schema::DataType;
if field
.metadata()
.get(ARROW_EXTENSION_KEY)
.map(String::as_str)
== Some(SECRET_HANDLE_EXTENSION)
{
return Err(IpcError::SecretLeakAttempt {
column: field.name().clone(),
});
}
match field.data_type() {
DataType::Struct(fields) => fields.iter().try_for_each(|f| walk(f.as_ref())),
DataType::List(item) | DataType::LargeList(item) | DataType::FixedSizeList(item, _) => {
walk(item.as_ref())
}
DataType::Map(field, _) => walk(field.as_ref()),
_ => Ok(()),
}
}
batch
.schema()
.fields()
.iter()
.try_for_each(|f| walk(f.as_ref()))
}
fn reject_all(batches: &[RecordBatch]) -> Result<(), IpcError> {
batches.iter().try_for_each(reject_secret_handles)
}
pub fn encode_batch(batch: &RecordBatch) -> Result<Vec<u8>, IpcError> {
reject_secret_handles(batch)?;
let mut buf: Vec<u8> = Vec::with_capacity(estimate_size(batch));
write_stream(&mut buf, batch.schema(), std::slice::from_ref(batch))?;
Ok(buf)
}
pub fn encode_batches(batches: &[RecordBatch]) -> Result<Vec<u8>, IpcError> {
let first = batches.first().ok_or(IpcError::EmptyBatchInput)?;
reject_all(batches)?;
let mut buf: Vec<u8> = Vec::with_capacity(estimate_size(first).saturating_mul(batches.len()));
write_stream(&mut buf, first.schema(), batches)?;
Ok(buf)
}
fn write_stream(
buf: &mut Vec<u8>,
schema: SchemaRef,
batches: &[RecordBatch],
) -> Result<(), IpcError> {
let mut w = StreamWriter::try_new(buf, schema.as_ref())
.map_err(|e| IpcError::Arrow(format!("writer setup: {e}")))?;
for b in batches {
w.write(b)
.map_err(|e| IpcError::Arrow(format!("write batch: {e}")))?;
}
w.finish()
.map_err(|e| IpcError::Arrow(format!("finish: {e}")))?;
Ok(())
}
pub fn decode_batch(bytes: &[u8]) -> Result<Option<RecordBatch>, IpcError> {
let batches = read_stream(bytes, "read batch")?;
reject_all(&batches)?;
match batches.len() {
0 => Ok(None),
1 => Ok(batches.into_iter().next()),
n => Err(IpcError::Arrow(format!(
"decode_batch expects a single-batch stream, got {n} batches"
))),
}
}
pub fn decode_batches(bytes: &[u8]) -> Result<Vec<RecordBatch>, IpcError> {
let batches = read_stream(bytes, "read batches")?;
reject_all(&batches)?;
Ok(batches)
}
fn read_stream(bytes: &[u8], read_label: &str) -> Result<Vec<RecordBatch>, IpcError> {
let reader = StreamReader::try_new(bytes, None)
.map_err(|e| IpcError::Arrow(format!("reader setup: {e}")))?;
reader
.collect::<Result<Vec<_>, _>>()
.map_err(|e| IpcError::Arrow(format!("{read_label}: {e}")))
}
fn estimate_size(batch: &RecordBatch) -> usize {
let rows = batch.num_rows();
let cols = batch.num_columns();
rows.saturating_mul(cols).saturating_mul(16) + 4096
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use arrow::array::{
Array, BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, LargeBinaryArray,
ListArray, StringArray, StructArray, TimestampMillisecondArray,
};
use arrow::buffer::OffsetBuffer;
use arrow_schema::{DataType, Field, Fields, Schema, TimeUnit};
fn schema_for(name: &str, dt: DataType) -> SchemaRef {
Arc::new(Schema::new(vec![Field::new(name, dt, true)]))
}
fn one_col_batch(name: &str, col: Arc<dyn arrow::array::Array>) -> RecordBatch {
let dt = col.data_type().clone();
let schema = schema_for(name, dt);
RecordBatch::try_new(schema, vec![col]).unwrap()
}
#[test]
fn round_trip_int64() {
let arr: Arc<dyn arrow::array::Array> = Arc::new(Int64Array::from(vec![1, 2, 3]));
let batch = one_col_batch("x", arr);
let encoded = encode_batch(&batch).unwrap();
let decoded = decode_batch(&encoded).unwrap().unwrap();
assert_eq!(decoded.num_rows(), 3);
}
#[test]
fn round_trip_int32_float32_float64() {
let schema = Arc::new(Schema::new(vec![
Field::new("i32", DataType::Int32, true),
Field::new("f32", DataType::Float32, true),
Field::new("f64", DataType::Float64, true),
]));
let i: Arc<dyn arrow::array::Array> = Arc::new(Int32Array::from(vec![1, 2]));
let f32a: Arc<dyn arrow::array::Array> = Arc::new(Float32Array::from(vec![1.5_f32, 2.5]));
let f64a: Arc<dyn arrow::array::Array> = Arc::new(Float64Array::from(vec![10.5_f64, 20.5]));
let batch = RecordBatch::try_new(schema, vec![i, f32a, f64a]).unwrap();
let encoded = encode_batch(&batch).unwrap();
let decoded = decode_batch(&encoded).unwrap().unwrap();
assert_eq!(decoded.num_rows(), 2);
let f64_out = decoded
.column(2)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert!((f64_out.value(1) - 20.5).abs() < f64::EPSILON);
}
#[test]
fn round_trip_utf8_strings_including_unicode() {
let arr: Arc<dyn arrow::array::Array> =
Arc::new(StringArray::from(vec!["hello", "naïve", "🌳", ""]));
let batch = one_col_batch("s", arr);
let encoded = encode_batch(&batch).unwrap();
let decoded = decode_batch(&encoded).unwrap().unwrap();
let col = decoded
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(col.value(2), "🌳");
assert_eq!(col.value(3), "");
}
#[test]
fn round_trip_booleans_with_nulls() {
let arr: Arc<dyn arrow::array::Array> =
Arc::new(BooleanArray::from(vec![Some(true), None, Some(false)]));
let batch = one_col_batch("b", arr);
let encoded = encode_batch(&batch).unwrap();
let decoded = decode_batch(&encoded).unwrap().unwrap();
let col = decoded
.column(0)
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap();
assert!(col.is_null(1));
assert!(col.value(0));
assert!(!col.value(2));
}
#[test]
fn round_trip_timestamp_ms() {
let arr: Arc<dyn arrow::array::Array> = Arc::new(
TimestampMillisecondArray::from(vec![1_700_000_000_000_i64, 1_800_000_000_000])
.with_timezone_opt::<&str>(None),
);
let batch = one_col_batch("ts", arr);
let encoded = encode_batch(&batch).unwrap();
let decoded = decode_batch(&encoded).unwrap().unwrap();
assert!(matches!(
decoded.schema().field(0).data_type(),
DataType::Timestamp(TimeUnit::Millisecond, _)
));
}
#[test]
fn round_trip_large_binary_for_cypher_values() {
let arr: Arc<dyn arrow::array::Array> = Arc::new(LargeBinaryArray::from(vec![
&[1_u8, 2, 3][..],
&[4, 5, 6, 7],
]));
let batch = one_col_batch("v", arr);
let encoded = encode_batch(&batch).unwrap();
let decoded = decode_batch(&encoded).unwrap().unwrap();
let col = decoded
.column(0)
.as_any()
.downcast_ref::<LargeBinaryArray>()
.unwrap();
assert_eq!(col.value(0), &[1, 2, 3]);
assert_eq!(col.value(1), &[4, 5, 6, 7]);
}
#[test]
fn round_trip_list_of_int64() {
let values: Arc<dyn arrow::array::Array> =
Arc::new(Int64Array::from(vec![1_i64, 2, 3, 4, 5, 6]));
let offsets = OffsetBuffer::new(vec![0_i32, 2, 5, 6].into());
let field = Arc::new(Field::new("item", DataType::Int64, true));
let list = ListArray::new(field, offsets, values, None);
let arr: Arc<dyn arrow::array::Array> = Arc::new(list);
let batch = one_col_batch("xs", arr);
let encoded = encode_batch(&batch).unwrap();
let decoded = decode_batch(&encoded).unwrap().unwrap();
let col = decoded
.column(0)
.as_any()
.downcast_ref::<ListArray>()
.unwrap();
assert_eq!(col.len(), 3);
assert_eq!(col.value_length(1), 3);
}
#[test]
fn round_trip_struct_array() {
let id: Arc<dyn arrow::array::Array> = Arc::new(Int64Array::from(vec![10, 20]));
let label: Arc<dyn arrow::array::Array> = Arc::new(StringArray::from(vec!["a", "b"]));
let fields = Fields::from(vec![
Field::new("id", DataType::Int64, false),
Field::new("label", DataType::Utf8, false),
]);
let s = StructArray::new(fields, vec![id, label], None);
let arr: Arc<dyn arrow::array::Array> = Arc::new(s);
let batch = one_col_batch("rec", arr);
let encoded = encode_batch(&batch).unwrap();
let decoded = decode_batch(&encoded).unwrap().unwrap();
assert_eq!(decoded.num_rows(), 2);
assert!(matches!(
decoded.schema().field(0).data_type(),
DataType::Struct(_)
));
}
#[test]
fn decode_empty_stream_returns_none() {
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
let mut buf: Vec<u8> = Vec::new();
{
let mut w = StreamWriter::try_new(&mut buf, schema.as_ref()).unwrap();
w.finish().unwrap();
}
assert!(decode_batch(&buf).unwrap().is_none());
}
#[test]
fn decode_garbage_bytes_is_arrow_ipc_error() {
let err = decode_batch(b"not arrow ipc").unwrap_err();
assert!(matches!(err, IpcError::Arrow(_)));
}
#[test]
fn encode_batches_emits_multiple_in_one_stream() {
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, true)]));
let a: Arc<dyn arrow::array::Array> = Arc::new(Int64Array::from(vec![1_i64, 2]));
let b: Arc<dyn arrow::array::Array> = Arc::new(Int64Array::from(vec![3_i64, 4, 5]));
let ba = RecordBatch::try_new(schema.clone(), vec![a]).unwrap();
let bb = RecordBatch::try_new(schema, vec![b]).unwrap();
let encoded = encode_batches(&[ba, bb]).unwrap();
let all = decode_batches(&encoded).unwrap();
assert_eq!(all.len(), 2);
assert_eq!(all[0].num_rows(), 2);
assert_eq!(all[1].num_rows(), 3);
}
#[test]
fn encode_batches_rejects_empty_input() {
let err = encode_batches(&[]).unwrap_err();
assert!(matches!(err, IpcError::EmptyBatchInput));
}
fn secret_tagged_field(name: &str) -> Field {
Field::new(name, DataType::FixedSizeBinary(8), false).with_metadata(
std::collections::HashMap::from([(
"ARROW:extension:name".to_owned(),
SECRET_HANDLE_EXTENSION.to_owned(),
)]),
)
}
#[test]
fn encode_batch_rejects_secret_handle_column() {
use arrow::array::FixedSizeBinaryArray;
let schema = Arc::new(Schema::new(vec![secret_tagged_field("api_key_handle")]));
let arr =
FixedSizeBinaryArray::try_from_iter([[0u8; 8], [1; 8]].iter().map(|b| b.as_slice()))
.unwrap();
let batch = RecordBatch::try_new(schema, vec![Arc::new(arr)]).unwrap();
match encode_batch(&batch) {
Ok(_) => panic!("encode_batch must reject secret-handle columns"),
Err(IpcError::SecretLeakAttempt { column }) => {
assert_eq!(column, "api_key_handle");
}
Err(other) => panic!("expected SecretLeakAttempt, got {other:?}"),
}
}
#[test]
fn decode_batches_rejects_secret_handle_column() {
use arrow::array::FixedSizeBinaryArray;
let plain_field = Field::new("api_key_handle", DataType::FixedSizeBinary(8), false);
let schema = Arc::new(Schema::new(vec![plain_field]));
let arr =
FixedSizeBinaryArray::try_from_iter([[0u8; 8]].iter().map(|b| b.as_slice())).unwrap();
let batch = RecordBatch::try_new(schema, vec![Arc::new(arr)]).unwrap();
let encoded = encode_batch(&batch).unwrap();
let tagged_schema = Arc::new(Schema::new(vec![secret_tagged_field("api_key_handle")]));
let arr2 =
FixedSizeBinaryArray::try_from_iter([[0u8; 8]].iter().map(|b| b.as_slice())).unwrap();
let tagged = RecordBatch::try_new(tagged_schema, vec![Arc::new(arr2)]).unwrap();
let mut buf: Vec<u8> = Vec::new();
{
let mut w = StreamWriter::try_new(&mut buf, tagged.schema().as_ref()).unwrap();
w.write(&tagged).unwrap();
w.finish().unwrap();
}
match decode_batches(&buf) {
Ok(_) => panic!("decode_batches must reject secret-handle columns"),
Err(IpcError::SecretLeakAttempt { column }) => {
assert_eq!(column, "api_key_handle");
}
Err(other) => panic!("expected SecretLeakAttempt, got {other:?}"),
}
assert!(!encoded.is_empty());
}
#[test]
fn decode_batch_rejects_secret_handle_column() {
use arrow::array::FixedSizeBinaryArray;
let tagged_schema = Arc::new(Schema::new(vec![secret_tagged_field("api_key_handle")]));
let arr =
FixedSizeBinaryArray::try_from_iter([[0u8; 8]].iter().map(|b| b.as_slice())).unwrap();
let tagged = RecordBatch::try_new(tagged_schema, vec![Arc::new(arr)]).unwrap();
let mut buf: Vec<u8> = Vec::new();
{
let mut w = StreamWriter::try_new(&mut buf, tagged.schema().as_ref()).unwrap();
w.write(&tagged).unwrap();
w.finish().unwrap();
}
match decode_batch(&buf) {
Ok(_) => panic!("decode_batch must reject secret-handle columns"),
Err(IpcError::SecretLeakAttempt { column }) => {
assert_eq!(column, "api_key_handle");
}
Err(other) => panic!("expected SecretLeakAttempt, got {other:?}"),
}
}
#[test]
fn encode_batch_rejects_secret_handle_inside_struct() {
use arrow::array::Int64Array;
let plain = Field::new("id", DataType::Int64, false);
let secret = secret_tagged_field("handle");
let struct_field = Field::new(
"rec",
DataType::Struct(Fields::from(vec![plain, secret])),
false,
);
let schema = Arc::new(Schema::new(vec![struct_field]));
let id_arr: Arc<dyn arrow::array::Array> = Arc::new(Int64Array::from(vec![1, 2]));
let secret_arr: Arc<dyn arrow::array::Array> = Arc::new(
arrow::array::FixedSizeBinaryArray::try_from_iter(
[[0u8; 8], [1; 8]].iter().map(|b| b.as_slice()),
)
.unwrap(),
);
let s = StructArray::new(
Fields::from(vec![
Field::new("id", DataType::Int64, false),
Field::new("handle", DataType::FixedSizeBinary(8), false).with_metadata(
std::collections::HashMap::from([(
"ARROW:extension:name".to_owned(),
SECRET_HANDLE_EXTENSION.to_owned(),
)]),
),
]),
vec![id_arr, secret_arr],
None,
);
let batch = RecordBatch::try_new(schema, vec![Arc::new(s)]).unwrap();
match encode_batch(&batch) {
Ok(_) => panic!("encode_batch must reject nested secret-handle"),
Err(IpcError::SecretLeakAttempt { column }) => {
assert_eq!(column, "handle");
}
Err(other) => panic!("expected SecretLeakAttempt, got {other:?}"),
}
}
}