use std::pin::Pin;
use crate::sql::Any;
use futures::Stream;
use prost::Message;
use tonic::{Request, Response, Status, Streaming};
use super::{
super::{
flight_service_server::FlightService, Action, ActionType, Criteria, Empty,
FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse,
PutResult, SchemaResult, Ticket,
},
ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
ActionCreatePreparedStatementResult, CommandGetCatalogs, CommandGetCrossReference,
CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery,
CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt, SqlInfo,
TicketStatementQuery,
};
pub(crate) static CREATE_PREPARED_STATEMENT: &str = "CreatePreparedStatement";
pub(crate) static CLOSE_PREPARED_STATEMENT: &str = "ClosePreparedStatement";
#[tonic::async_trait]
pub trait FlightSqlService: Sync + Send + Sized + 'static {
type FlightService: FlightService;
async fn do_handshake(
&self,
_request: Request<Streaming<HandshakeRequest>>,
) -> Result<
Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send>>>,
Status,
> {
Err(Status::unimplemented(
"Handshake has no default implementation",
))
}
async fn do_get_fallback(
&self,
_request: Request<Ticket>,
message: Any,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
Err(Status::unimplemented(format!(
"do_get: The defined request is invalid: {}",
message.type_url
)))
}
async fn get_flight_info_statement(
&self,
query: CommandStatementQuery,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status>;
async fn get_flight_info_prepared_statement(
&self,
query: CommandPreparedStatementQuery,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status>;
async fn get_flight_info_catalogs(
&self,
query: CommandGetCatalogs,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status>;
async fn get_flight_info_schemas(
&self,
query: CommandGetDbSchemas,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status>;
async fn get_flight_info_tables(
&self,
query: CommandGetTables,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status>;
async fn get_flight_info_table_types(
&self,
query: CommandGetTableTypes,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status>;
async fn get_flight_info_sql_info(
&self,
query: CommandGetSqlInfo,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status>;
async fn get_flight_info_primary_keys(
&self,
query: CommandGetPrimaryKeys,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status>;
async fn get_flight_info_exported_keys(
&self,
query: CommandGetExportedKeys,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status>;
async fn get_flight_info_imported_keys(
&self,
query: CommandGetImportedKeys,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status>;
async fn get_flight_info_cross_reference(
&self,
query: CommandGetCrossReference,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status>;
async fn do_get_statement(
&self,
ticket: TicketStatementQuery,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;
async fn do_get_prepared_statement(
&self,
query: CommandPreparedStatementQuery,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;
async fn do_get_catalogs(
&self,
query: CommandGetCatalogs,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;
async fn do_get_schemas(
&self,
query: CommandGetDbSchemas,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;
async fn do_get_tables(
&self,
query: CommandGetTables,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;
async fn do_get_table_types(
&self,
query: CommandGetTableTypes,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;
async fn do_get_sql_info(
&self,
query: CommandGetSqlInfo,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;
async fn do_get_primary_keys(
&self,
query: CommandGetPrimaryKeys,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;
async fn do_get_exported_keys(
&self,
query: CommandGetExportedKeys,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;
async fn do_get_imported_keys(
&self,
query: CommandGetImportedKeys,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;
async fn do_get_cross_reference(
&self,
query: CommandGetCrossReference,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status>;
async fn do_put_statement_update(
&self,
ticket: CommandStatementUpdate,
request: Request<Streaming<FlightData>>,
) -> Result<i64, Status>;
async fn do_put_prepared_statement_query(
&self,
query: CommandPreparedStatementQuery,
request: Request<Streaming<FlightData>>,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status>;
async fn do_put_prepared_statement_update(
&self,
query: CommandPreparedStatementUpdate,
request: Request<Streaming<FlightData>>,
) -> Result<i64, Status>;
async fn do_action_create_prepared_statement(
&self,
query: ActionCreatePreparedStatementRequest,
request: Request<Action>,
) -> Result<ActionCreatePreparedStatementResult, Status>;
async fn do_action_close_prepared_statement(
&self,
query: ActionClosePreparedStatementRequest,
request: Request<Action>,
);
async fn register_sql_info(&self, id: i32, result: &SqlInfo);
}
#[tonic::async_trait]
impl<T: 'static> FlightService for T
where
T: FlightSqlService + Send,
{
type HandshakeStream =
Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send + 'static>>;
type ListFlightsStream =
Pin<Box<dyn Stream<Item = Result<FlightInfo, Status>> + Send + 'static>>;
type DoGetStream =
Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send + 'static>>;
type DoPutStream =
Pin<Box<dyn Stream<Item = Result<PutResult, Status>> + Send + 'static>>;
type DoActionStream = Pin<
Box<dyn Stream<Item = Result<super::super::Result, Status>> + Send + 'static>,
>;
type ListActionsStream =
Pin<Box<dyn Stream<Item = Result<ActionType, Status>> + Send + 'static>>;
type DoExchangeStream =
Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send + 'static>>;
async fn handshake(
&self,
request: Request<Streaming<HandshakeRequest>>,
) -> Result<Response<Self::HandshakeStream>, Status> {
let res = self.do_handshake(request).await?;
Ok(res)
}
async fn list_flights(
&self,
_request: Request<Criteria>,
) -> Result<Response<Self::ListFlightsStream>, Status> {
Err(Status::unimplemented("Not yet implemented"))
}
async fn get_flight_info(
&self,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let message =
Any::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?;
if message.is::<CommandStatementQuery>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_statement(token, request).await;
}
if message.is::<CommandPreparedStatementQuery>() {
let handle = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self
.get_flight_info_prepared_statement(handle, request)
.await;
}
if message.is::<CommandGetCatalogs>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_catalogs(token, request).await;
}
if message.is::<CommandGetDbSchemas>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_schemas(token, request).await;
}
if message.is::<CommandGetTables>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_tables(token, request).await;
}
if message.is::<CommandGetTableTypes>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_table_types(token, request).await;
}
if message.is::<CommandGetSqlInfo>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_sql_info(token, request).await;
}
if message.is::<CommandGetPrimaryKeys>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_primary_keys(token, request).await;
}
if message.is::<CommandGetExportedKeys>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_exported_keys(token, request).await;
}
if message.is::<CommandGetImportedKeys>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_imported_keys(token, request).await;
}
if message.is::<CommandGetCrossReference>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.get_flight_info_cross_reference(token, request).await;
}
Err(Status::unimplemented(format!(
"get_flight_info: The defined request is invalid: {}",
message.type_url
)))
}
async fn get_schema(
&self,
_request: Request<FlightDescriptor>,
) -> Result<Response<SchemaResult>, Status> {
Err(Status::unimplemented("Not yet implemented"))
}
async fn do_get(
&self,
request: Request<Ticket>,
) -> Result<Response<Self::DoGetStream>, Status> {
let msg: Any = Message::decode(&*request.get_ref().ticket)
.map_err(decode_error_to_status)?;
fn unpack<T: ProstMessageExt>(msg: Any) -> Result<T, Status> {
msg.unpack()
.map_err(arrow_error_to_status)?
.ok_or_else(|| Status::internal("Expected a command, but found none."))
}
if msg.is::<TicketStatementQuery>() {
return self.do_get_statement(unpack(msg)?, request).await;
}
if msg.is::<CommandPreparedStatementQuery>() {
return self.do_get_prepared_statement(unpack(msg)?, request).await;
}
if msg.is::<CommandGetCatalogs>() {
return self.do_get_catalogs(unpack(msg)?, request).await;
}
if msg.is::<CommandGetDbSchemas>() {
return self.do_get_schemas(unpack(msg)?, request).await;
}
if msg.is::<CommandGetTables>() {
return self.do_get_tables(unpack(msg)?, request).await;
}
if msg.is::<CommandGetTableTypes>() {
return self.do_get_table_types(unpack(msg)?, request).await;
}
if msg.is::<CommandGetSqlInfo>() {
return self.do_get_sql_info(unpack(msg)?, request).await;
}
if msg.is::<CommandGetPrimaryKeys>() {
return self.do_get_primary_keys(unpack(msg)?, request).await;
}
if msg.is::<CommandGetExportedKeys>() {
return self.do_get_exported_keys(unpack(msg)?, request).await;
}
if msg.is::<CommandGetImportedKeys>() {
return self.do_get_imported_keys(unpack(msg)?, request).await;
}
if msg.is::<CommandGetCrossReference>() {
return self.do_get_cross_reference(unpack(msg)?, request).await;
}
self.do_get_fallback(request, msg).await
}
async fn do_put(
&self,
mut request: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoPutStream>, Status> {
let cmd = request.get_mut().message().await?.unwrap();
let message = Any::decode(&*cmd.flight_descriptor.unwrap().cmd)
.map_err(decode_error_to_status)?;
if message.is::<CommandStatementUpdate>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
let record_count = self.do_put_statement_update(token, request).await?;
let result = DoPutUpdateResult { record_count };
let output = futures::stream::iter(vec![Ok(PutResult {
app_metadata: result.encode_to_vec().into(),
})]);
return Ok(Response::new(Box::pin(output)));
}
if message.is::<CommandPreparedStatementQuery>() {
let token = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
return self.do_put_prepared_statement_query(token, request).await;
}
if message.is::<CommandPreparedStatementUpdate>() {
let handle = message
.unpack()
.map_err(arrow_error_to_status)?
.expect("unreachable");
let record_count = self
.do_put_prepared_statement_update(handle, request)
.await?;
let result = DoPutUpdateResult { record_count };
let output = futures::stream::iter(vec![Ok(PutResult {
app_metadata: result.encode_to_vec().into(),
})]);
return Ok(Response::new(Box::pin(output)));
}
Err(Status::invalid_argument(format!(
"do_put: The defined request is invalid: {}",
message.type_url
)))
}
async fn list_actions(
&self,
_request: Request<Empty>,
) -> Result<Response<Self::ListActionsStream>, Status> {
let create_prepared_statement_action_type = ActionType {
r#type: CREATE_PREPARED_STATEMENT.to_string(),
description: "Creates a reusable prepared statement resource on the server.\n
Request Message: ActionCreatePreparedStatementRequest\n
Response Message: ActionCreatePreparedStatementResult"
.into(),
};
let close_prepared_statement_action_type = ActionType {
r#type: CLOSE_PREPARED_STATEMENT.to_string(),
description: "Closes a reusable prepared statement resource on the server.\n
Request Message: ActionClosePreparedStatementRequest\n
Response Message: N/A"
.into(),
};
let actions: Vec<Result<ActionType, Status>> = vec![
Ok(create_prepared_statement_action_type),
Ok(close_prepared_statement_action_type),
];
let output = futures::stream::iter(actions);
Ok(Response::new(Box::pin(output) as Self::ListActionsStream))
}
async fn do_action(
&self,
request: Request<Action>,
) -> Result<Response<Self::DoActionStream>, Status> {
if request.get_ref().r#type == CREATE_PREPARED_STATEMENT {
let any =
Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
let cmd: ActionCreatePreparedStatementRequest = any
.unpack()
.map_err(arrow_error_to_status)?
.ok_or_else(|| {
Status::invalid_argument(
"Unable to unpack ActionCreatePreparedStatementRequest.",
)
})?;
let stmt = self
.do_action_create_prepared_statement(cmd, request)
.await?;
let output = futures::stream::iter(vec![Ok(super::super::gen::Result {
body: stmt.as_any().encode_to_vec().into(),
})]);
return Ok(Response::new(Box::pin(output)));
}
if request.get_ref().r#type == CLOSE_PREPARED_STATEMENT {
let any =
Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
let cmd: ActionClosePreparedStatementRequest = any
.unpack()
.map_err(arrow_error_to_status)?
.ok_or_else(|| {
Status::invalid_argument(
"Unable to unpack ActionClosePreparedStatementRequest.",
)
})?;
self.do_action_close_prepared_statement(cmd, request).await;
return Ok(Response::new(Box::pin(futures::stream::empty())));
}
Err(Status::invalid_argument(format!(
"do_action: The defined request is invalid: {:?}",
request.get_ref().r#type
)))
}
async fn do_exchange(
&self,
_request: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoExchangeStream>, Status> {
Err(Status::unimplemented("Not yet implemented"))
}
}
fn decode_error_to_status(err: prost::DecodeError) -> Status {
Status::invalid_argument(format!("{err:?}"))
}
fn arrow_error_to_status(err: arrow_schema::ArrowError) -> Status {
Status::internal(format!("{err:?}"))
}