use std::sync::Arc;
use arrow_array::{RecordBatch, UInt64Array};
use arrow_ipc::writer::StreamWriter;
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use serde::de::DeserializeOwned;
use crate::error::{NanoError, Result};
use crate::json_output::{record_batches_to_json_rows, record_batches_to_rust_json_rows};
#[derive(Debug, Clone, Copy, Default)]
pub struct MutationExecResult {
pub affected_nodes: usize,
pub affected_edges: usize,
}
#[derive(Debug, Clone)]
pub struct QueryResult {
schema: SchemaRef,
batches: Vec<RecordBatch>,
}
impl QueryResult {
pub fn new(schema: SchemaRef, batches: Vec<RecordBatch>) -> Self {
Self { schema, batches }
}
pub fn schema(&self) -> &SchemaRef {
&self.schema
}
pub fn batches(&self) -> &[RecordBatch] {
&self.batches
}
pub fn into_batches(self) -> Vec<RecordBatch> {
self.batches
}
pub fn num_rows(&self) -> usize {
self.batches.iter().map(RecordBatch::num_rows).sum()
}
pub fn concat_batches(&self) -> Result<RecordBatch> {
if self.batches.is_empty() {
return Ok(RecordBatch::new_empty(self.schema.clone()));
}
arrow_select::concat::concat_batches(&self.schema, &self.batches)
.map_err(|err| NanoError::Execution(err.to_string()))
}
pub fn to_sdk_json(&self) -> serde_json::Value {
serde_json::Value::Array(record_batches_to_json_rows(&self.batches))
}
pub fn to_rust_json(&self) -> serde_json::Value {
serde_json::Value::Array(record_batches_to_rust_json_rows(&self.batches))
}
pub fn deserialize<T: DeserializeOwned>(&self) -> Result<T> {
serde_json::from_value(self.to_rust_json()).map_err(|err| {
NanoError::Execution(format!("failed to deserialize query result: {}", err))
})
}
pub fn to_arrow_ipc(&self) -> Result<Vec<u8>> {
let mut buffer = Vec::new();
let mut writer = StreamWriter::try_new(&mut buffer, &self.schema)?;
for batch in &self.batches {
writer.write(batch)?;
}
writer.finish()?;
drop(writer);
Ok(buffer)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct MutationResult {
pub affected_nodes: usize,
pub affected_edges: usize,
}
impl MutationResult {
pub fn to_sdk_json(&self) -> serde_json::Value {
serde_json::json!({
"affectedNodes": self.affected_nodes,
"affectedEdges": self.affected_edges,
})
}
pub fn to_record_batch(&self) -> Result<RecordBatch> {
let schema = Arc::new(Schema::new(vec![
Field::new("affected_nodes", DataType::UInt64, false),
Field::new("affected_edges", DataType::UInt64, false),
]));
Ok(RecordBatch::try_new(
schema,
vec![
Arc::new(UInt64Array::from(vec![self.affected_nodes as u64])),
Arc::new(UInt64Array::from(vec![self.affected_edges as u64])),
],
)?)
}
}
impl From<MutationExecResult> for MutationResult {
fn from(value: MutationExecResult) -> Self {
Self {
affected_nodes: value.affected_nodes,
affected_edges: value.affected_edges,
}
}
}
#[derive(Debug, Clone)]
pub enum RunResult {
Query(QueryResult),
Mutation(MutationResult),
}
impl RunResult {
pub fn to_sdk_json(&self) -> serde_json::Value {
match self {
Self::Query(result) => result.to_sdk_json(),
Self::Mutation(result) => result.to_sdk_json(),
}
}
pub fn into_record_batches(self) -> Result<Vec<RecordBatch>> {
match self {
Self::Query(result) => Ok(result.into_batches()),
Self::Mutation(result) => Ok(vec![result.to_record_batch()?]),
}
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use arrow_array::Int64Array;
use arrow_ipc::reader::StreamReader;
use serde::Deserialize;
use super::*;
#[test]
fn query_result_arrow_ipc_round_trips_empty_schema() {
let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, false)]));
let result = QueryResult::new(schema.clone(), vec![]);
let encoded = result.to_arrow_ipc().expect("encode empty result");
let reader = StreamReader::try_new(Cursor::new(encoded), None).expect("open stream");
assert_eq!(reader.schema().as_ref(), schema.as_ref());
assert_eq!(reader.count(), 0);
}
#[test]
fn query_result_arrow_ipc_round_trips_batches() {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(UInt64Array::from(vec![1_u64, 2_u64]))],
)
.expect("batch");
let result = QueryResult::new(schema.clone(), vec![batch]);
let encoded = result.to_arrow_ipc().expect("encode result");
let mut reader = StreamReader::try_new(Cursor::new(encoded), None).expect("open stream");
let decoded = reader.next().expect("first batch").expect("decode batch");
assert_eq!(reader.schema().as_ref(), schema.as_ref());
assert_eq!(decoded.num_rows(), 2);
assert_eq!(decoded.schema().as_ref(), schema.as_ref());
}
#[test]
fn query_result_num_rows_and_concat_cover_multiple_batches() {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)]));
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(UInt64Array::from(vec![1_u64, 2_u64]))],
)
.expect("batch1");
let batch2 = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(UInt64Array::from(vec![3_u64]))],
)
.expect("batch2");
let result = QueryResult::new(schema.clone(), vec![batch1, batch2]);
assert_eq!(result.num_rows(), 3);
let concatenated = result.concat_batches().expect("concat batches");
let ids = concatenated
.column(0)
.as_any()
.downcast_ref::<UInt64Array>()
.expect("u64 ids");
assert_eq!(concatenated.schema().as_ref(), schema.as_ref());
assert_eq!(ids.values(), &[1, 2, 3]);
}
#[test]
fn query_result_concat_empty_batches_returns_empty_batch() {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)]));
let result = QueryResult::new(schema.clone(), vec![]);
let concatenated = result.concat_batches().expect("concat empty");
assert_eq!(concatenated.schema().as_ref(), schema.as_ref());
assert_eq!(concatenated.num_rows(), 0);
}
#[test]
fn query_result_to_rust_json_preserves_wide_integers() {
let schema = Arc::new(Schema::new(vec![
Field::new("signed", DataType::Int64, false),
Field::new("unsigned", DataType::UInt64, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int64Array::from(vec![i64::MIN])),
Arc::new(UInt64Array::from(vec![u64::MAX])),
],
)
.expect("batch");
let result = QueryResult::new(schema, vec![batch]);
assert_eq!(
result.to_rust_json(),
serde_json::json!([{
"signed": i64::MIN,
"unsigned": u64::MAX,
}])
);
}
#[derive(Debug, Deserialize, PartialEq)]
struct PersonRow {
id: u64,
age: i64,
}
#[test]
fn query_result_deserialize_decodes_rust_rows() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::UInt64, false),
Field::new("age", DataType::Int64, false),
]));
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt64Array::from(vec![1_u64])),
Arc::new(Int64Array::from(vec![40_i64])),
],
)
.expect("batch1");
let batch2 = RecordBatch::try_new(
schema,
vec![
Arc::new(UInt64Array::from(vec![u64::MAX])),
Arc::new(Int64Array::from(vec![-5_i64])),
],
)
.expect("batch2");
let result = QueryResult::new(batch1.schema(), vec![batch1, batch2]);
let rows: Vec<PersonRow> = result.deserialize().expect("deserialize rows");
assert_eq!(
rows,
vec![
PersonRow { id: 1, age: 40 },
PersonRow {
id: u64::MAX,
age: -5,
},
]
);
}
}