use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::error::FlightError;
use arrow_flight::sql::metadata::{SqlInfoData, SqlInfoDataBuilder};
use arrow_flight::sql::{
server::FlightSqlService, ActionBeginSavepointRequest, ActionBeginSavepointResult,
ActionBeginTransactionRequest, ActionBeginTransactionResult, ActionCancelQueryRequest,
ActionCancelQueryResult, ActionClosePreparedStatementRequest,
ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult,
ActionCreatePreparedSubstraitPlanRequest, ActionEndSavepointRequest,
ActionEndTransactionRequest, Any, CommandGetCatalogs, CommandGetCrossReference,
CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys,
CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo,
CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery,
CommandStatementSubstraitPlan, CommandStatementUpdate, ProstMessageExt, SqlInfo,
TicketStatementQuery,
};
use arrow_flight::{
flight_service_server::FlightService, Action, FlightData, FlightDescriptor, FlightEndpoint,
FlightInfo, HandshakeRequest, HandshakeResponse, Ticket,
};
use datafusion::datasource::TableProvider;
use datafusion::sql::parser::Statement;
use datafusion::sql::sqlparser::ast::{self, SetExpr};
use ella_engine::engine::EllaState;
use ella_engine::{EngineError, Plan};
use futures::{SinkExt, Stream, TryStreamExt};
use once_cell::sync::Lazy;
use prost::Message;
use std::pin::Pin;
use std::sync::Arc;
use tonic::{Request, Response, Status, Streaming};
use super::auth::{connection, ConnectionManager};
macro_rules! status {
($desc:expr, $err:expr) => {
Status::internal(format!("{}: {} at {}:{}", $desc, $err, file!(), line!()))
};
}
static SQL_INFO: Lazy<SqlInfoData> = Lazy::new(|| {
let mut builder = SqlInfoDataBuilder::new();
builder.append(SqlInfo::FlightSqlServerName, "ella");
builder.append(SqlInfo::FlightSqlServerVersion, env!("CARGO_PKG_VERSION"));
builder.append(SqlInfo::FlightSqlServerArrowVersion, "1.3");
builder.build().unwrap()
});
#[derive(Debug, Clone)]
pub(crate) struct EllaSqlService {
connections: ConnectionManager,
}
impl EllaSqlService {
pub fn new(connections: ConnectionManager) -> Self {
Self { connections }
}
}
impl EllaSqlService {
async fn execute_plan(
&self,
state: &EllaState,
ticket: &[u8],
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
let stream =
ella_engine::lazy::Lazy::new(Plan::from_bytes(ticket)?, Arc::new(state.backend()))
.stream()
.await?;
let schema = stream.arrow_schema();
let stream = stream
.into_inner()
.map_err(|err| FlightError::ExternalError(Box::new(err)));
let stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.build(stream)
.map_err(Into::into);
Ok(Response::new(Box::pin(stream)))
}
}
#[tonic::async_trait]
impl FlightSqlService for EllaSqlService {
type FlightService = EllaSqlService;
async fn do_handshake(
&self,
_request: Request<Streaming<HandshakeRequest>>,
) -> Result<
Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send>>>,
Status,
> {
let token = self.connections.handshake()?.into_bytes();
let result = HandshakeResponse {
protocol_version: 0,
payload: token.into(),
};
let result = Ok(result);
let output = futures::stream::iter(vec![result]);
return Ok(Response::new(Box::pin(output)));
}
#[tracing::instrument(skip_all)]
async fn do_get_fallback(
&self,
request: Request<Ticket>,
_message: Any,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
let state = connection(&request)?.read();
let ticket = request.into_inner().ticket;
self.execute_plan(&state, &ticket).await
}
#[tracing::instrument(skip(self, request))]
async fn get_flight_info_statement(
&self,
query: CommandStatementQuery,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let state = connection(&request)?.read();
let plan = state.query(&query.query).await?;
let statement_handle = plan.plan().to_bytes().into();
let ticket = TicketStatementQuery { statement_handle };
let endpoint = FlightEndpoint {
ticket: Some(Ticket {
ticket: ticket.as_any().encode_to_vec().into(),
}),
location: vec![],
};
let info = FlightInfo::new()
.try_with_schema(&plan.plan().arrow_schema())
.map_err(crate::Error::from)?
.with_endpoint(endpoint)
.with_ordered(true)
.with_descriptor(request.into_inner());
Ok(Response::new(info))
}
#[tracing::instrument(skip(self, _request))]
async fn get_flight_info_substrait_plan(
&self,
_query: CommandStatementSubstraitPlan,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
Err(Status::unimplemented(
"get_flight_info_substrait_plan not implemented",
))
}
#[tracing::instrument(skip(self, _request))]
async fn get_flight_info_prepared_statement(
&self,
_cmd: CommandPreparedStatementQuery,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
Err(Status::unimplemented(
"get_flight_info_prepared_statement not implemented",
))
}
#[tracing::instrument(skip(self, request))]
async fn get_flight_info_catalogs(
&self,
query: CommandGetCatalogs,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let flight_descriptor = request.into_inner();
let ticket = Ticket {
ticket: query.as_any().encode_to_vec().into(),
};
let endpoint = FlightEndpoint::new().with_ticket(ticket);
let flight_info = FlightInfo::new()
.try_with_schema(&query.into_builder().schema())
.map_err(|e| status!("Unable to encode schema", e))?
.with_endpoint(endpoint)
.with_descriptor(flight_descriptor);
Ok(tonic::Response::new(flight_info))
}
#[tracing::instrument(skip(self, request))]
async fn get_flight_info_schemas(
&self,
query: CommandGetDbSchemas,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let flight_descriptor = request.into_inner();
let ticket = Ticket {
ticket: query.as_any().encode_to_vec().into(),
};
let endpoint = FlightEndpoint::new().with_ticket(ticket);
let flight_info = FlightInfo::new()
.try_with_schema(&query.into_builder().schema())
.map_err(|e| status!("Unable to encode schema", e))?
.with_endpoint(endpoint)
.with_descriptor(flight_descriptor);
Ok(tonic::Response::new(flight_info))
}
#[tracing::instrument(skip(self, request))]
async fn get_flight_info_tables(
&self,
query: CommandGetTables,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let flight_descriptor = request.into_inner();
let ticket = Ticket {
ticket: query.as_any().encode_to_vec().into(),
};
let endpoint = FlightEndpoint::new().with_ticket(ticket);
let flight_info = FlightInfo::new()
.try_with_schema(&query.into_builder().schema())
.map_err(|e| status!("Unable to encode schema", e))?
.with_endpoint(endpoint)
.with_descriptor(flight_descriptor);
Ok(tonic::Response::new(flight_info))
}
#[tracing::instrument(skip(self, _request))]
async fn get_flight_info_table_types(
&self,
_query: CommandGetTableTypes,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
Err(Status::unimplemented(
"get_flight_info_table_types not implemented",
))
}
#[tracing::instrument(skip(self, request))]
async fn get_flight_info_sql_info(
&self,
query: CommandGetSqlInfo,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let flight_descriptor = request.into_inner();
let ticket = Ticket::new(query.as_any().encode_to_vec());
let endpoint = FlightEndpoint::new().with_ticket(ticket);
let flight_info = FlightInfo::new()
.try_with_schema(query.into_builder(&SQL_INFO).schema().as_ref())
.map_err(|e| status!("Unable to encode schema", e))?
.with_endpoint(endpoint)
.with_descriptor(flight_descriptor);
Ok(tonic::Response::new(flight_info))
}
#[tracing::instrument(skip(self, _request))]
async fn get_flight_info_primary_keys(
&self,
_query: CommandGetPrimaryKeys,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
Err(Status::unimplemented(
"get_flight_info_primary_keys not implemented",
))
}
#[tracing::instrument(skip(self, _request))]
async fn get_flight_info_exported_keys(
&self,
_query: CommandGetExportedKeys,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
Err(Status::unimplemented(
"get_flight_info_exported_keys not implemented",
))
}
#[tracing::instrument(skip(self, _request))]
async fn get_flight_info_imported_keys(
&self,
_query: CommandGetImportedKeys,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
Err(Status::unimplemented(
"get_flight_info_imported_keys not implemented",
))
}
#[tracing::instrument(skip(self, _request))]
async fn get_flight_info_cross_reference(
&self,
_query: CommandGetCrossReference,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
Err(Status::unimplemented(
"get_flight_info_imported_keys not implemented",
))
}
#[tracing::instrument(skip(self, _request))]
async fn get_flight_info_xdbc_type_info(
&self,
_query: CommandGetXdbcTypeInfo,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
Err(Status::unimplemented(
"get_flight_info_xdbc_type_info not implemented",
))
}
#[tracing::instrument(skip_all)]
async fn do_get_statement(
&self,
ticket: TicketStatementQuery,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
let state = connection(&request)?.read();
self.execute_plan(&state, &ticket.statement_handle).await
}
#[tracing::instrument(skip(self, _request))]
async fn do_get_prepared_statement(
&self,
_query: CommandPreparedStatementQuery,
_request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
Err(Status::unimplemented(
"do_get_prepared_statement not implemented",
))
}
#[tracing::instrument(skip(self, request))]
async fn do_get_catalogs(
&self,
query: CommandGetCatalogs,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
let state = connection(&request)?.read();
let mut builder = query.into_builder();
for catalog in state.cluster().catalogs() {
builder.append(catalog.id().to_string());
}
let schema = builder.schema();
let batch = builder.build();
let stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.build(futures::stream::once(async { batch }))
.map_err(Status::from);
Ok(Response::new(Box::pin(stream)))
}
#[tracing::instrument(skip(self, request))]
async fn do_get_schemas(
&self,
query: CommandGetDbSchemas,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
let state = connection(&request)?.read();
let mut builder = query.into_builder();
for catalog in state.cluster().catalogs() {
for schema in catalog.schemas() {
builder.append(&schema.id().catalog, &schema.id().schema);
}
}
let schema = builder.schema();
let batch = builder.build();
let stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.build(futures::stream::once(async { batch }))
.map_err(Status::from);
Ok(Response::new(Box::pin(stream)))
}
#[tracing::instrument(skip(self, request))]
async fn do_get_tables(
&self,
query: CommandGetTables,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
let state = connection(&request)?.read();
let mut builder = query.into_builder();
for catalog in state.cluster().catalogs() {
for schema in catalog.schemas() {
for table in schema.tables() {
let id = table.id();
builder
.append(
&id.catalog,
&id.schema,
&id.table,
table.kind(),
&table.schema(),
)
.map_err(|e| status!("Failed to serialize table info", e))?;
}
}
}
let schema = builder.schema();
let batch = builder.build();
let stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.build(futures::stream::once(async { batch }))
.map_err(Status::from);
Ok(Response::new(Box::pin(stream)))
}
#[tracing::instrument(skip(self, _request))]
async fn do_get_table_types(
&self,
_query: CommandGetTableTypes,
_request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
Err(Status::unimplemented("do_get_table_types not implemented"))
}
#[tracing::instrument(skip(self, _request))]
async fn do_get_sql_info(
&self,
query: CommandGetSqlInfo,
_request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
let builder = query.into_builder(&SQL_INFO);
let schema = builder.schema();
let batch = builder.build();
let stream = FlightDataEncoderBuilder::new()
.with_schema(schema)
.build(futures::stream::once(async { batch }))
.map_err(Status::from);
Ok(Response::new(Box::pin(stream)))
}
#[tracing::instrument(skip(self, _request))]
async fn do_get_primary_keys(
&self,
_query: CommandGetPrimaryKeys,
_request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
Err(Status::unimplemented("do_get_primary_keys not implemented"))
}
#[tracing::instrument(skip(self, _request))]
async fn do_get_exported_keys(
&self,
_query: CommandGetExportedKeys,
_request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
Err(Status::unimplemented(
"do_get_exported_keys not implemented",
))
}
#[tracing::instrument(skip(self, _request))]
async fn do_get_imported_keys(
&self,
_query: CommandGetImportedKeys,
_request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
Err(Status::unimplemented(
"do_get_imported_keys not implemented",
))
}
#[tracing::instrument(skip(self, _request))]
async fn do_get_cross_reference(
&self,
_query: CommandGetCrossReference,
_request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
Err(Status::unimplemented(
"do_get_cross_reference not implemented",
))
}
#[tracing::instrument(skip(self, _request))]
async fn do_get_xdbc_type_info(
&self,
_query: CommandGetXdbcTypeInfo,
_request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
Err(Status::unimplemented(
"do_get_xdbc_type_info not implemented",
))
}
#[tracing::instrument(skip(self, request))]
async fn do_put_statement_update(
&self,
ticket: CommandStatementUpdate,
request: Request<Streaming<FlightData>>,
) -> Result<i64, Status> {
let state = connection(&request)?.read();
let session = state.session();
let stmt = session
.sql_to_statement(
&ticket.query,
&session.config().options().sql_parser.dialect,
)
.map_err(crate::Error::from)?;
if let Statement::Statement(stmt) = stmt {
if let ast::Statement::Insert {
source, table_name, ..
} = stmt.as_ref()
{
if let SetExpr::Table(src) = source.body.as_ref() {
if src.schema_name.is_none() && src.table_name.as_deref() == Some("this") {
let mut stream = FlightRecordBatchStream::new_from_flight_data(
request.into_inner().map_err(Into::into),
);
let mut pb = state
.table(state.resolve(table_name.to_string().into()))
.and_then(|t| t.as_topic())
.ok_or_else(|| {
crate::Error::from(EngineError::TableNotFound(
table_name.to_string(),
))
})?
.publish();
let mut rows = 0;
while let Some(batch) = stream.try_next().await? {
rows += batch.num_rows();
pb.send(batch).await?;
}
pb.flush().await?;
return Ok(rows as i64);
}
}
}
}
todo!()
}
#[tracing::instrument(skip(self, _request))]
async fn do_put_substrait_plan(
&self,
_ticket: CommandStatementSubstraitPlan,
_request: Request<Streaming<FlightData>>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_substrait_plan not implemented",
))
}
#[tracing::instrument(skip(self, _request))]
async fn do_put_prepared_statement_query(
&self,
_query: CommandPreparedStatementQuery,
_request: Request<Streaming<FlightData>>,
) -> Result<Response<<Self as FlightService>::DoPutStream>, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_query not implemented",
))
}
#[tracing::instrument(skip(self, _request))]
async fn do_put_prepared_statement_update(
&self,
_query: CommandPreparedStatementUpdate,
_request: Request<Streaming<FlightData>>,
) -> Result<i64, Status> {
Err(Status::unimplemented(
"do_put_prepared_statement_update not implemented",
))
}
#[tracing::instrument(skip(self, _request))]
async fn do_action_create_prepared_statement(
&self,
_query: ActionCreatePreparedStatementRequest,
_request: Request<Action>,
) -> Result<ActionCreatePreparedStatementResult, Status> {
Err(Status::unimplemented(
"do_action_create_prepared_statement not implemented",
))
}
#[tracing::instrument(skip(self, _request))]
async fn do_action_close_prepared_statement(
&self,
_query: ActionClosePreparedStatementRequest,
_request: Request<Action>,
) -> Result<(), Status> {
Err(Status::unimplemented(
"Implement do_action_close_prepared_statement",
))
}
#[tracing::instrument(skip(self, _request))]
async fn do_action_create_prepared_substrait_plan(
&self,
_query: ActionCreatePreparedSubstraitPlanRequest,
_request: Request<Action>,
) -> Result<ActionCreatePreparedStatementResult, Status> {
Err(Status::unimplemented(
"Implement do_action_create_prepared_substrait_plan",
))
}
#[tracing::instrument(skip(self, _request))]
async fn do_action_begin_transaction(
&self,
_query: ActionBeginTransactionRequest,
_request: Request<Action>,
) -> Result<ActionBeginTransactionResult, Status> {
Err(Status::unimplemented(
"Implement do_action_begin_transaction",
))
}
#[tracing::instrument(skip(self, _request))]
async fn do_action_end_transaction(
&self,
_query: ActionEndTransactionRequest,
_request: Request<Action>,
) -> Result<(), Status> {
Err(Status::unimplemented("Implement do_action_end_transaction"))
}
#[tracing::instrument(skip(self, _request))]
async fn do_action_begin_savepoint(
&self,
_query: ActionBeginSavepointRequest,
_request: Request<Action>,
) -> Result<ActionBeginSavepointResult, Status> {
Err(Status::unimplemented("Implement do_action_begin_savepoint"))
}
#[tracing::instrument(skip(self, _request))]
async fn do_action_end_savepoint(
&self,
_query: ActionEndSavepointRequest,
_request: Request<Action>,
) -> Result<(), Status> {
Err(Status::unimplemented("Implement do_action_end_savepoint"))
}
#[tracing::instrument(skip(self, _request))]
async fn do_action_cancel_query(
&self,
_query: ActionCancelQueryRequest,
_request: Request<Action>,
) -> Result<ActionCancelQueryResult, Status> {
Err(Status::unimplemented("Implement do_action_cancel_query"))
}
async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
}