1use std::sync::Arc;
7
8use arrow::array::{
9 BooleanArray, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array,
10 NullArray, StringBuilder,
11};
12use arrow::datatypes::{DataType, Field, Schema};
13use arrow::record_batch::RecordBatch;
14use arrow_flight::encode::FlightDataEncoderBuilder;
15use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
16use arrow_flight::{
17 Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
18 HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
19};
20use futures::stream::BoxStream;
21use futures::{StreamExt, TryStreamExt};
22use tonic::transport::Server;
23use tonic::{Request, Response, Status, Streaming};
24
25use kyu_executor::QueryResult;
26use kyu_types::{LogicalType, TypedValue};
27
28use crate::Database;
29
30fn logical_to_arrow(ty: &LogicalType) -> DataType {
34 match ty {
35 LogicalType::Bool => DataType::Boolean,
36 LogicalType::Int8 => DataType::Int8,
37 LogicalType::Int16 => DataType::Int16,
38 LogicalType::Int32 => DataType::Int32,
39 LogicalType::Int64 | LogicalType::Serial => DataType::Int64,
40 LogicalType::Float => DataType::Float32,
41 LogicalType::Double => DataType::Float64,
42 LogicalType::String => DataType::Utf8,
43 _ => DataType::Utf8, }
45}
46
47pub fn to_record_batch(result: &QueryResult) -> Option<RecordBatch> {
51 let nc = result.num_columns();
52 if nc == 0 {
53 return None;
54 }
55 let nr = result.num_rows();
56
57 let fields: Vec<Field> = result
58 .column_names
59 .iter()
60 .zip(&result.column_types)
61 .map(|(name, ty)| Field::new(name.as_str(), logical_to_arrow(ty), true))
62 .collect();
63 let schema = Arc::new(Schema::new(fields));
64
65 let columns: Vec<Arc<dyn arrow::array::Array>> = (0..nc)
66 .map(|col_idx| build_arrow_column(result, col_idx, nr, &result.column_types[col_idx]))
67 .collect();
68
69 RecordBatch::try_new(schema, columns).ok()
70}
71
72fn build_arrow_column(
74 result: &QueryResult,
75 col_idx: usize,
76 num_rows: usize,
77 ty: &LogicalType,
78) -> Arc<dyn arrow::array::Array> {
79 match ty {
80 LogicalType::Bool => {
81 let values: Vec<Option<bool>> = (0..num_rows)
82 .map(|r| match &result.row(r)[col_idx] {
83 TypedValue::Bool(v) => Some(*v),
84 TypedValue::Null => None,
85 _ => None,
86 })
87 .collect();
88 Arc::new(BooleanArray::from(values))
89 }
90 LogicalType::Int8 => {
91 let values: Vec<Option<i8>> = (0..num_rows)
92 .map(|r| match &result.row(r)[col_idx] {
93 TypedValue::Int8(v) => Some(*v),
94 TypedValue::Null => None,
95 _ => None,
96 })
97 .collect();
98 Arc::new(Int8Array::from(values))
99 }
100 LogicalType::Int16 => {
101 let values: Vec<Option<i16>> = (0..num_rows)
102 .map(|r| match &result.row(r)[col_idx] {
103 TypedValue::Int16(v) => Some(*v),
104 TypedValue::Null => None,
105 _ => None,
106 })
107 .collect();
108 Arc::new(Int16Array::from(values))
109 }
110 LogicalType::Int32 => {
111 let values: Vec<Option<i32>> = (0..num_rows)
112 .map(|r| match &result.row(r)[col_idx] {
113 TypedValue::Int32(v) => Some(*v),
114 TypedValue::Null => None,
115 _ => None,
116 })
117 .collect();
118 Arc::new(Int32Array::from(values))
119 }
120 LogicalType::Int64 | LogicalType::Serial => {
121 let values: Vec<Option<i64>> = (0..num_rows)
122 .map(|r| match &result.row(r)[col_idx] {
123 TypedValue::Int64(v) => Some(*v),
124 TypedValue::Null => None,
125 _ => None,
126 })
127 .collect();
128 Arc::new(Int64Array::from(values))
129 }
130 LogicalType::Float => {
131 let values: Vec<Option<f32>> = (0..num_rows)
132 .map(|r| match &result.row(r)[col_idx] {
133 TypedValue::Float(v) => Some(*v),
134 TypedValue::Null => None,
135 _ => None,
136 })
137 .collect();
138 Arc::new(Float32Array::from(values))
139 }
140 LogicalType::Double => {
141 let values: Vec<Option<f64>> = (0..num_rows)
142 .map(|r| match &result.row(r)[col_idx] {
143 TypedValue::Double(v) => Some(*v),
144 TypedValue::Null => None,
145 _ => None,
146 })
147 .collect();
148 Arc::new(Float64Array::from(values))
149 }
150 LogicalType::String => {
151 let mut builder = StringBuilder::with_capacity(num_rows, num_rows * 16);
152 for r in 0..num_rows {
153 match &result.row(r)[col_idx] {
154 TypedValue::String(s) => builder.append_value(s.as_str()),
155 TypedValue::Null => builder.append_null(),
156 other => builder.append_value(format!("{other:?}")),
157 }
158 }
159 Arc::new(builder.finish())
160 }
161 _ => {
162 Arc::new(NullArray::new(num_rows))
164 }
165 }
166}
167
168pub struct KyuFlightService {
176 db: Arc<Database>,
177}
178
179impl KyuFlightService {
180 pub fn new(db: Arc<Database>) -> Self {
181 Self { db }
182 }
183
184 #[allow(clippy::result_large_err)]
185 fn execute_query(&self, cypher: &str) -> Result<QueryResult, Status> {
186 let conn = self.db.connect();
187 conn.query(cypher)
188 .map_err(|e| Status::internal(format!("query error: {e}")))
189 }
190}
191
192#[tonic::async_trait]
193impl FlightService for KyuFlightService {
194 type HandshakeStream = BoxStream<'static, Result<HandshakeResponse, Status>>;
195 type ListFlightsStream = BoxStream<'static, Result<FlightInfo, Status>>;
196 type DoGetStream = BoxStream<'static, Result<FlightData, Status>>;
197 type DoPutStream = BoxStream<'static, Result<PutResult, Status>>;
198 type DoExchangeStream = BoxStream<'static, Result<FlightData, Status>>;
199 type DoActionStream = BoxStream<'static, Result<arrow_flight::Result, Status>>;
200 type ListActionsStream = BoxStream<'static, Result<ActionType, Status>>;
201
202 async fn do_get(
206 &self,
207 request: Request<Ticket>,
208 ) -> Result<Response<Self::DoGetStream>, Status> {
209 let ticket = request.into_inner();
210 let cypher = String::from_utf8(ticket.ticket.to_vec())
211 .map_err(|_| Status::invalid_argument("ticket must be valid UTF-8"))?;
212
213 let result = self.execute_query(&cypher)?;
214
215 let batches: Vec<RecordBatch> = match to_record_batch(&result) {
216 Some(batch) => vec![batch],
217 None => {
218 let schema = Arc::new(Schema::new(vec![Field::new(
220 "ok",
221 DataType::Boolean,
222 false,
223 )]));
224 vec![
225 RecordBatch::try_new(schema, vec![Arc::new(BooleanArray::from(vec![true]))])
226 .unwrap(),
227 ]
228 }
229 };
230
231 let schema = batches[0].schema();
232 let batch_stream = futures::stream::iter(batches.into_iter().map(Ok));
233 let flight_stream = FlightDataEncoderBuilder::new()
234 .with_schema(schema)
235 .build(batch_stream)
236 .map_err(|e| Status::internal(e.to_string()));
237
238 Ok(Response::new(flight_stream.boxed()))
239 }
240
241 async fn do_action(
245 &self,
246 request: Request<Action>,
247 ) -> Result<Response<Self::DoActionStream>, Status> {
248 let action = request.into_inner();
249
250 match action.r#type.as_str() {
251 "query" => {
252 let cypher = String::from_utf8(action.body.to_vec())
253 .map_err(|_| Status::invalid_argument("body must be valid UTF-8"))?;
254 self.execute_query(&cypher)?;
255 let result = arrow_flight::Result {
256 body: bytes::Bytes::from("OK"),
257 };
258 let stream = futures::stream::once(async { Ok(result) });
259 Ok(Response::new(stream.boxed()))
260 }
261 other => Err(Status::invalid_argument(format!(
262 "unknown action type: {other}"
263 ))),
264 }
265 }
266
267 async fn list_actions(
268 &self,
269 _request: Request<Empty>,
270 ) -> Result<Response<Self::ListActionsStream>, Status> {
271 let actions = vec![Ok(ActionType {
272 r#type: "query".into(),
273 description: "Execute a Cypher DDL/DML statement".into(),
274 })];
275 Ok(Response::new(futures::stream::iter(actions).boxed()))
276 }
277
278 async fn handshake(
279 &self,
280 _request: Request<Streaming<HandshakeRequest>>,
281 ) -> Result<Response<Self::HandshakeStream>, Status> {
282 Err(Status::unimplemented("handshake not supported"))
283 }
284
285 async fn list_flights(
286 &self,
287 _request: Request<Criteria>,
288 ) -> Result<Response<Self::ListFlightsStream>, Status> {
289 Err(Status::unimplemented("list_flights not supported"))
290 }
291
292 async fn get_flight_info(
293 &self,
294 _request: Request<FlightDescriptor>,
295 ) -> Result<Response<FlightInfo>, Status> {
296 Err(Status::unimplemented("get_flight_info not supported"))
297 }
298
299 async fn poll_flight_info(
300 &self,
301 _request: Request<FlightDescriptor>,
302 ) -> Result<Response<PollInfo>, Status> {
303 Err(Status::unimplemented("poll_flight_info not supported"))
304 }
305
306 async fn get_schema(
307 &self,
308 _request: Request<FlightDescriptor>,
309 ) -> Result<Response<SchemaResult>, Status> {
310 Err(Status::unimplemented("get_schema not supported"))
311 }
312
313 async fn do_put(
314 &self,
315 _request: Request<Streaming<FlightData>>,
316 ) -> Result<Response<Self::DoPutStream>, Status> {
317 Err(Status::unimplemented("do_put not supported"))
318 }
319
320 async fn do_exchange(
321 &self,
322 _request: Request<Streaming<FlightData>>,
323 ) -> Result<Response<Self::DoExchangeStream>, Status> {
324 Err(Status::unimplemented("do_exchange not supported"))
325 }
326}
327
328pub async fn serve_flight(
332 db: Arc<Database>,
333 host: &str,
334 port: u16,
335) -> Result<(), Box<dyn std::error::Error>> {
336 let addr = format!("{host}:{port}").parse()?;
337 let service = KyuFlightService::new(db);
338 println!("KyuGraph Flight server listening on {addr}");
339 Server::builder()
340 .add_service(FlightServiceServer::new(service))
341 .serve(addr)
342 .await?;
343 Ok(())
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use kyu_types::LogicalType;
350 use smol_str::SmolStr;
351
352 #[test]
353 fn to_record_batch_empty_columns() {
354 let result = QueryResult::new(vec![], vec![]);
355 assert!(to_record_batch(&result).is_none());
356 }
357
358 #[test]
359 fn to_record_batch_int64() {
360 let mut result = QueryResult::new(vec![SmolStr::new("x")], vec![LogicalType::Int64]);
361 result.push_row(vec![TypedValue::Int64(42)]);
362 result.push_row(vec![TypedValue::Int64(99)]);
363
364 let batch = to_record_batch(&result).unwrap();
365 assert_eq!(batch.num_rows(), 2);
366 assert_eq!(batch.num_columns(), 1);
367
368 let col = batch
369 .column(0)
370 .as_any()
371 .downcast_ref::<Int64Array>()
372 .unwrap();
373 assert_eq!(col.value(0), 42);
374 assert_eq!(col.value(1), 99);
375 }
376
377 #[test]
378 fn to_record_batch_mixed_types() {
379 let mut result = QueryResult::new(
380 vec![
381 SmolStr::new("id"),
382 SmolStr::new("name"),
383 SmolStr::new("score"),
384 ],
385 vec![LogicalType::Int64, LogicalType::String, LogicalType::Double],
386 );
387 result.push_row(vec![
388 TypedValue::Int64(1),
389 TypedValue::String(SmolStr::new("Alice")),
390 TypedValue::Double(95.5),
391 ]);
392 result.push_row(vec![
393 TypedValue::Int64(2),
394 TypedValue::Null,
395 TypedValue::Double(87.3),
396 ]);
397
398 let batch = to_record_batch(&result).unwrap();
399 assert_eq!(batch.num_rows(), 2);
400 assert_eq!(batch.num_columns(), 3);
401
402 let name_col = batch.column(1);
404 assert!(!name_col.is_null(0));
405 assert!(name_col.is_null(1));
406 }
407
408 #[test]
409 fn to_record_batch_bool_float() {
410 let mut result = QueryResult::new(
411 vec![SmolStr::new("active"), SmolStr::new("temp")],
412 vec![LogicalType::Bool, LogicalType::Float],
413 );
414 result.push_row(vec![TypedValue::Bool(true), TypedValue::Float(3.14)]);
415
416 let batch = to_record_batch(&result).unwrap();
417 let bool_col = batch
418 .column(0)
419 .as_any()
420 .downcast_ref::<BooleanArray>()
421 .unwrap();
422 assert!(bool_col.value(0));
423
424 let float_col = batch
425 .column(1)
426 .as_any()
427 .downcast_ref::<Float32Array>()
428 .unwrap();
429 assert!((float_col.value(0) - 3.14).abs() < 0.01);
430 }
431
432 #[test]
433 fn roundtrip_via_database() {
434 let db = Database::in_memory();
435 let conn = db.connect();
436 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))")
437 .unwrap();
438 conn.query("CREATE (n:Person {id: 1, name: 'Alice'})")
439 .unwrap();
440 conn.query("CREATE (n:Person {id: 2, name: 'Bob'})")
441 .unwrap();
442
443 let result = conn.query("MATCH (p:Person) RETURN p.id, p.name").unwrap();
444 let batch = to_record_batch(&result).unwrap();
445
446 assert_eq!(batch.num_rows(), 2);
447 assert_eq!(batch.schema().fields().len(), 2);
448 assert_eq!(batch.schema().field(0).data_type(), &DataType::Int64);
449 assert_eq!(batch.schema().field(1).data_type(), &DataType::Utf8);
450 }
451}