Skip to main content

omnigraph_compiler/
result.rs

1use std::sync::Arc;
2
3use arrow_array::{RecordBatch, UInt64Array};
4use arrow_ipc::writer::StreamWriter;
5use arrow_schema::{DataType, Field, Schema, SchemaRef};
6use serde::de::DeserializeOwned;
7
8use crate::error::{NanoError, Result};
9use crate::json_output::{record_batches_to_json_rows, record_batches_to_rust_json_rows};
10
11#[derive(Debug, Clone, Copy, Default)]
12pub struct MutationExecResult {
13    pub affected_nodes: usize,
14    pub affected_edges: usize,
15}
16
17#[derive(Debug, Clone)]
18pub struct QueryResult {
19    schema: SchemaRef,
20    batches: Vec<RecordBatch>,
21}
22
23impl QueryResult {
24    pub fn new(schema: SchemaRef, batches: Vec<RecordBatch>) -> Self {
25        Self { schema, batches }
26    }
27
28    pub fn schema(&self) -> &SchemaRef {
29        &self.schema
30    }
31
32    pub fn batches(&self) -> &[RecordBatch] {
33        &self.batches
34    }
35
36    pub fn into_batches(self) -> Vec<RecordBatch> {
37        self.batches
38    }
39
40    pub fn num_rows(&self) -> usize {
41        self.batches.iter().map(RecordBatch::num_rows).sum()
42    }
43
44    pub fn concat_batches(&self) -> Result<RecordBatch> {
45        if self.batches.is_empty() {
46            return Ok(RecordBatch::new_empty(self.schema.clone()));
47        }
48
49        arrow_select::concat::concat_batches(&self.schema, &self.batches)
50            .map_err(|err| NanoError::Execution(err.to_string()))
51    }
52
53    pub fn to_sdk_json(&self) -> serde_json::Value {
54        serde_json::Value::Array(record_batches_to_json_rows(&self.batches))
55    }
56
57    pub fn to_rust_json(&self) -> serde_json::Value {
58        serde_json::Value::Array(record_batches_to_rust_json_rows(&self.batches))
59    }
60
61    pub fn deserialize<T: DeserializeOwned>(&self) -> Result<T> {
62        serde_json::from_value(self.to_rust_json()).map_err(|err| {
63            NanoError::Execution(format!("failed to deserialize query result: {}", err))
64        })
65    }
66
67    pub fn to_arrow_ipc(&self) -> Result<Vec<u8>> {
68        let mut buffer = Vec::new();
69        let mut writer = StreamWriter::try_new(&mut buffer, &self.schema)?;
70        for batch in &self.batches {
71            writer.write(batch)?;
72        }
73        writer.finish()?;
74        drop(writer);
75        Ok(buffer)
76    }
77}
78
79#[derive(Debug, Clone, Copy, Default)]
80pub struct MutationResult {
81    pub affected_nodes: usize,
82    pub affected_edges: usize,
83}
84
85impl MutationResult {
86    pub fn to_sdk_json(&self) -> serde_json::Value {
87        serde_json::json!({
88            "affectedNodes": self.affected_nodes,
89            "affectedEdges": self.affected_edges,
90        })
91    }
92
93    pub fn to_record_batch(&self) -> Result<RecordBatch> {
94        let schema = Arc::new(Schema::new(vec![
95            Field::new("affected_nodes", DataType::UInt64, false),
96            Field::new("affected_edges", DataType::UInt64, false),
97        ]));
98        Ok(RecordBatch::try_new(
99            schema,
100            vec![
101                Arc::new(UInt64Array::from(vec![self.affected_nodes as u64])),
102                Arc::new(UInt64Array::from(vec![self.affected_edges as u64])),
103            ],
104        )?)
105    }
106}
107
108impl From<MutationExecResult> for MutationResult {
109    fn from(value: MutationExecResult) -> Self {
110        Self {
111            affected_nodes: value.affected_nodes,
112            affected_edges: value.affected_edges,
113        }
114    }
115}
116
117#[derive(Debug, Clone)]
118pub enum RunResult {
119    Query(QueryResult),
120    Mutation(MutationResult),
121}
122
123impl RunResult {
124    pub fn to_sdk_json(&self) -> serde_json::Value {
125        match self {
126            Self::Query(result) => result.to_sdk_json(),
127            Self::Mutation(result) => result.to_sdk_json(),
128        }
129    }
130
131    pub fn into_record_batches(self) -> Result<Vec<RecordBatch>> {
132        match self {
133            Self::Query(result) => Ok(result.into_batches()),
134            Self::Mutation(result) => Ok(vec![result.to_record_batch()?]),
135        }
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use std::io::Cursor;
142
143    use arrow_array::Int64Array;
144    use arrow_ipc::reader::StreamReader;
145    use serde::Deserialize;
146
147    use super::*;
148
149    #[test]
150    fn query_result_arrow_ipc_round_trips_empty_schema() {
151        let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, false)]));
152        let result = QueryResult::new(schema.clone(), vec![]);
153
154        let encoded = result.to_arrow_ipc().expect("encode empty result");
155        let reader = StreamReader::try_new(Cursor::new(encoded), None).expect("open stream");
156
157        assert_eq!(reader.schema().as_ref(), schema.as_ref());
158        assert_eq!(reader.count(), 0);
159    }
160
161    #[test]
162    fn query_result_arrow_ipc_round_trips_batches() {
163        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)]));
164        let batch = RecordBatch::try_new(
165            schema.clone(),
166            vec![Arc::new(UInt64Array::from(vec![1_u64, 2_u64]))],
167        )
168        .expect("batch");
169        let result = QueryResult::new(schema.clone(), vec![batch]);
170
171        let encoded = result.to_arrow_ipc().expect("encode result");
172        let mut reader = StreamReader::try_new(Cursor::new(encoded), None).expect("open stream");
173        let decoded = reader.next().expect("first batch").expect("decode batch");
174
175        assert_eq!(reader.schema().as_ref(), schema.as_ref());
176        assert_eq!(decoded.num_rows(), 2);
177        assert_eq!(decoded.schema().as_ref(), schema.as_ref());
178    }
179
180    #[test]
181    fn query_result_num_rows_and_concat_cover_multiple_batches() {
182        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)]));
183        let batch1 = RecordBatch::try_new(
184            schema.clone(),
185            vec![Arc::new(UInt64Array::from(vec![1_u64, 2_u64]))],
186        )
187        .expect("batch1");
188        let batch2 = RecordBatch::try_new(
189            schema.clone(),
190            vec![Arc::new(UInt64Array::from(vec![3_u64]))],
191        )
192        .expect("batch2");
193        let result = QueryResult::new(schema.clone(), vec![batch1, batch2]);
194
195        assert_eq!(result.num_rows(), 3);
196
197        let concatenated = result.concat_batches().expect("concat batches");
198        let ids = concatenated
199            .column(0)
200            .as_any()
201            .downcast_ref::<UInt64Array>()
202            .expect("u64 ids");
203        assert_eq!(concatenated.schema().as_ref(), schema.as_ref());
204        assert_eq!(ids.values(), &[1, 2, 3]);
205    }
206
207    #[test]
208    fn query_result_concat_empty_batches_returns_empty_batch() {
209        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)]));
210        let result = QueryResult::new(schema.clone(), vec![]);
211
212        let concatenated = result.concat_batches().expect("concat empty");
213
214        assert_eq!(concatenated.schema().as_ref(), schema.as_ref());
215        assert_eq!(concatenated.num_rows(), 0);
216    }
217
218    #[test]
219    fn query_result_to_rust_json_preserves_wide_integers() {
220        let schema = Arc::new(Schema::new(vec![
221            Field::new("signed", DataType::Int64, false),
222            Field::new("unsigned", DataType::UInt64, false),
223        ]));
224        let batch = RecordBatch::try_new(
225            schema.clone(),
226            vec![
227                Arc::new(Int64Array::from(vec![i64::MIN])),
228                Arc::new(UInt64Array::from(vec![u64::MAX])),
229            ],
230        )
231        .expect("batch");
232        let result = QueryResult::new(schema, vec![batch]);
233
234        assert_eq!(
235            result.to_rust_json(),
236            serde_json::json!([{
237                "signed": i64::MIN,
238                "unsigned": u64::MAX,
239            }])
240        );
241    }
242
243    #[derive(Debug, Deserialize, PartialEq)]
244    struct PersonRow {
245        id: u64,
246        age: i64,
247    }
248
249    #[test]
250    fn query_result_deserialize_decodes_rust_rows() {
251        let schema = Arc::new(Schema::new(vec![
252            Field::new("id", DataType::UInt64, false),
253            Field::new("age", DataType::Int64, false),
254        ]));
255        let batch1 = RecordBatch::try_new(
256            schema.clone(),
257            vec![
258                Arc::new(UInt64Array::from(vec![1_u64])),
259                Arc::new(Int64Array::from(vec![40_i64])),
260            ],
261        )
262        .expect("batch1");
263        let batch2 = RecordBatch::try_new(
264            schema,
265            vec![
266                Arc::new(UInt64Array::from(vec![u64::MAX])),
267                Arc::new(Int64Array::from(vec![-5_i64])),
268            ],
269        )
270        .expect("batch2");
271        let result = QueryResult::new(batch1.schema(), vec![batch1, batch2]);
272
273        let rows: Vec<PersonRow> = result.deserialize().expect("deserialize rows");
274
275        assert_eq!(
276            rows,
277            vec![
278                PersonRow { id: 1, age: 40 },
279                PersonRow {
280                    id: u64::MAX,
281                    age: -5,
282                },
283            ]
284        );
285    }
286}