Skip to main content

kyu_api/
flight.rs

1//! Arrow Flight gRPC server for KyuGraph.
2//!
3//! Exposes KyuGraph as an Arrow Flight endpoint. Clients send Cypher queries
4//! as Flight tickets and receive results as Arrow RecordBatch streams.
5
6use std::sync::Arc;
7
8use arrow::array::{
9    BooleanArray, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array,
10    NullArray, StringBuilder,
11};
12use arrow::datatypes::{DataType, Field, Schema};
13use arrow::record_batch::RecordBatch;
14use arrow_flight::encode::FlightDataEncoderBuilder;
15use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
16use arrow_flight::{
17    Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
18    HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
19};
20use futures::stream::BoxStream;
21use futures::{StreamExt, TryStreamExt};
22use tonic::transport::Server;
23use tonic::{Request, Response, Status, Streaming};
24
25use kyu_executor::QueryResult;
26use kyu_types::{LogicalType, TypedValue};
27
28use crate::Database;
29
30// ---- QueryResult → RecordBatch conversion ----
31
32/// Map a KyuGraph LogicalType to an Arrow DataType.
33fn logical_to_arrow(ty: &LogicalType) -> DataType {
34    match ty {
35        LogicalType::Bool => DataType::Boolean,
36        LogicalType::Int8 => DataType::Int8,
37        LogicalType::Int16 => DataType::Int16,
38        LogicalType::Int32 => DataType::Int32,
39        LogicalType::Int64 | LogicalType::Serial => DataType::Int64,
40        LogicalType::Float => DataType::Float32,
41        LogicalType::Double => DataType::Float64,
42        LogicalType::String => DataType::Utf8,
43        _ => DataType::Utf8, // fallback: serialize as string
44    }
45}
46
47/// Convert a QueryResult into an Arrow RecordBatch.
48///
49/// Returns `None` if the result has zero columns (DDL/DML with no output).
50pub fn to_record_batch(result: &QueryResult) -> Option<RecordBatch> {
51    let nc = result.num_columns();
52    if nc == 0 {
53        return None;
54    }
55    let nr = result.num_rows();
56
57    let fields: Vec<Field> = result
58        .column_names
59        .iter()
60        .zip(&result.column_types)
61        .map(|(name, ty)| Field::new(name.as_str(), logical_to_arrow(ty), true))
62        .collect();
63    let schema = Arc::new(Schema::new(fields));
64
65    let columns: Vec<Arc<dyn arrow::array::Array>> = (0..nc)
66        .map(|col_idx| build_arrow_column(result, col_idx, nr, &result.column_types[col_idx]))
67        .collect();
68
69    RecordBatch::try_new(schema, columns).ok()
70}
71
72/// Build a single Arrow array column from a QueryResult.
73fn build_arrow_column(
74    result: &QueryResult,
75    col_idx: usize,
76    num_rows: usize,
77    ty: &LogicalType,
78) -> Arc<dyn arrow::array::Array> {
79    match ty {
80        LogicalType::Bool => {
81            let values: Vec<Option<bool>> = (0..num_rows)
82                .map(|r| match &result.row(r)[col_idx] {
83                    TypedValue::Bool(v) => Some(*v),
84                    TypedValue::Null => None,
85                    _ => None,
86                })
87                .collect();
88            Arc::new(BooleanArray::from(values))
89        }
90        LogicalType::Int8 => {
91            let values: Vec<Option<i8>> = (0..num_rows)
92                .map(|r| match &result.row(r)[col_idx] {
93                    TypedValue::Int8(v) => Some(*v),
94                    TypedValue::Null => None,
95                    _ => None,
96                })
97                .collect();
98            Arc::new(Int8Array::from(values))
99        }
100        LogicalType::Int16 => {
101            let values: Vec<Option<i16>> = (0..num_rows)
102                .map(|r| match &result.row(r)[col_idx] {
103                    TypedValue::Int16(v) => Some(*v),
104                    TypedValue::Null => None,
105                    _ => None,
106                })
107                .collect();
108            Arc::new(Int16Array::from(values))
109        }
110        LogicalType::Int32 => {
111            let values: Vec<Option<i32>> = (0..num_rows)
112                .map(|r| match &result.row(r)[col_idx] {
113                    TypedValue::Int32(v) => Some(*v),
114                    TypedValue::Null => None,
115                    _ => None,
116                })
117                .collect();
118            Arc::new(Int32Array::from(values))
119        }
120        LogicalType::Int64 | LogicalType::Serial => {
121            let values: Vec<Option<i64>> = (0..num_rows)
122                .map(|r| match &result.row(r)[col_idx] {
123                    TypedValue::Int64(v) => Some(*v),
124                    TypedValue::Null => None,
125                    _ => None,
126                })
127                .collect();
128            Arc::new(Int64Array::from(values))
129        }
130        LogicalType::Float => {
131            let values: Vec<Option<f32>> = (0..num_rows)
132                .map(|r| match &result.row(r)[col_idx] {
133                    TypedValue::Float(v) => Some(*v),
134                    TypedValue::Null => None,
135                    _ => None,
136                })
137                .collect();
138            Arc::new(Float32Array::from(values))
139        }
140        LogicalType::Double => {
141            let values: Vec<Option<f64>> = (0..num_rows)
142                .map(|r| match &result.row(r)[col_idx] {
143                    TypedValue::Double(v) => Some(*v),
144                    TypedValue::Null => None,
145                    _ => None,
146                })
147                .collect();
148            Arc::new(Float64Array::from(values))
149        }
150        LogicalType::String => {
151            let mut builder = StringBuilder::with_capacity(num_rows, num_rows * 16);
152            for r in 0..num_rows {
153                match &result.row(r)[col_idx] {
154                    TypedValue::String(s) => builder.append_value(s.as_str()),
155                    TypedValue::Null => builder.append_null(),
156                    other => builder.append_value(format!("{other:?}")),
157                }
158            }
159            Arc::new(builder.finish())
160        }
161        _ => {
162            // Fallback: null column for unsupported types.
163            Arc::new(NullArray::new(num_rows))
164        }
165    }
166}
167
168// ---- Arrow Flight Service ----
169
170/// KyuGraph Arrow Flight service.
171///
172/// Each `do_get` executes the Cypher query encoded in the ticket and streams
173/// the result as Arrow RecordBatches. DDL and mutations are supported via
174/// `do_action`.
175pub struct KyuFlightService {
176    db: Arc<Database>,
177}
178
179impl KyuFlightService {
180    pub fn new(db: Arc<Database>) -> Self {
181        Self { db }
182    }
183
184    #[allow(clippy::result_large_err)]
185    fn execute_query(&self, cypher: &str) -> Result<QueryResult, Status> {
186        let conn = self.db.connect();
187        conn.query(cypher)
188            .map_err(|e| Status::internal(format!("query error: {e}")))
189    }
190}
191
192#[tonic::async_trait]
193impl FlightService for KyuFlightService {
194    type HandshakeStream = BoxStream<'static, Result<HandshakeResponse, Status>>;
195    type ListFlightsStream = BoxStream<'static, Result<FlightInfo, Status>>;
196    type DoGetStream = BoxStream<'static, Result<FlightData, Status>>;
197    type DoPutStream = BoxStream<'static, Result<PutResult, Status>>;
198    type DoExchangeStream = BoxStream<'static, Result<FlightData, Status>>;
199    type DoActionStream = BoxStream<'static, Result<arrow_flight::Result, Status>>;
200    type ListActionsStream = BoxStream<'static, Result<ActionType, Status>>;
201
202    /// Execute a Cypher query and stream the result as Arrow RecordBatches.
203    ///
204    /// The ticket body is the UTF-8 encoded Cypher query string.
205    async fn do_get(
206        &self,
207        request: Request<Ticket>,
208    ) -> Result<Response<Self::DoGetStream>, Status> {
209        let ticket = request.into_inner();
210        let cypher = String::from_utf8(ticket.ticket.to_vec())
211            .map_err(|_| Status::invalid_argument("ticket must be valid UTF-8"))?;
212
213        let result = self.execute_query(&cypher)?;
214
215        let batches: Vec<RecordBatch> = match to_record_batch(&result) {
216            Some(batch) => vec![batch],
217            None => {
218                // DDL/DML — return empty schema with zero rows.
219                let schema = Arc::new(Schema::new(vec![Field::new(
220                    "ok",
221                    DataType::Boolean,
222                    false,
223                )]));
224                vec![RecordBatch::try_new(
225                    schema,
226                    vec![Arc::new(BooleanArray::from(vec![true]))],
227                )
228                .unwrap()]
229            }
230        };
231
232        let schema = batches[0].schema();
233        let batch_stream = futures::stream::iter(batches.into_iter().map(Ok));
234        let flight_stream = FlightDataEncoderBuilder::new()
235            .with_schema(schema)
236            .build(batch_stream)
237            .map_err(|e| Status::internal(e.to_string()));
238
239        Ok(Response::new(flight_stream.boxed()))
240    }
241
242    /// Execute DDL/DML statements via actions.
243    ///
244    /// Action type: "query". Body: UTF-8 Cypher string.
245    async fn do_action(
246        &self,
247        request: Request<Action>,
248    ) -> Result<Response<Self::DoActionStream>, Status> {
249        let action = request.into_inner();
250
251        match action.r#type.as_str() {
252            "query" => {
253                let cypher = String::from_utf8(action.body.to_vec())
254                    .map_err(|_| Status::invalid_argument("body must be valid UTF-8"))?;
255                self.execute_query(&cypher)?;
256                let result = arrow_flight::Result {
257                    body: bytes::Bytes::from("OK"),
258                };
259                let stream = futures::stream::once(async { Ok(result) });
260                Ok(Response::new(stream.boxed()))
261            }
262            other => Err(Status::invalid_argument(format!(
263                "unknown action type: {other}"
264            ))),
265        }
266    }
267
268    async fn list_actions(
269        &self,
270        _request: Request<Empty>,
271    ) -> Result<Response<Self::ListActionsStream>, Status> {
272        let actions = vec![Ok(ActionType {
273            r#type: "query".into(),
274            description: "Execute a Cypher DDL/DML statement".into(),
275        })];
276        Ok(Response::new(futures::stream::iter(actions).boxed()))
277    }
278
279    async fn handshake(
280        &self,
281        _request: Request<Streaming<HandshakeRequest>>,
282    ) -> Result<Response<Self::HandshakeStream>, Status> {
283        Err(Status::unimplemented("handshake not supported"))
284    }
285
286    async fn list_flights(
287        &self,
288        _request: Request<Criteria>,
289    ) -> Result<Response<Self::ListFlightsStream>, Status> {
290        Err(Status::unimplemented("list_flights not supported"))
291    }
292
293    async fn get_flight_info(
294        &self,
295        _request: Request<FlightDescriptor>,
296    ) -> Result<Response<FlightInfo>, Status> {
297        Err(Status::unimplemented("get_flight_info not supported"))
298    }
299
300    async fn poll_flight_info(
301        &self,
302        _request: Request<FlightDescriptor>,
303    ) -> Result<Response<PollInfo>, Status> {
304        Err(Status::unimplemented("poll_flight_info not supported"))
305    }
306
307    async fn get_schema(
308        &self,
309        _request: Request<FlightDescriptor>,
310    ) -> Result<Response<SchemaResult>, Status> {
311        Err(Status::unimplemented("get_schema not supported"))
312    }
313
314    async fn do_put(
315        &self,
316        _request: Request<Streaming<FlightData>>,
317    ) -> Result<Response<Self::DoPutStream>, Status> {
318        Err(Status::unimplemented("do_put not supported"))
319    }
320
321    async fn do_exchange(
322        &self,
323        _request: Request<Streaming<FlightData>>,
324    ) -> Result<Response<Self::DoExchangeStream>, Status> {
325        Err(Status::unimplemented("do_exchange not supported"))
326    }
327}
328
329/// Start the Arrow Flight server.
330///
331/// Binds to `host:port` and serves until the process is terminated.
332pub async fn serve_flight(db: Arc<Database>, host: &str, port: u16) -> Result<(), Box<dyn std::error::Error>> {
333    let addr = format!("{host}:{port}").parse()?;
334    let service = KyuFlightService::new(db);
335    println!("KyuGraph Flight server listening on {addr}");
336    Server::builder()
337        .add_service(FlightServiceServer::new(service))
338        .serve(addr)
339        .await?;
340    Ok(())
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use kyu_types::LogicalType;
347    use smol_str::SmolStr;
348
349    #[test]
350    fn to_record_batch_empty_columns() {
351        let result = QueryResult::new(vec![], vec![]);
352        assert!(to_record_batch(&result).is_none());
353    }
354
355    #[test]
356    fn to_record_batch_int64() {
357        let mut result = QueryResult::new(
358            vec![SmolStr::new("x")],
359            vec![LogicalType::Int64],
360        );
361        result.push_row(vec![TypedValue::Int64(42)]);
362        result.push_row(vec![TypedValue::Int64(99)]);
363
364        let batch = to_record_batch(&result).unwrap();
365        assert_eq!(batch.num_rows(), 2);
366        assert_eq!(batch.num_columns(), 1);
367
368        let col = batch.column(0).as_any().downcast_ref::<Int64Array>().unwrap();
369        assert_eq!(col.value(0), 42);
370        assert_eq!(col.value(1), 99);
371    }
372
373    #[test]
374    fn to_record_batch_mixed_types() {
375        let mut result = QueryResult::new(
376            vec![SmolStr::new("id"), SmolStr::new("name"), SmolStr::new("score")],
377            vec![LogicalType::Int64, LogicalType::String, LogicalType::Double],
378        );
379        result.push_row(vec![
380            TypedValue::Int64(1),
381            TypedValue::String(SmolStr::new("Alice")),
382            TypedValue::Double(95.5),
383        ]);
384        result.push_row(vec![
385            TypedValue::Int64(2),
386            TypedValue::Null,
387            TypedValue::Double(87.3),
388        ]);
389
390        let batch = to_record_batch(&result).unwrap();
391        assert_eq!(batch.num_rows(), 2);
392        assert_eq!(batch.num_columns(), 3);
393
394        // Check string column has a null.
395        let name_col = batch.column(1);
396        assert!(!name_col.is_null(0));
397        assert!(name_col.is_null(1));
398    }
399
400    #[test]
401    fn to_record_batch_bool_float() {
402        let mut result = QueryResult::new(
403            vec![SmolStr::new("active"), SmolStr::new("temp")],
404            vec![LogicalType::Bool, LogicalType::Float],
405        );
406        result.push_row(vec![TypedValue::Bool(true), TypedValue::Float(3.14)]);
407
408        let batch = to_record_batch(&result).unwrap();
409        let bool_col = batch.column(0).as_any().downcast_ref::<BooleanArray>().unwrap();
410        assert!(bool_col.value(0));
411
412        let float_col = batch.column(1).as_any().downcast_ref::<Float32Array>().unwrap();
413        assert!((float_col.value(0) - 3.14).abs() < 0.01);
414    }
415
416    #[test]
417    fn roundtrip_via_database() {
418        let db = Database::in_memory();
419        let conn = db.connect();
420        conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))")
421            .unwrap();
422        conn.query("CREATE (n:Person {id: 1, name: 'Alice'})").unwrap();
423        conn.query("CREATE (n:Person {id: 2, name: 'Bob'})").unwrap();
424
425        let result = conn.query("MATCH (p:Person) RETURN p.id, p.name").unwrap();
426        let batch = to_record_batch(&result).unwrap();
427
428        assert_eq!(batch.num_rows(), 2);
429        assert_eq!(batch.schema().fields().len(), 2);
430        assert_eq!(batch.schema().field(0).data_type(), &DataType::Int64);
431        assert_eq!(batch.schema().field(1).data_type(), &DataType::Utf8);
432    }
433}