use std::io::Cursor;
use std::sync::Arc;
use anyhow::{Context, Result};
use arrow::array::{ArrayRef, BooleanArray, Float64Array, Int64Array, RecordBatch, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::ipc::reader::StreamReader;
use arrow::ipc::writer::StreamWriter;
pub fn encode_record_batch(batch: &RecordBatch) -> Result<Vec<u8>> {
let mut buffer: Vec<u8> = Vec::with_capacity(8 * 1024);
{
let mut writer = StreamWriter::try_new(&mut buffer, batch.schema_ref())
.context("create Arrow IPC stream writer")?;
writer
.write(batch)
.context("write record batch to Arrow IPC stream")?;
writer.finish().context("finish Arrow IPC stream writer")?;
}
Ok(buffer)
}
pub fn decode_record_batches(bytes: &[u8]) -> Result<Vec<RecordBatch>> {
let cursor = Cursor::new(bytes);
let reader = StreamReader::try_new(cursor, None).context("open Arrow IPC stream reader")?;
let mut batches = Vec::new();
for batch in reader {
batches.push(batch.context("read record batch from Arrow IPC stream")?);
}
Ok(batches)
}
#[derive(Debug, Clone)]
pub struct TabularEncoding {
pub bytes: Vec<u8>,
pub row_count: usize,
pub schema_digest: String,
pub media_type: &'static str,
}
pub const ARROW_STREAM_MEDIA_TYPE: &str = "application/vnd.apache.arrow.stream";
pub fn try_encode_tabular_json(value: &serde_json::Value) -> Option<TabularEncoding> {
let payload = if value.get("rows").is_some() {
value
} else if let Some(nested) = value.get("data") {
if nested.get("rows").is_some() {
nested
} else {
return None;
}
} else {
return None;
};
let rows = payload.get("rows").and_then(|r| r.as_array())?;
if rows.is_empty() {
return None;
}
let column_names: Vec<String> =
if let Some(cols) = payload.get("columns").and_then(|c| c.as_array()) {
cols.iter()
.map(|v| v.as_str().unwrap_or_default().to_string())
.collect()
} else if let Some(first_obj) = rows.first().and_then(|r| r.as_object()) {
first_obj.keys().cloned().collect()
} else {
return None;
};
if column_names.is_empty() {
return None;
}
let mut columns_inferred: Vec<(String, ColumnType, Vec<Option<serde_json::Value>>)> =
column_names
.into_iter()
.map(|name| (name, ColumnType::Unset, Vec::with_capacity(rows.len())))
.collect();
for row in rows {
for (idx, (name, ctype, values)) in columns_inferred.iter_mut().enumerate() {
let cell = if let Some(obj) = row.as_object() {
obj.get(name).cloned()
} else if let Some(arr) = row.as_array() {
arr.get(idx).cloned()
} else {
return None;
};
let cell = cell.unwrap_or(serde_json::Value::Null);
if !cell.is_null() {
ctype.observe(&cell);
}
values.push(if cell.is_null() { None } else { Some(cell) });
}
}
let mut fields: Vec<Field> = Vec::with_capacity(columns_inferred.len());
let mut arrays: Vec<ArrayRef> = Vec::with_capacity(columns_inferred.len());
for (name, ctype, values) in columns_inferred {
let data_type = ctype.to_arrow();
let array: ArrayRef = match data_type {
DataType::Int64 => Arc::new(Int64Array::from_iter(
values
.iter()
.map(|v| v.as_ref().and_then(|cell| cell.as_i64())),
)),
DataType::Float64 => Arc::new(Float64Array::from_iter(
values
.iter()
.map(|v| v.as_ref().and_then(|cell| cell.as_f64())),
)),
DataType::Boolean => Arc::new(BooleanArray::from_iter(
values
.iter()
.map(|v| v.as_ref().and_then(|cell| cell.as_bool())),
)),
DataType::Utf8 => Arc::new(StringArray::from_iter(values.iter().map(|v| {
v.as_ref().map(|cell| match cell {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
})
}))),
_ => unreachable!("ColumnType::to_arrow yields only Int64/Float64/Boolean/Utf8"),
};
fields.push(Field::new(name.clone(), data_type, true));
arrays.push(array);
}
let schema = Arc::new(Schema::new(fields));
let row_count = arrays.first().map(|a| a.len()).unwrap_or(0);
let batch = RecordBatch::try_new(schema, arrays).ok()?;
let bytes = encode_record_batch(&batch).ok()?;
Some(TabularEncoding {
bytes,
row_count,
schema_digest: "arrow".to_string(),
media_type: ARROW_STREAM_MEDIA_TYPE,
})
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ColumnType {
Unset,
Int64,
Float64,
Boolean,
Utf8,
}
impl ColumnType {
fn observe(&mut self, value: &serde_json::Value) {
let next = match value {
serde_json::Value::Bool(_) => match self {
ColumnType::Unset | ColumnType::Boolean => ColumnType::Boolean,
_ => ColumnType::Utf8,
},
serde_json::Value::Number(n) => {
if n.is_i64() || n.is_u64() {
match self {
ColumnType::Unset | ColumnType::Int64 => ColumnType::Int64,
ColumnType::Float64 => ColumnType::Float64,
_ => ColumnType::Utf8,
}
} else {
match self {
ColumnType::Unset | ColumnType::Int64 | ColumnType::Float64 => {
ColumnType::Float64
}
_ => ColumnType::Utf8,
}
}
}
serde_json::Value::String(_)
| serde_json::Value::Array(_)
| serde_json::Value::Object(_) => ColumnType::Utf8,
serde_json::Value::Null => *self,
};
*self = next;
}
fn to_arrow(self) -> DataType {
match self {
ColumnType::Int64 => DataType::Int64,
ColumnType::Float64 => DataType::Float64,
ColumnType::Boolean => DataType::Boolean,
ColumnType::Unset | ColumnType::Utf8 => DataType::Utf8,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Int64Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;
fn sample_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("event_id", DataType::Int64, false),
Field::new("event_type", DataType::Utf8, false),
]));
let event_ids = Int64Array::from(vec![1, 2, 3]);
let event_types = StringArray::from(vec![
"batch.accepted",
"batch.processing",
"batch.completed",
]);
RecordBatch::try_new(schema, vec![Arc::new(event_ids), Arc::new(event_types)])
.expect("build record batch")
}
#[test]
fn round_trip_single_batch_preserves_rows_and_schema() {
let original = sample_batch();
let encoded = encode_record_batch(&original).expect("encode");
let decoded = decode_record_batches(&encoded).expect("decode");
assert_eq!(decoded.len(), 1, "single-batch stream");
let decoded_batch = &decoded[0];
assert_eq!(decoded_batch.num_rows(), 3);
assert_eq!(decoded_batch.num_columns(), 2);
assert_eq!(
decoded_batch.schema_ref().as_ref(),
original.schema_ref().as_ref(),
);
}
#[test]
fn encoded_bytes_have_arrow_ipc_magic() {
let batch = sample_batch();
let encoded = encode_record_batch(&batch).expect("encode");
assert!(encoded.len() > 8, "non-trivial output");
assert_eq!(
&encoded[0..4],
&[0xFF, 0xFF, 0xFF, 0xFF],
"Arrow IPC continuation marker",
);
}
#[test]
fn decode_rejects_garbage_bytes() {
let result = decode_record_batches(b"not an arrow ipc stream");
assert!(result.is_err(), "garbage must not decode");
}
#[test]
fn tabular_array_rows_round_trip_via_arrow_ipc() {
let payload = serde_json::json!({
"columns": ["id", "name", "score", "active"],
"rows": [
[1, "alice", 0.95, true],
[2, "bob", 0.72, false],
[3, "carol", 0.88, true],
],
"row_count": 3,
});
let encoded = try_encode_tabular_json(&payload).expect("must encode");
assert_eq!(encoded.row_count, 3);
assert_eq!(encoded.media_type, ARROW_STREAM_MEDIA_TYPE);
assert_eq!(encoded.schema_digest, "arrow");
let batches = decode_record_batches(&encoded.bytes).expect("decode");
assert_eq!(batches.len(), 1);
let batch = &batches[0];
assert_eq!(batch.num_rows(), 3);
assert_eq!(batch.num_columns(), 4);
assert_eq!(batch.schema().field(0).data_type(), &DataType::Int64);
assert_eq!(batch.schema().field(1).data_type(), &DataType::Utf8);
assert_eq!(batch.schema().field(2).data_type(), &DataType::Float64);
assert_eq!(batch.schema().field(3).data_type(), &DataType::Boolean);
}
#[test]
fn tabular_object_rows_round_trip_via_arrow_ipc() {
let payload = serde_json::json!({
"rows": [
{"id": 1, "label": "x"},
{"id": 2, "label": "y"},
],
"row_count": 2,
});
let encoded = try_encode_tabular_json(&payload).expect("must encode");
assert_eq!(encoded.row_count, 2);
let batches = decode_record_batches(&encoded.bytes).expect("decode");
let batch = &batches[0];
assert_eq!(batch.num_rows(), 2);
assert_eq!(batch.num_columns(), 2);
}
#[test]
fn tabular_nested_under_data_round_trip_via_arrow_ipc() {
let payload = serde_json::json!({
"status": "Success",
"data": {
"columns": ["a"],
"rows": [[1], [2], [3]],
"row_count": 3,
},
"duration_ms": 12,
});
let encoded = try_encode_tabular_json(&payload).expect("must encode nested-under-data");
assert_eq!(encoded.row_count, 3);
}
#[test]
fn tabular_mixed_type_column_collapses_to_utf8() {
let payload = serde_json::json!({
"columns": ["mixed"],
"rows": [
[1],
["two"],
[4.2],
],
});
let encoded = try_encode_tabular_json(&payload).expect("must encode");
let batches = decode_record_batches(&encoded.bytes).expect("decode");
assert_eq!(batches[0].schema().field(0).data_type(), &DataType::Utf8);
assert_eq!(batches[0].num_rows(), 3);
}
#[test]
fn tabular_null_cells_round_trip_as_arrow_nulls() {
let payload = serde_json::json!({
"columns": ["id"],
"rows": [
[1],
[null],
[3],
],
});
let encoded = try_encode_tabular_json(&payload).expect("must encode");
let batches = decode_record_batches(&encoded.bytes).expect("decode");
let batch = &batches[0];
let col = batch.column(0);
assert_eq!(col.len(), 3);
assert!(col.is_null(1));
}
#[test]
fn tabular_empty_rows_returns_none() {
let payload = serde_json::json!({
"columns": ["id"],
"rows": [],
});
assert!(try_encode_tabular_json(&payload).is_none());
}
#[test]
fn tabular_missing_rows_returns_none() {
let payload = serde_json::json!({
"stdout": "hello",
"exit_code": 0,
});
assert!(try_encode_tabular_json(&payload).is_none());
}
#[test]
fn tabular_non_object_input_returns_none() {
assert!(try_encode_tabular_json(&serde_json::json!([1, 2, 3])).is_none());
assert!(try_encode_tabular_json(&serde_json::json!("hello")).is_none());
assert!(try_encode_tabular_json(&serde_json::json!(42)).is_none());
assert!(try_encode_tabular_json(&serde_json::Value::Null).is_none());
}
#[test]
fn tabular_all_null_column_infers_utf8() {
let payload = serde_json::json!({
"columns": ["nullable"],
"rows": [
[null],
[null],
],
});
let encoded = try_encode_tabular_json(&payload).expect("must encode");
let batches = decode_record_batches(&encoded.bytes).expect("decode");
let batch = &batches[0];
assert_eq!(batch.schema().field(0).data_type(), &DataType::Utf8);
assert!(batch.column(0).is_null(0));
assert!(batch.column(0).is_null(1));
}
}