1use std::sync::Arc;
7
8use arrow::array::{
9 BooleanArray, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array,
10 NullArray, StringBuilder,
11};
12use arrow::datatypes::{DataType, Field, Schema};
13use arrow::record_batch::RecordBatch;
14use arrow_flight::encode::FlightDataEncoderBuilder;
15use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
16use arrow_flight::{
17 Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
18 HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
19};
20use futures::stream::BoxStream;
21use futures::{StreamExt, TryStreamExt};
22use tonic::transport::Server;
23use tonic::{Request, Response, Status, Streaming};
24
25use kyu_executor::QueryResult;
26use kyu_types::{LogicalType, TypedValue};
27
28use crate::Database;
29
30fn logical_to_arrow(ty: &LogicalType) -> DataType {
34 match ty {
35 LogicalType::Bool => DataType::Boolean,
36 LogicalType::Int8 => DataType::Int8,
37 LogicalType::Int16 => DataType::Int16,
38 LogicalType::Int32 => DataType::Int32,
39 LogicalType::Int64 | LogicalType::Serial => DataType::Int64,
40 LogicalType::Float => DataType::Float32,
41 LogicalType::Double => DataType::Float64,
42 LogicalType::String => DataType::Utf8,
43 _ => DataType::Utf8, }
45}
46
47pub fn to_record_batch(result: &QueryResult) -> Option<RecordBatch> {
51 let nc = result.num_columns();
52 if nc == 0 {
53 return None;
54 }
55 let nr = result.num_rows();
56
57 let fields: Vec<Field> = result
58 .column_names
59 .iter()
60 .zip(&result.column_types)
61 .map(|(name, ty)| Field::new(name.as_str(), logical_to_arrow(ty), true))
62 .collect();
63 let schema = Arc::new(Schema::new(fields));
64
65 let columns: Vec<Arc<dyn arrow::array::Array>> = (0..nc)
66 .map(|col_idx| build_arrow_column(result, col_idx, nr, &result.column_types[col_idx]))
67 .collect();
68
69 RecordBatch::try_new(schema, columns).ok()
70}
71
72fn build_arrow_column(
74 result: &QueryResult,
75 col_idx: usize,
76 num_rows: usize,
77 ty: &LogicalType,
78) -> Arc<dyn arrow::array::Array> {
79 match ty {
80 LogicalType::Bool => {
81 let values: Vec<Option<bool>> = (0..num_rows)
82 .map(|r| match &result.row(r)[col_idx] {
83 TypedValue::Bool(v) => Some(*v),
84 TypedValue::Null => None,
85 _ => None,
86 })
87 .collect();
88 Arc::new(BooleanArray::from(values))
89 }
90 LogicalType::Int8 => {
91 let values: Vec<Option<i8>> = (0..num_rows)
92 .map(|r| match &result.row(r)[col_idx] {
93 TypedValue::Int8(v) => Some(*v),
94 TypedValue::Null => None,
95 _ => None,
96 })
97 .collect();
98 Arc::new(Int8Array::from(values))
99 }
100 LogicalType::Int16 => {
101 let values: Vec<Option<i16>> = (0..num_rows)
102 .map(|r| match &result.row(r)[col_idx] {
103 TypedValue::Int16(v) => Some(*v),
104 TypedValue::Null => None,
105 _ => None,
106 })
107 .collect();
108 Arc::new(Int16Array::from(values))
109 }
110 LogicalType::Int32 => {
111 let values: Vec<Option<i32>> = (0..num_rows)
112 .map(|r| match &result.row(r)[col_idx] {
113 TypedValue::Int32(v) => Some(*v),
114 TypedValue::Null => None,
115 _ => None,
116 })
117 .collect();
118 Arc::new(Int32Array::from(values))
119 }
120 LogicalType::Int64 | LogicalType::Serial => {
121 let values: Vec<Option<i64>> = (0..num_rows)
122 .map(|r| match &result.row(r)[col_idx] {
123 TypedValue::Int64(v) => Some(*v),
124 TypedValue::Null => None,
125 _ => None,
126 })
127 .collect();
128 Arc::new(Int64Array::from(values))
129 }
130 LogicalType::Float => {
131 let values: Vec<Option<f32>> = (0..num_rows)
132 .map(|r| match &result.row(r)[col_idx] {
133 TypedValue::Float(v) => Some(*v),
134 TypedValue::Null => None,
135 _ => None,
136 })
137 .collect();
138 Arc::new(Float32Array::from(values))
139 }
140 LogicalType::Double => {
141 let values: Vec<Option<f64>> = (0..num_rows)
142 .map(|r| match &result.row(r)[col_idx] {
143 TypedValue::Double(v) => Some(*v),
144 TypedValue::Null => None,
145 _ => None,
146 })
147 .collect();
148 Arc::new(Float64Array::from(values))
149 }
150 LogicalType::String => {
151 let mut builder = StringBuilder::with_capacity(num_rows, num_rows * 16);
152 for r in 0..num_rows {
153 match &result.row(r)[col_idx] {
154 TypedValue::String(s) => builder.append_value(s.as_str()),
155 TypedValue::Null => builder.append_null(),
156 other => builder.append_value(format!("{other:?}")),
157 }
158 }
159 Arc::new(builder.finish())
160 }
161 _ => {
162 Arc::new(NullArray::new(num_rows))
164 }
165 }
166}
167
168pub struct KyuFlightService {
176 db: Arc<Database>,
177}
178
179impl KyuFlightService {
180 pub fn new(db: Arc<Database>) -> Self {
181 Self { db }
182 }
183
184 #[allow(clippy::result_large_err)]
185 fn execute_query(&self, cypher: &str) -> Result<QueryResult, Status> {
186 let conn = self.db.connect();
187 conn.query(cypher)
188 .map_err(|e| Status::internal(format!("query error: {e}")))
189 }
190}
191
192#[tonic::async_trait]
193impl FlightService for KyuFlightService {
194 type HandshakeStream = BoxStream<'static, Result<HandshakeResponse, Status>>;
195 type ListFlightsStream = BoxStream<'static, Result<FlightInfo, Status>>;
196 type DoGetStream = BoxStream<'static, Result<FlightData, Status>>;
197 type DoPutStream = BoxStream<'static, Result<PutResult, Status>>;
198 type DoExchangeStream = BoxStream<'static, Result<FlightData, Status>>;
199 type DoActionStream = BoxStream<'static, Result<arrow_flight::Result, Status>>;
200 type ListActionsStream = BoxStream<'static, Result<ActionType, Status>>;
201
202 async fn do_get(
206 &self,
207 request: Request<Ticket>,
208 ) -> Result<Response<Self::DoGetStream>, Status> {
209 let ticket = request.into_inner();
210 let cypher = String::from_utf8(ticket.ticket.to_vec())
211 .map_err(|_| Status::invalid_argument("ticket must be valid UTF-8"))?;
212
213 let result = self.execute_query(&cypher)?;
214
215 let batches: Vec<RecordBatch> = match to_record_batch(&result) {
216 Some(batch) => vec![batch],
217 None => {
218 let schema = Arc::new(Schema::new(vec![Field::new(
220 "ok",
221 DataType::Boolean,
222 false,
223 )]));
224 vec![RecordBatch::try_new(
225 schema,
226 vec![Arc::new(BooleanArray::from(vec![true]))],
227 )
228 .unwrap()]
229 }
230 };
231
232 let schema = batches[0].schema();
233 let batch_stream = futures::stream::iter(batches.into_iter().map(Ok));
234 let flight_stream = FlightDataEncoderBuilder::new()
235 .with_schema(schema)
236 .build(batch_stream)
237 .map_err(|e| Status::internal(e.to_string()));
238
239 Ok(Response::new(flight_stream.boxed()))
240 }
241
242 async fn do_action(
246 &self,
247 request: Request<Action>,
248 ) -> Result<Response<Self::DoActionStream>, Status> {
249 let action = request.into_inner();
250
251 match action.r#type.as_str() {
252 "query" => {
253 let cypher = String::from_utf8(action.body.to_vec())
254 .map_err(|_| Status::invalid_argument("body must be valid UTF-8"))?;
255 self.execute_query(&cypher)?;
256 let result = arrow_flight::Result {
257 body: bytes::Bytes::from("OK"),
258 };
259 let stream = futures::stream::once(async { Ok(result) });
260 Ok(Response::new(stream.boxed()))
261 }
262 other => Err(Status::invalid_argument(format!(
263 "unknown action type: {other}"
264 ))),
265 }
266 }
267
268 async fn list_actions(
269 &self,
270 _request: Request<Empty>,
271 ) -> Result<Response<Self::ListActionsStream>, Status> {
272 let actions = vec![Ok(ActionType {
273 r#type: "query".into(),
274 description: "Execute a Cypher DDL/DML statement".into(),
275 })];
276 Ok(Response::new(futures::stream::iter(actions).boxed()))
277 }
278
279 async fn handshake(
280 &self,
281 _request: Request<Streaming<HandshakeRequest>>,
282 ) -> Result<Response<Self::HandshakeStream>, Status> {
283 Err(Status::unimplemented("handshake not supported"))
284 }
285
286 async fn list_flights(
287 &self,
288 _request: Request<Criteria>,
289 ) -> Result<Response<Self::ListFlightsStream>, Status> {
290 Err(Status::unimplemented("list_flights not supported"))
291 }
292
293 async fn get_flight_info(
294 &self,
295 _request: Request<FlightDescriptor>,
296 ) -> Result<Response<FlightInfo>, Status> {
297 Err(Status::unimplemented("get_flight_info not supported"))
298 }
299
300 async fn poll_flight_info(
301 &self,
302 _request: Request<FlightDescriptor>,
303 ) -> Result<Response<PollInfo>, Status> {
304 Err(Status::unimplemented("poll_flight_info not supported"))
305 }
306
307 async fn get_schema(
308 &self,
309 _request: Request<FlightDescriptor>,
310 ) -> Result<Response<SchemaResult>, Status> {
311 Err(Status::unimplemented("get_schema not supported"))
312 }
313
314 async fn do_put(
315 &self,
316 _request: Request<Streaming<FlightData>>,
317 ) -> Result<Response<Self::DoPutStream>, Status> {
318 Err(Status::unimplemented("do_put not supported"))
319 }
320
321 async fn do_exchange(
322 &self,
323 _request: Request<Streaming<FlightData>>,
324 ) -> Result<Response<Self::DoExchangeStream>, Status> {
325 Err(Status::unimplemented("do_exchange not supported"))
326 }
327}
328
329pub async fn serve_flight(db: Arc<Database>, host: &str, port: u16) -> Result<(), Box<dyn std::error::Error>> {
333 let addr = format!("{host}:{port}").parse()?;
334 let service = KyuFlightService::new(db);
335 println!("KyuGraph Flight server listening on {addr}");
336 Server::builder()
337 .add_service(FlightServiceServer::new(service))
338 .serve(addr)
339 .await?;
340 Ok(())
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use kyu_types::LogicalType;
347 use smol_str::SmolStr;
348
349 #[test]
350 fn to_record_batch_empty_columns() {
351 let result = QueryResult::new(vec![], vec![]);
352 assert!(to_record_batch(&result).is_none());
353 }
354
355 #[test]
356 fn to_record_batch_int64() {
357 let mut result = QueryResult::new(
358 vec![SmolStr::new("x")],
359 vec![LogicalType::Int64],
360 );
361 result.push_row(vec![TypedValue::Int64(42)]);
362 result.push_row(vec![TypedValue::Int64(99)]);
363
364 let batch = to_record_batch(&result).unwrap();
365 assert_eq!(batch.num_rows(), 2);
366 assert_eq!(batch.num_columns(), 1);
367
368 let col = batch.column(0).as_any().downcast_ref::<Int64Array>().unwrap();
369 assert_eq!(col.value(0), 42);
370 assert_eq!(col.value(1), 99);
371 }
372
373 #[test]
374 fn to_record_batch_mixed_types() {
375 let mut result = QueryResult::new(
376 vec![SmolStr::new("id"), SmolStr::new("name"), SmolStr::new("score")],
377 vec![LogicalType::Int64, LogicalType::String, LogicalType::Double],
378 );
379 result.push_row(vec![
380 TypedValue::Int64(1),
381 TypedValue::String(SmolStr::new("Alice")),
382 TypedValue::Double(95.5),
383 ]);
384 result.push_row(vec![
385 TypedValue::Int64(2),
386 TypedValue::Null,
387 TypedValue::Double(87.3),
388 ]);
389
390 let batch = to_record_batch(&result).unwrap();
391 assert_eq!(batch.num_rows(), 2);
392 assert_eq!(batch.num_columns(), 3);
393
394 let name_col = batch.column(1);
396 assert!(!name_col.is_null(0));
397 assert!(name_col.is_null(1));
398 }
399
400 #[test]
401 fn to_record_batch_bool_float() {
402 let mut result = QueryResult::new(
403 vec![SmolStr::new("active"), SmolStr::new("temp")],
404 vec![LogicalType::Bool, LogicalType::Float],
405 );
406 result.push_row(vec![TypedValue::Bool(true), TypedValue::Float(3.14)]);
407
408 let batch = to_record_batch(&result).unwrap();
409 let bool_col = batch.column(0).as_any().downcast_ref::<BooleanArray>().unwrap();
410 assert!(bool_col.value(0));
411
412 let float_col = batch.column(1).as_any().downcast_ref::<Float32Array>().unwrap();
413 assert!((float_col.value(0) - 3.14).abs() < 0.01);
414 }
415
416 #[test]
417 fn roundtrip_via_database() {
418 let db = Database::in_memory();
419 let conn = db.connect();
420 conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))")
421 .unwrap();
422 conn.query("CREATE (n:Person {id: 1, name: 'Alice'})").unwrap();
423 conn.query("CREATE (n:Person {id: 2, name: 'Bob'})").unwrap();
424
425 let result = conn.query("MATCH (p:Person) RETURN p.id, p.name").unwrap();
426 let batch = to_record_batch(&result).unwrap();
427
428 assert_eq!(batch.num_rows(), 2);
429 assert_eq!(batch.schema().fields().len(), 2);
430 assert_eq!(batch.schema().field(0).data_type(), &DataType::Int64);
431 assert_eq!(batch.schema().field(1).data_type(), &DataType::Utf8);
432 }
433}