1use std::sync::Arc;
2
3use arrow_array::{RecordBatch, UInt64Array};
4use arrow_ipc::writer::StreamWriter;
5use arrow_schema::{DataType, Field, Schema, SchemaRef};
6use serde::de::DeserializeOwned;
7
8use crate::error::{NanoError, Result};
9use crate::json_output::{record_batches_to_json_rows, record_batches_to_rust_json_rows};
10
11#[derive(Debug, Clone, Copy, Default)]
12pub struct MutationExecResult {
13 pub affected_nodes: usize,
14 pub affected_edges: usize,
15}
16
17#[derive(Debug, Clone)]
18pub struct QueryResult {
19 schema: SchemaRef,
20 batches: Vec<RecordBatch>,
21}
22
23impl QueryResult {
24 pub fn new(schema: SchemaRef, batches: Vec<RecordBatch>) -> Self {
25 Self { schema, batches }
26 }
27
28 pub fn schema(&self) -> &SchemaRef {
29 &self.schema
30 }
31
32 pub fn batches(&self) -> &[RecordBatch] {
33 &self.batches
34 }
35
36 pub fn into_batches(self) -> Vec<RecordBatch> {
37 self.batches
38 }
39
40 pub fn num_rows(&self) -> usize {
41 self.batches.iter().map(RecordBatch::num_rows).sum()
42 }
43
44 pub fn concat_batches(&self) -> Result<RecordBatch> {
45 if self.batches.is_empty() {
46 return Ok(RecordBatch::new_empty(self.schema.clone()));
47 }
48
49 arrow_select::concat::concat_batches(&self.schema, &self.batches)
50 .map_err(|err| NanoError::Execution(err.to_string()))
51 }
52
53 pub fn to_sdk_json(&self) -> serde_json::Value {
54 serde_json::Value::Array(record_batches_to_json_rows(&self.batches))
55 }
56
57 pub fn to_rust_json(&self) -> serde_json::Value {
58 serde_json::Value::Array(record_batches_to_rust_json_rows(&self.batches))
59 }
60
61 pub fn deserialize<T: DeserializeOwned>(&self) -> Result<T> {
62 serde_json::from_value(self.to_rust_json()).map_err(|err| {
63 NanoError::Execution(format!("failed to deserialize query result: {}", err))
64 })
65 }
66
67 pub fn to_arrow_ipc(&self) -> Result<Vec<u8>> {
68 let mut buffer = Vec::new();
69 let mut writer = StreamWriter::try_new(&mut buffer, &self.schema)?;
70 for batch in &self.batches {
71 writer.write(batch)?;
72 }
73 writer.finish()?;
74 drop(writer);
75 Ok(buffer)
76 }
77}
78
79#[derive(Debug, Clone, Copy, Default)]
80pub struct MutationResult {
81 pub affected_nodes: usize,
82 pub affected_edges: usize,
83}
84
85impl MutationResult {
86 pub fn to_sdk_json(&self) -> serde_json::Value {
87 serde_json::json!({
88 "affectedNodes": self.affected_nodes,
89 "affectedEdges": self.affected_edges,
90 })
91 }
92
93 pub fn to_record_batch(&self) -> Result<RecordBatch> {
94 let schema = Arc::new(Schema::new(vec![
95 Field::new("affected_nodes", DataType::UInt64, false),
96 Field::new("affected_edges", DataType::UInt64, false),
97 ]));
98 Ok(RecordBatch::try_new(
99 schema,
100 vec![
101 Arc::new(UInt64Array::from(vec![self.affected_nodes as u64])),
102 Arc::new(UInt64Array::from(vec![self.affected_edges as u64])),
103 ],
104 )?)
105 }
106}
107
108impl From<MutationExecResult> for MutationResult {
109 fn from(value: MutationExecResult) -> Self {
110 Self {
111 affected_nodes: value.affected_nodes,
112 affected_edges: value.affected_edges,
113 }
114 }
115}
116
117#[derive(Debug, Clone)]
118pub enum RunResult {
119 Query(QueryResult),
120 Mutation(MutationResult),
121}
122
123impl RunResult {
124 pub fn to_sdk_json(&self) -> serde_json::Value {
125 match self {
126 Self::Query(result) => result.to_sdk_json(),
127 Self::Mutation(result) => result.to_sdk_json(),
128 }
129 }
130
131 pub fn into_record_batches(self) -> Result<Vec<RecordBatch>> {
132 match self {
133 Self::Query(result) => Ok(result.into_batches()),
134 Self::Mutation(result) => Ok(vec![result.to_record_batch()?]),
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use std::io::Cursor;
142
143 use arrow_array::Int64Array;
144 use arrow_ipc::reader::StreamReader;
145 use serde::Deserialize;
146
147 use super::*;
148
149 #[test]
150 fn query_result_arrow_ipc_round_trips_empty_schema() {
151 let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, false)]));
152 let result = QueryResult::new(schema.clone(), vec![]);
153
154 let encoded = result.to_arrow_ipc().expect("encode empty result");
155 let reader = StreamReader::try_new(Cursor::new(encoded), None).expect("open stream");
156
157 assert_eq!(reader.schema().as_ref(), schema.as_ref());
158 assert_eq!(reader.count(), 0);
159 }
160
161 #[test]
162 fn query_result_arrow_ipc_round_trips_batches() {
163 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)]));
164 let batch = RecordBatch::try_new(
165 schema.clone(),
166 vec![Arc::new(UInt64Array::from(vec![1_u64, 2_u64]))],
167 )
168 .expect("batch");
169 let result = QueryResult::new(schema.clone(), vec![batch]);
170
171 let encoded = result.to_arrow_ipc().expect("encode result");
172 let mut reader = StreamReader::try_new(Cursor::new(encoded), None).expect("open stream");
173 let decoded = reader.next().expect("first batch").expect("decode batch");
174
175 assert_eq!(reader.schema().as_ref(), schema.as_ref());
176 assert_eq!(decoded.num_rows(), 2);
177 assert_eq!(decoded.schema().as_ref(), schema.as_ref());
178 }
179
180 #[test]
181 fn query_result_num_rows_and_concat_cover_multiple_batches() {
182 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)]));
183 let batch1 = RecordBatch::try_new(
184 schema.clone(),
185 vec![Arc::new(UInt64Array::from(vec![1_u64, 2_u64]))],
186 )
187 .expect("batch1");
188 let batch2 = RecordBatch::try_new(
189 schema.clone(),
190 vec![Arc::new(UInt64Array::from(vec![3_u64]))],
191 )
192 .expect("batch2");
193 let result = QueryResult::new(schema.clone(), vec![batch1, batch2]);
194
195 assert_eq!(result.num_rows(), 3);
196
197 let concatenated = result.concat_batches().expect("concat batches");
198 let ids = concatenated
199 .column(0)
200 .as_any()
201 .downcast_ref::<UInt64Array>()
202 .expect("u64 ids");
203 assert_eq!(concatenated.schema().as_ref(), schema.as_ref());
204 assert_eq!(ids.values(), &[1, 2, 3]);
205 }
206
207 #[test]
208 fn query_result_concat_empty_batches_returns_empty_batch() {
209 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)]));
210 let result = QueryResult::new(schema.clone(), vec![]);
211
212 let concatenated = result.concat_batches().expect("concat empty");
213
214 assert_eq!(concatenated.schema().as_ref(), schema.as_ref());
215 assert_eq!(concatenated.num_rows(), 0);
216 }
217
218 #[test]
219 fn query_result_to_rust_json_preserves_wide_integers() {
220 let schema = Arc::new(Schema::new(vec![
221 Field::new("signed", DataType::Int64, false),
222 Field::new("unsigned", DataType::UInt64, false),
223 ]));
224 let batch = RecordBatch::try_new(
225 schema.clone(),
226 vec![
227 Arc::new(Int64Array::from(vec![i64::MIN])),
228 Arc::new(UInt64Array::from(vec![u64::MAX])),
229 ],
230 )
231 .expect("batch");
232 let result = QueryResult::new(schema, vec![batch]);
233
234 assert_eq!(
235 result.to_rust_json(),
236 serde_json::json!([{
237 "signed": i64::MIN,
238 "unsigned": u64::MAX,
239 }])
240 );
241 }
242
243 #[derive(Debug, Deserialize, PartialEq)]
244 struct PersonRow {
245 id: u64,
246 age: i64,
247 }
248
249 #[test]
250 fn query_result_deserialize_decodes_rust_rows() {
251 let schema = Arc::new(Schema::new(vec![
252 Field::new("id", DataType::UInt64, false),
253 Field::new("age", DataType::Int64, false),
254 ]));
255 let batch1 = RecordBatch::try_new(
256 schema.clone(),
257 vec![
258 Arc::new(UInt64Array::from(vec![1_u64])),
259 Arc::new(Int64Array::from(vec![40_i64])),
260 ],
261 )
262 .expect("batch1");
263 let batch2 = RecordBatch::try_new(
264 schema,
265 vec![
266 Arc::new(UInt64Array::from(vec![u64::MAX])),
267 Arc::new(Int64Array::from(vec![-5_i64])),
268 ],
269 )
270 .expect("batch2");
271 let result = QueryResult::new(batch1.schema(), vec![batch1, batch2]);
272
273 let rows: Vec<PersonRow> = result.deserialize().expect("deserialize rows");
274
275 assert_eq!(
276 rows,
277 vec![
278 PersonRow { id: 1, age: 40 },
279 PersonRow {
280 id: u64::MAX,
281 age: -5,
282 },
283 ]
284 );
285 }
286}