Skip to main content

kyu_api/
flight.rs

1//! Arrow Flight gRPC server for KyuGraph.
2//!
3//! Exposes KyuGraph as an Arrow Flight endpoint. Clients send Cypher queries
4//! as Flight tickets and receive results as Arrow RecordBatch streams.
5
6use std::sync::Arc;
7
8use arrow::array::{
9    BooleanArray, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array,
10    NullArray, StringBuilder,
11};
12use arrow::datatypes::{DataType, Field, Schema};
13use arrow::record_batch::RecordBatch;
14use arrow_flight::encode::FlightDataEncoderBuilder;
15use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
16use arrow_flight::{
17    Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
18    HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
19};
20use futures::stream::BoxStream;
21use futures::{StreamExt, TryStreamExt};
22use tonic::transport::Server;
23use tonic::{Request, Response, Status, Streaming};
24
25use kyu_executor::QueryResult;
26use kyu_types::{LogicalType, TypedValue};
27
28use crate::Database;
29
30// ---- QueryResult → RecordBatch conversion ----
31
32/// Map a KyuGraph LogicalType to an Arrow DataType.
33fn logical_to_arrow(ty: &LogicalType) -> DataType {
34    match ty {
35        LogicalType::Bool => DataType::Boolean,
36        LogicalType::Int8 => DataType::Int8,
37        LogicalType::Int16 => DataType::Int16,
38        LogicalType::Int32 => DataType::Int32,
39        LogicalType::Int64 | LogicalType::Serial => DataType::Int64,
40        LogicalType::Float => DataType::Float32,
41        LogicalType::Double => DataType::Float64,
42        LogicalType::String => DataType::Utf8,
43        _ => DataType::Utf8, // fallback: serialize as string
44    }
45}
46
47/// Convert a QueryResult into an Arrow RecordBatch.
48///
49/// Returns `None` if the result has zero columns (DDL/DML with no output).
50pub fn to_record_batch(result: &QueryResult) -> Option<RecordBatch> {
51    let nc = result.num_columns();
52    if nc == 0 {
53        return None;
54    }
55    let nr = result.num_rows();
56
57    let fields: Vec<Field> = result
58        .column_names
59        .iter()
60        .zip(&result.column_types)
61        .map(|(name, ty)| Field::new(name.as_str(), logical_to_arrow(ty), true))
62        .collect();
63    let schema = Arc::new(Schema::new(fields));
64
65    let columns: Vec<Arc<dyn arrow::array::Array>> = (0..nc)
66        .map(|col_idx| build_arrow_column(result, col_idx, nr, &result.column_types[col_idx]))
67        .collect();
68
69    RecordBatch::try_new(schema, columns).ok()
70}
71
72/// Build a single Arrow array column from a QueryResult.
73fn build_arrow_column(
74    result: &QueryResult,
75    col_idx: usize,
76    num_rows: usize,
77    ty: &LogicalType,
78) -> Arc<dyn arrow::array::Array> {
79    match ty {
80        LogicalType::Bool => {
81            let values: Vec<Option<bool>> = (0..num_rows)
82                .map(|r| match &result.row(r)[col_idx] {
83                    TypedValue::Bool(v) => Some(*v),
84                    TypedValue::Null => None,
85                    _ => None,
86                })
87                .collect();
88            Arc::new(BooleanArray::from(values))
89        }
90        LogicalType::Int8 => {
91            let values: Vec<Option<i8>> = (0..num_rows)
92                .map(|r| match &result.row(r)[col_idx] {
93                    TypedValue::Int8(v) => Some(*v),
94                    TypedValue::Null => None,
95                    _ => None,
96                })
97                .collect();
98            Arc::new(Int8Array::from(values))
99        }
100        LogicalType::Int16 => {
101            let values: Vec<Option<i16>> = (0..num_rows)
102                .map(|r| match &result.row(r)[col_idx] {
103                    TypedValue::Int16(v) => Some(*v),
104                    TypedValue::Null => None,
105                    _ => None,
106                })
107                .collect();
108            Arc::new(Int16Array::from(values))
109        }
110        LogicalType::Int32 => {
111            let values: Vec<Option<i32>> = (0..num_rows)
112                .map(|r| match &result.row(r)[col_idx] {
113                    TypedValue::Int32(v) => Some(*v),
114                    TypedValue::Null => None,
115                    _ => None,
116                })
117                .collect();
118            Arc::new(Int32Array::from(values))
119        }
120        LogicalType::Int64 | LogicalType::Serial => {
121            let values: Vec<Option<i64>> = (0..num_rows)
122                .map(|r| match &result.row(r)[col_idx] {
123                    TypedValue::Int64(v) => Some(*v),
124                    TypedValue::Null => None,
125                    _ => None,
126                })
127                .collect();
128            Arc::new(Int64Array::from(values))
129        }
130        LogicalType::Float => {
131            let values: Vec<Option<f32>> = (0..num_rows)
132                .map(|r| match &result.row(r)[col_idx] {
133                    TypedValue::Float(v) => Some(*v),
134                    TypedValue::Null => None,
135                    _ => None,
136                })
137                .collect();
138            Arc::new(Float32Array::from(values))
139        }
140        LogicalType::Double => {
141            let values: Vec<Option<f64>> = (0..num_rows)
142                .map(|r| match &result.row(r)[col_idx] {
143                    TypedValue::Double(v) => Some(*v),
144                    TypedValue::Null => None,
145                    _ => None,
146                })
147                .collect();
148            Arc::new(Float64Array::from(values))
149        }
150        LogicalType::String => {
151            let mut builder = StringBuilder::with_capacity(num_rows, num_rows * 16);
152            for r in 0..num_rows {
153                match &result.row(r)[col_idx] {
154                    TypedValue::String(s) => builder.append_value(s.as_str()),
155                    TypedValue::Null => builder.append_null(),
156                    other => builder.append_value(format!("{other:?}")),
157                }
158            }
159            Arc::new(builder.finish())
160        }
161        _ => {
162            // Fallback: null column for unsupported types.
163            Arc::new(NullArray::new(num_rows))
164        }
165    }
166}
167
168// ---- Arrow Flight Service ----
169
170/// KyuGraph Arrow Flight service.
171///
172/// Each `do_get` executes the Cypher query encoded in the ticket and streams
173/// the result as Arrow RecordBatches. DDL and mutations are supported via
174/// `do_action`.
175pub struct KyuFlightService {
176    db: Arc<Database>,
177}
178
179impl KyuFlightService {
180    pub fn new(db: Arc<Database>) -> Self {
181        Self { db }
182    }
183
184    #[allow(clippy::result_large_err)]
185    fn execute_query(&self, cypher: &str) -> Result<QueryResult, Status> {
186        let conn = self.db.connect();
187        conn.query(cypher)
188            .map_err(|e| Status::internal(format!("query error: {e}")))
189    }
190}
191
192#[tonic::async_trait]
193impl FlightService for KyuFlightService {
194    type HandshakeStream = BoxStream<'static, Result<HandshakeResponse, Status>>;
195    type ListFlightsStream = BoxStream<'static, Result<FlightInfo, Status>>;
196    type DoGetStream = BoxStream<'static, Result<FlightData, Status>>;
197    type DoPutStream = BoxStream<'static, Result<PutResult, Status>>;
198    type DoExchangeStream = BoxStream<'static, Result<FlightData, Status>>;
199    type DoActionStream = BoxStream<'static, Result<arrow_flight::Result, Status>>;
200    type ListActionsStream = BoxStream<'static, Result<ActionType, Status>>;
201
202    /// Execute a Cypher query and stream the result as Arrow RecordBatches.
203    ///
204    /// The ticket body is the UTF-8 encoded Cypher query string.
205    async fn do_get(
206        &self,
207        request: Request<Ticket>,
208    ) -> Result<Response<Self::DoGetStream>, Status> {
209        let ticket = request.into_inner();
210        let cypher = String::from_utf8(ticket.ticket.to_vec())
211            .map_err(|_| Status::invalid_argument("ticket must be valid UTF-8"))?;
212
213        let result = self.execute_query(&cypher)?;
214
215        let batches: Vec<RecordBatch> = match to_record_batch(&result) {
216            Some(batch) => vec![batch],
217            None => {
218                // DDL/DML — return empty schema with zero rows.
219                let schema = Arc::new(Schema::new(vec![Field::new(
220                    "ok",
221                    DataType::Boolean,
222                    false,
223                )]));
224                vec![
225                    RecordBatch::try_new(schema, vec![Arc::new(BooleanArray::from(vec![true]))])
226                        .unwrap(),
227                ]
228            }
229        };
230
231        let schema = batches[0].schema();
232        let batch_stream = futures::stream::iter(batches.into_iter().map(Ok));
233        let flight_stream = FlightDataEncoderBuilder::new()
234            .with_schema(schema)
235            .build(batch_stream)
236            .map_err(|e| Status::internal(e.to_string()));
237
238        Ok(Response::new(flight_stream.boxed()))
239    }
240
241    /// Execute DDL/DML statements via actions.
242    ///
243    /// Action type: "query". Body: UTF-8 Cypher string.
244    async fn do_action(
245        &self,
246        request: Request<Action>,
247    ) -> Result<Response<Self::DoActionStream>, Status> {
248        let action = request.into_inner();
249
250        match action.r#type.as_str() {
251            "query" => {
252                let cypher = String::from_utf8(action.body.to_vec())
253                    .map_err(|_| Status::invalid_argument("body must be valid UTF-8"))?;
254                self.execute_query(&cypher)?;
255                let result = arrow_flight::Result {
256                    body: bytes::Bytes::from("OK"),
257                };
258                let stream = futures::stream::once(async { Ok(result) });
259                Ok(Response::new(stream.boxed()))
260            }
261            other => Err(Status::invalid_argument(format!(
262                "unknown action type: {other}"
263            ))),
264        }
265    }
266
267    async fn list_actions(
268        &self,
269        _request: Request<Empty>,
270    ) -> Result<Response<Self::ListActionsStream>, Status> {
271        let actions = vec![Ok(ActionType {
272            r#type: "query".into(),
273            description: "Execute a Cypher DDL/DML statement".into(),
274        })];
275        Ok(Response::new(futures::stream::iter(actions).boxed()))
276    }
277
278    async fn handshake(
279        &self,
280        _request: Request<Streaming<HandshakeRequest>>,
281    ) -> Result<Response<Self::HandshakeStream>, Status> {
282        Err(Status::unimplemented("handshake not supported"))
283    }
284
285    async fn list_flights(
286        &self,
287        _request: Request<Criteria>,
288    ) -> Result<Response<Self::ListFlightsStream>, Status> {
289        Err(Status::unimplemented("list_flights not supported"))
290    }
291
292    async fn get_flight_info(
293        &self,
294        _request: Request<FlightDescriptor>,
295    ) -> Result<Response<FlightInfo>, Status> {
296        Err(Status::unimplemented("get_flight_info not supported"))
297    }
298
299    async fn poll_flight_info(
300        &self,
301        _request: Request<FlightDescriptor>,
302    ) -> Result<Response<PollInfo>, Status> {
303        Err(Status::unimplemented("poll_flight_info not supported"))
304    }
305
306    async fn get_schema(
307        &self,
308        _request: Request<FlightDescriptor>,
309    ) -> Result<Response<SchemaResult>, Status> {
310        Err(Status::unimplemented("get_schema not supported"))
311    }
312
313    async fn do_put(
314        &self,
315        _request: Request<Streaming<FlightData>>,
316    ) -> Result<Response<Self::DoPutStream>, Status> {
317        Err(Status::unimplemented("do_put not supported"))
318    }
319
320    async fn do_exchange(
321        &self,
322        _request: Request<Streaming<FlightData>>,
323    ) -> Result<Response<Self::DoExchangeStream>, Status> {
324        Err(Status::unimplemented("do_exchange not supported"))
325    }
326}
327
328/// Start the Arrow Flight server.
329///
330/// Binds to `host:port` and serves until the process is terminated.
331pub async fn serve_flight(
332    db: Arc<Database>,
333    host: &str,
334    port: u16,
335) -> Result<(), Box<dyn std::error::Error>> {
336    let addr = format!("{host}:{port}").parse()?;
337    let service = KyuFlightService::new(db);
338    println!("KyuGraph Flight server listening on {addr}");
339    Server::builder()
340        .add_service(FlightServiceServer::new(service))
341        .serve(addr)
342        .await?;
343    Ok(())
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use kyu_types::LogicalType;
350    use smol_str::SmolStr;
351
352    #[test]
353    fn to_record_batch_empty_columns() {
354        let result = QueryResult::new(vec![], vec![]);
355        assert!(to_record_batch(&result).is_none());
356    }
357
358    #[test]
359    fn to_record_batch_int64() {
360        let mut result = QueryResult::new(vec![SmolStr::new("x")], vec![LogicalType::Int64]);
361        result.push_row(vec![TypedValue::Int64(42)]);
362        result.push_row(vec![TypedValue::Int64(99)]);
363
364        let batch = to_record_batch(&result).unwrap();
365        assert_eq!(batch.num_rows(), 2);
366        assert_eq!(batch.num_columns(), 1);
367
368        let col = batch
369            .column(0)
370            .as_any()
371            .downcast_ref::<Int64Array>()
372            .unwrap();
373        assert_eq!(col.value(0), 42);
374        assert_eq!(col.value(1), 99);
375    }
376
377    #[test]
378    fn to_record_batch_mixed_types() {
379        let mut result = QueryResult::new(
380            vec![
381                SmolStr::new("id"),
382                SmolStr::new("name"),
383                SmolStr::new("score"),
384            ],
385            vec![LogicalType::Int64, LogicalType::String, LogicalType::Double],
386        );
387        result.push_row(vec![
388            TypedValue::Int64(1),
389            TypedValue::String(SmolStr::new("Alice")),
390            TypedValue::Double(95.5),
391        ]);
392        result.push_row(vec![
393            TypedValue::Int64(2),
394            TypedValue::Null,
395            TypedValue::Double(87.3),
396        ]);
397
398        let batch = to_record_batch(&result).unwrap();
399        assert_eq!(batch.num_rows(), 2);
400        assert_eq!(batch.num_columns(), 3);
401
402        // Check string column has a null.
403        let name_col = batch.column(1);
404        assert!(!name_col.is_null(0));
405        assert!(name_col.is_null(1));
406    }
407
408    #[test]
409    fn to_record_batch_bool_float() {
410        let mut result = QueryResult::new(
411            vec![SmolStr::new("active"), SmolStr::new("temp")],
412            vec![LogicalType::Bool, LogicalType::Float],
413        );
414        result.push_row(vec![TypedValue::Bool(true), TypedValue::Float(3.14)]);
415
416        let batch = to_record_batch(&result).unwrap();
417        let bool_col = batch
418            .column(0)
419            .as_any()
420            .downcast_ref::<BooleanArray>()
421            .unwrap();
422        assert!(bool_col.value(0));
423
424        let float_col = batch
425            .column(1)
426            .as_any()
427            .downcast_ref::<Float32Array>()
428            .unwrap();
429        assert!((float_col.value(0) - 3.14).abs() < 0.01);
430    }
431
432    #[test]
433    fn roundtrip_via_database() {
434        let db = Database::in_memory();
435        let conn = db.connect();
436        conn.query("CREATE NODE TABLE Person (id INT64, name STRING, PRIMARY KEY (id))")
437            .unwrap();
438        conn.query("CREATE (n:Person {id: 1, name: 'Alice'})")
439            .unwrap();
440        conn.query("CREATE (n:Person {id: 2, name: 'Bob'})")
441            .unwrap();
442
443        let result = conn.query("MATCH (p:Person) RETURN p.id, p.name").unwrap();
444        let batch = to_record_batch(&result).unwrap();
445
446        assert_eq!(batch.num_rows(), 2);
447        assert_eq!(batch.schema().fields().len(), 2);
448        assert_eq!(batch.schema().field(0).data_type(), &DataType::Int64);
449        assert_eq!(batch.schema().field(1).data_type(), &DataType::Utf8);
450    }
451}