arrow-flight 32.0.0

Apache Arrow Flight
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

use std::pin::Pin;

use crate::sql::Any;
use futures::Stream;
use prost::Message;
use tonic::{Request, Response, Status, Streaming};

use super::{
        flight_service_server::FlightService, Action, ActionType, Criteria, Empty,
        FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse,
        PutResult, SchemaResult, Ticket,
    ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
    ActionCreatePreparedStatementResult, CommandGetCatalogs, CommandGetCrossReference,
    CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
    CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
    CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery,
    CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt, SqlInfo,

pub(crate) static CREATE_PREPARED_STATEMENT: &str = "CreatePreparedStatement";
pub(crate) static CLOSE_PREPARED_STATEMENT: &str = "ClosePreparedStatement";

/// Implements FlightSqlService to handle the flight sql protocol
pub trait FlightSqlService: Sync + Send + Sized + 'static {
    /// When impl FlightSqlService, you can always set FlightService to Self
    type FlightService: FlightService;

    /// Accept authentication and return a token
    /// <>
    async fn do_handshake(
        _request: Request<Streaming<HandshakeRequest>>,
    ) -> Result<
        Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send>>>,
    > {
            "Handshake has no default implementation",

    /// Implementors may override to handle additional calls to do_get()
    async fn do_get_fallback(
        _request: Request<Ticket>,
        message: Any,
    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
            "do_get: The defined request is invalid: {}",

    /// Get a FlightInfo for executing a SQL query.
    async fn get_flight_info_statement(
        query: CommandStatementQuery,
        request: Request<FlightDescriptor>,
    ) -> Result<Response<FlightInfo>, Status>;

    /// Get a FlightInfo for executing an already created prepared statement.
    async fn get_flight_info_prepared_statement(
        query: CommandPreparedStatementQuery,
        request: Request<FlightDescriptor>,
    ) -> Result<Response<FlightInfo>, Status>;

    /// Get a FlightInfo for listing catalogs.
    async fn get_flight_info_catalogs(
        query: CommandGetCatalogs,
        request: Request<FlightDescriptor>,
    ) -> Result<Response<FlightInfo>, Status>;

    /// Get a FlightInfo for listing schemas.
    async fn get_flight_info_schemas(
        query: CommandGetDbSchemas,
        request: Request<FlightDescriptor>,
    ) -> Result<Response<FlightInfo>, Status>;

    /// Get a FlightInfo for listing tables.
    async fn get_flight_info_tables(
        query: CommandGetTables,
        request: Request<FlightDescriptor>,
    ) -> Result<Response<FlightInfo>, Status>;

    /// Get a FlightInfo to extract information about the table types.
    async fn get_flight_info_table_types(
        query: CommandGetTableTypes,
        request: Request<FlightDescriptor>,
    ) -> Result<Response<FlightInfo>, Status>;

    /// Get a FlightInfo for retrieving other information (See SqlInfo).
    async fn get_flight_info_sql_info(
        query: CommandGetSqlInfo,
        request: Request<FlightDescriptor>,
    ) -> Result<Response<FlightInfo>, Status>;

    /// Get a FlightInfo to extract information about primary and foreign keys.
    async fn get_flight_info_primary_keys(
        query: CommandGetPrimaryKeys,
        request: Request<FlightDescriptor>,
    ) -> Result<Response<FlightInfo>, Status>;

    /// Get a FlightInfo to extract information about exported keys.
    async fn get_flight_info_exported_keys(
        query: CommandGetExportedKeys,
        request: Request<FlightDescriptor>,
    ) -> Result<Response<FlightInfo>, Status>;

    /// Get a FlightInfo to extract information about imported keys.
    async fn get_flight_info_imported_keys(
        query: CommandGetImportedKeys,
        request: Request<FlightDescriptor>,
    ) -> Result<Response<FlightInfo>, Status>;

    /// Get a FlightInfo to extract information about cross reference.
    async fn get_flight_info_cross_reference(
        query: CommandGetCrossReference,
        request: Request<FlightDescriptor>,
    ) -> Result<Response<FlightInfo>, Status>;

    // do_get

    /// Get a FlightDataStream containing the query results.
    async fn do_get_statement(
        ticket: TicketStatementQuery,
        request: Request<Ticket>,
    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;

    /// Get a FlightDataStream containing the prepared statement query results.
    async fn do_get_prepared_statement(
        query: CommandPreparedStatementQuery,
        request: Request<Ticket>,
    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;

    /// Get a FlightDataStream containing the list of catalogs.
    async fn do_get_catalogs(
        query: CommandGetCatalogs,
        request: Request<Ticket>,
    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;

    /// Get a FlightDataStream containing the list of schemas.
    async fn do_get_schemas(
        query: CommandGetDbSchemas,
        request: Request<Ticket>,
    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;

    /// Get a FlightDataStream containing the list of tables.
    async fn do_get_tables(
        query: CommandGetTables,
        request: Request<Ticket>,
    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;

    /// Get a FlightDataStream containing the data related to the table types.
    async fn do_get_table_types(
        query: CommandGetTableTypes,
        request: Request<Ticket>,
    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;

    /// Get a FlightDataStream containing the list of SqlInfo results.
    async fn do_get_sql_info(
        query: CommandGetSqlInfo,
        request: Request<Ticket>,
    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;

    /// Get a FlightDataStream containing the data related to the primary and foreign keys.
    async fn do_get_primary_keys(
        query: CommandGetPrimaryKeys,
        request: Request<Ticket>,
    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;

    /// Get a FlightDataStream containing the data related to the exported keys.
    async fn do_get_exported_keys(
        query: CommandGetExportedKeys,
        request: Request<Ticket>,
    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;

    /// Get a FlightDataStream containing the data related to the imported keys.
    async fn do_get_imported_keys(
        query: CommandGetImportedKeys,
        request: Request<Ticket>,
    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;

    /// Get a FlightDataStream containing the data related to the cross reference.
    async fn do_get_cross_reference(
        query: CommandGetCrossReference,
        request: Request<Ticket>,
    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;

    // do_put

    /// Execute an update SQL statement.
    async fn do_put_statement_update(
        ticket: CommandStatementUpdate,
        request: Request<Streaming<FlightData>>,
    ) -> Result<i64, Status>;

    /// Bind parameters to given prepared statement.
    async fn do_put_prepared_statement_query(
        query: CommandPreparedStatementQuery,
        request: Request<Streaming<FlightData>>,
    ) -> Result<Response<<Self as FlightService>::DoPutStream>, Status>;

    /// Execute an update SQL prepared statement.
    async fn do_put_prepared_statement_update(
        query: CommandPreparedStatementUpdate,
        request: Request<Streaming<FlightData>>,
    ) -> Result<i64, Status>;

    // do_action

    /// Create a prepared statement from given SQL statement.
    async fn do_action_create_prepared_statement(
        query: ActionCreatePreparedStatementRequest,
        request: Request<Action>,
    ) -> Result<ActionCreatePreparedStatementResult, Status>;

    /// Close a prepared statement.
    async fn do_action_close_prepared_statement(
        query: ActionClosePreparedStatementRequest,
        request: Request<Action>,

    /// Register a new SqlInfo result, making it available when calling GetSqlInfo.
    async fn register_sql_info(&self, id: i32, result: &SqlInfo);

/// Implements the lower level interface to handle FlightSQL
impl<T: 'static> FlightService for T
    T: FlightSqlService + Send,
    type HandshakeStream =
        Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send + 'static>>;
    type ListFlightsStream =
        Pin<Box<dyn Stream<Item = Result<FlightInfo, Status>> + Send + 'static>>;
    type DoGetStream =
        Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send + 'static>>;
    type DoPutStream =
        Pin<Box<dyn Stream<Item = Result<PutResult, Status>> + Send + 'static>>;
    type DoActionStream = Pin<
        Box<dyn Stream<Item = Result<super::super::Result, Status>> + Send + 'static>,
    type ListActionsStream =
        Pin<Box<dyn Stream<Item = Result<ActionType, Status>> + Send + 'static>>;
    type DoExchangeStream =
        Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send + 'static>>;

    async fn handshake(
        request: Request<Streaming<HandshakeRequest>>,
    ) -> Result<Response<Self::HandshakeStream>, Status> {
        let res = self.do_handshake(request).await?;

    async fn list_flights(
        _request: Request<Criteria>,
    ) -> Result<Response<Self::ListFlightsStream>, Status> {
        Err(Status::unimplemented("Not yet implemented"))

    async fn get_flight_info(
        request: Request<FlightDescriptor>,
    ) -> Result<Response<FlightInfo>, Status> {
        let message =

        if<CommandStatementQuery>() {
            let token = message
            return self.get_flight_info_statement(token, request).await;
        if<CommandPreparedStatementQuery>() {
            let handle = message
            return self
                .get_flight_info_prepared_statement(handle, request)
        if<CommandGetCatalogs>() {
            let token = message
            return self.get_flight_info_catalogs(token, request).await;
        if<CommandGetDbSchemas>() {
            let token = message
            return self.get_flight_info_schemas(token, request).await;
        if<CommandGetTables>() {
            let token = message
            return self.get_flight_info_tables(token, request).await;
        if<CommandGetTableTypes>() {
            let token = message
            return self.get_flight_info_table_types(token, request).await;
        if<CommandGetSqlInfo>() {
            let token = message
            return self.get_flight_info_sql_info(token, request).await;
        if<CommandGetPrimaryKeys>() {
            let token = message
            return self.get_flight_info_primary_keys(token, request).await;
        if<CommandGetExportedKeys>() {
            let token = message
            return self.get_flight_info_exported_keys(token, request).await;
        if<CommandGetImportedKeys>() {
            let token = message
            return self.get_flight_info_imported_keys(token, request).await;
        if<CommandGetCrossReference>() {
            let token = message
            return self.get_flight_info_cross_reference(token, request).await;

            "get_flight_info: The defined request is invalid: {}",

    async fn get_schema(
        _request: Request<FlightDescriptor>,
    ) -> Result<Response<SchemaResult>, Status> {
        Err(Status::unimplemented("Not yet implemented"))

    async fn do_get(
        request: Request<Ticket>,
    ) -> Result<Response<Self::DoGetStream>, Status> {
        let msg: Any = Message::decode(&*request.get_ref().ticket)

        fn unpack<T: ProstMessageExt>(msg: Any) -> Result<T, Status> {
                .ok_or_else(|| Status::internal("Expected a command, but found none."))

        if<TicketStatementQuery>() {
            return self.do_get_statement(unpack(msg)?, request).await;
        if<CommandPreparedStatementQuery>() {
            return self.do_get_prepared_statement(unpack(msg)?, request).await;
        if<CommandGetCatalogs>() {
            return self.do_get_catalogs(unpack(msg)?, request).await;
        if<CommandGetDbSchemas>() {
            return self.do_get_schemas(unpack(msg)?, request).await;
        if<CommandGetTables>() {
            return self.do_get_tables(unpack(msg)?, request).await;
        if<CommandGetTableTypes>() {
            return self.do_get_table_types(unpack(msg)?, request).await;
        if<CommandGetSqlInfo>() {
            return self.do_get_sql_info(unpack(msg)?, request).await;
        if<CommandGetPrimaryKeys>() {
            return self.do_get_primary_keys(unpack(msg)?, request).await;
        if<CommandGetExportedKeys>() {
            return self.do_get_exported_keys(unpack(msg)?, request).await;
        if<CommandGetImportedKeys>() {
            return self.do_get_imported_keys(unpack(msg)?, request).await;
        if<CommandGetCrossReference>() {
            return self.do_get_cross_reference(unpack(msg)?, request).await;

        self.do_get_fallback(request, msg).await

    async fn do_put(
        mut request: Request<Streaming<FlightData>>,
    ) -> Result<Response<Self::DoPutStream>, Status> {
        let cmd = request.get_mut().message().await?.unwrap();
        let message = Any::decode(&*cmd.flight_descriptor.unwrap().cmd)
        if<CommandStatementUpdate>() {
            let token = message
            let record_count = self.do_put_statement_update(token, request).await?;
            let result = DoPutUpdateResult { record_count };
            let output = futures::stream::iter(vec![Ok(PutResult {
                app_metadata: result.encode_to_vec().into(),
            return Ok(Response::new(Box::pin(output)));
        if<CommandPreparedStatementQuery>() {
            let token = message
            return self.do_put_prepared_statement_query(token, request).await;
        if<CommandPreparedStatementUpdate>() {
            let handle = message
            let record_count = self
                .do_put_prepared_statement_update(handle, request)
            let result = DoPutUpdateResult { record_count };
            let output = futures::stream::iter(vec![Ok(PutResult {
                app_metadata: result.encode_to_vec().into(),
            return Ok(Response::new(Box::pin(output)));

            "do_put: The defined request is invalid: {}",

    async fn list_actions(
        _request: Request<Empty>,
    ) -> Result<Response<Self::ListActionsStream>, Status> {
        let create_prepared_statement_action_type = ActionType {
            r#type: CREATE_PREPARED_STATEMENT.to_string(),
            description: "Creates a reusable prepared statement resource on the server.\n
                Request Message: ActionCreatePreparedStatementRequest\n
                Response Message: ActionCreatePreparedStatementResult"
        let close_prepared_statement_action_type = ActionType {
            r#type: CLOSE_PREPARED_STATEMENT.to_string(),
            description: "Closes a reusable prepared statement resource on the server.\n
                Request Message: ActionClosePreparedStatementRequest\n
                Response Message: N/A"
        let actions: Vec<Result<ActionType, Status>> = vec![
        let output = futures::stream::iter(actions);
        Ok(Response::new(Box::pin(output) as Self::ListActionsStream))

    async fn do_action(
        request: Request<Action>,
    ) -> Result<Response<Self::DoActionStream>, Status> {
        if request.get_ref().r#type == CREATE_PREPARED_STATEMENT {
            let any =

            let cmd: ActionCreatePreparedStatementRequest = any
                .ok_or_else(|| {
                        "Unable to unpack ActionCreatePreparedStatementRequest.",
            let stmt = self
                .do_action_create_prepared_statement(cmd, request)
            let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
                body: stmt.as_any().encode_to_vec().into(),
            return Ok(Response::new(Box::pin(output)));
        if request.get_ref().r#type == CLOSE_PREPARED_STATEMENT {
            let any =

            let cmd: ActionClosePreparedStatementRequest = any
                .ok_or_else(|| {
                        "Unable to unpack ActionClosePreparedStatementRequest.",
            self.do_action_close_prepared_statement(cmd, request).await;
            return Ok(Response::new(Box::pin(futures::stream::empty())));

            "do_action: The defined request is invalid: {:?}",

    async fn do_exchange(
        _request: Request<Streaming<FlightData>>,
    ) -> Result<Response<Self::DoExchangeStream>, Status> {
        Err(Status::unimplemented("Not yet implemented"))

fn decode_error_to_status(err: prost::DecodeError) -> Status {

fn arrow_error_to_status(err: arrow_schema::ArrowError) -> Status {