use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use bytes::Bytes;
use std::collections::HashMap;
use std::str::FromStr;
use tonic::metadata::AsciiMetadataKey;
use crate::flight_service_client::FlightServiceClient;
use crate::sql::server::{CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT};
use crate::sql::{
ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
ActionCreatePreparedStatementResult, Any, CommandGetCatalogs,
CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys,
CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo,
CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo,
CommandPreparedStatementQuery, CommandStatementQuery, CommandStatementUpdate,
DoPutUpdateResult, ProstMessageExt, SqlInfo,
};
use crate::{
Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest,
HandshakeResponse, IpcMessage, PutResult, Ticket,
};
use arrow_array::RecordBatch;
use arrow_buffer::Buffer;
use arrow_ipc::convert::fb_to_schema;
use arrow_ipc::reader::read_record_batch;
use arrow_ipc::{root_as_message, MessageHeader};
use arrow_schema::{ArrowError, Schema, SchemaRef};
use futures::{stream, TryStreamExt};
use prost::Message;
use tonic::transport::Channel;
use tonic::{IntoRequest, Streaming};
#[derive(Debug, Clone)]
pub struct FlightSqlServiceClient<T> {
token: Option<String>,
headers: HashMap<String, String>,
flight_client: FlightServiceClient<T>,
}
impl FlightSqlServiceClient<Channel> {
pub fn new(channel: Channel) -> Self {
let flight_client = FlightServiceClient::new(channel);
FlightSqlServiceClient {
token: None,
flight_client,
headers: HashMap::default(),
}
}
pub fn inner(&self) -> &FlightServiceClient<Channel> {
&self.flight_client
}
pub fn inner_mut(&mut self) -> &mut FlightServiceClient<Channel> {
&mut self.flight_client
}
pub fn into_inner(self) -> FlightServiceClient<Channel> {
self.flight_client
}
pub fn set_token(&mut self, token: String) {
self.token = Some(token);
}
pub fn set_header(&mut self, key: impl Into<String>, value: impl Into<String>) {
let key: String = key.into();
let value: String = value.into();
self.headers.insert(key, value);
}
async fn get_flight_info_for_command<M: ProstMessageExt>(
&mut self,
cmd: M,
) -> Result<FlightInfo, ArrowError> {
let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let req = self.set_request_headers(descriptor.into_request())?;
let fi = self
.flight_client
.get_flight_info(req)
.await
.map_err(status_to_arrow_error)?
.into_inner();
Ok(fi)
}
pub async fn execute(&mut self, query: String) -> Result<FlightInfo, ArrowError> {
let cmd = CommandStatementQuery { query };
self.get_flight_info_for_command(cmd).await
}
pub async fn handshake(
&mut self,
username: &str,
password: &str,
) -> Result<Bytes, ArrowError> {
let cmd = HandshakeRequest {
protocol_version: 0,
payload: Default::default(),
};
let mut req = tonic::Request::new(stream::iter(vec![cmd]));
let val = BASE64_STANDARD.encode(format!("{username}:{password}"));
let val = format!("Basic {val}")
.parse()
.map_err(|_| ArrowError::ParseError("Cannot parse header".to_string()))?;
req.metadata_mut().insert("authorization", val);
let req = self.set_request_headers(req)?;
let resp = self
.flight_client
.handshake(req)
.await
.map_err(|e| ArrowError::IoError(format!("Can't handshake {e}")))?;
if let Some(auth) = resp.metadata().get("authorization") {
let auth = auth.to_str().map_err(|_| {
ArrowError::ParseError("Can't read auth header".to_string())
})?;
let bearer = "Bearer ";
if !auth.starts_with(bearer) {
Err(ArrowError::ParseError("Invalid auth header!".to_string()))?;
}
let auth = auth[bearer.len()..].to_string();
self.token = Some(auth);
}
let responses: Vec<HandshakeResponse> =
resp.into_inner().try_collect().await.map_err(|_| {
ArrowError::ParseError("Can't collect responses".to_string())
})?;
let resp = match responses.as_slice() {
[resp] => resp.payload.clone(),
[] => Bytes::new(),
_ => Err(ArrowError::ParseError(
"Multiple handshake responses".to_string(),
))?,
};
Ok(resp)
}
pub async fn execute_update(&mut self, query: String) -> Result<i64, ArrowError> {
let cmd = CommandStatementUpdate { query };
let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let req = self.set_request_headers(
stream::iter(vec![FlightData {
flight_descriptor: Some(descriptor),
..Default::default()
}])
.into_request(),
)?;
let mut result = self
.flight_client
.do_put(req)
.await
.map_err(status_to_arrow_error)?
.into_inner();
let result = result
.message()
.await
.map_err(status_to_arrow_error)?
.unwrap();
let any =
Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
let result: DoPutUpdateResult = any.unpack()?.unwrap();
Ok(result.record_count)
}
pub async fn get_catalogs(&mut self) -> Result<FlightInfo, ArrowError> {
self.get_flight_info_for_command(CommandGetCatalogs {})
.await
}
pub async fn get_db_schemas(
&mut self,
request: CommandGetDbSchemas,
) -> Result<FlightInfo, ArrowError> {
self.get_flight_info_for_command(request).await
}
pub async fn do_get(
&mut self,
ticket: impl IntoRequest<Ticket>,
) -> Result<Streaming<FlightData>, ArrowError> {
let req = self.set_request_headers(ticket.into_request())?;
Ok(self
.flight_client
.do_get(req)
.await
.map_err(status_to_arrow_error)?
.into_inner())
}
pub async fn do_put(
&mut self,
request: impl tonic::IntoStreamingRequest<Message = FlightData>,
) -> Result<Streaming<PutResult>, ArrowError> {
let req = self.set_request_headers(request.into_streaming_request())?;
Ok(self
.flight_client
.do_put(req)
.await
.map_err(status_to_arrow_error)?
.into_inner())
}
pub async fn do_action(
&mut self,
request: impl IntoRequest<Action>,
) -> Result<Streaming<crate::Result>, ArrowError> {
let req = self.set_request_headers(request.into_request())?;
Ok(self
.flight_client
.do_action(req)
.await
.map_err(status_to_arrow_error)?
.into_inner())
}
pub async fn get_tables(
&mut self,
request: CommandGetTables,
) -> Result<FlightInfo, ArrowError> {
self.get_flight_info_for_command(request).await
}
pub async fn get_primary_keys(
&mut self,
request: CommandGetPrimaryKeys,
) -> Result<FlightInfo, ArrowError> {
self.get_flight_info_for_command(request).await
}
pub async fn get_exported_keys(
&mut self,
request: CommandGetExportedKeys,
) -> Result<FlightInfo, ArrowError> {
self.get_flight_info_for_command(request).await
}
pub async fn get_imported_keys(
&mut self,
request: CommandGetImportedKeys,
) -> Result<FlightInfo, ArrowError> {
self.get_flight_info_for_command(request).await
}
pub async fn get_cross_reference(
&mut self,
request: CommandGetCrossReference,
) -> Result<FlightInfo, ArrowError> {
self.get_flight_info_for_command(request).await
}
pub async fn get_table_types(&mut self) -> Result<FlightInfo, ArrowError> {
self.get_flight_info_for_command(CommandGetTableTypes {})
.await
}
pub async fn get_sql_info(
&mut self,
sql_infos: Vec<SqlInfo>,
) -> Result<FlightInfo, ArrowError> {
let request = CommandGetSqlInfo {
info: sql_infos.iter().map(|sql_info| *sql_info as u32).collect(),
};
self.get_flight_info_for_command(request).await
}
pub async fn get_xdbc_type_info(
&mut self,
request: CommandGetXdbcTypeInfo,
) -> Result<FlightInfo, ArrowError> {
self.get_flight_info_for_command(request).await
}
pub async fn prepare(
&mut self,
query: String,
) -> Result<PreparedStatement<Channel>, ArrowError> {
let cmd = ActionCreatePreparedStatementRequest { query };
let action = Action {
r#type: CREATE_PREPARED_STATEMENT.to_string(),
body: cmd.as_any().encode_to_vec().into(),
};
let req = self.set_request_headers(action.into_request())?;
let mut result = self
.flight_client
.do_action(req)
.await
.map_err(status_to_arrow_error)?
.into_inner();
let result = result
.message()
.await
.map_err(status_to_arrow_error)?
.unwrap();
let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
let prepared_result: ActionCreatePreparedStatementResult = any.unpack()?.unwrap();
let dataset_schema = match prepared_result.dataset_schema.len() {
0 => Schema::empty(),
_ => Schema::try_from(IpcMessage(prepared_result.dataset_schema))?,
};
let parameter_schema = match prepared_result.parameter_schema.len() {
0 => Schema::empty(),
_ => Schema::try_from(IpcMessage(prepared_result.parameter_schema))?,
};
Ok(PreparedStatement::new(
self.clone(),
prepared_result.prepared_statement_handle,
dataset_schema,
parameter_schema,
))
}
pub async fn close(&mut self) -> Result<(), ArrowError> {
Ok(())
}
fn set_request_headers<T>(
&self,
mut req: tonic::Request<T>,
) -> Result<tonic::Request<T>, ArrowError> {
for (k, v) in &self.headers {
let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| {
ArrowError::IoError(format!("Cannot convert header key \"{k}\": {e}"))
})?;
let v = v.parse().map_err(|e| {
ArrowError::IoError(format!("Cannot convert header value \"{v}\": {e}"))
})?;
req.metadata_mut().insert(k, v);
}
if let Some(token) = &self.token {
let val = format!("Bearer {token}").parse().map_err(|e| {
ArrowError::IoError(format!("Cannot convert token to header value: {e}"))
})?;
req.metadata_mut().insert("authorization", val);
}
Ok(req)
}
}
#[derive(Debug, Clone)]
pub struct PreparedStatement<T> {
flight_sql_client: FlightSqlServiceClient<T>,
parameter_binding: Option<RecordBatch>,
handle: Bytes,
dataset_schema: Schema,
parameter_schema: Schema,
}
impl PreparedStatement<Channel> {
pub(crate) fn new(
flight_client: FlightSqlServiceClient<Channel>,
handle: impl Into<Bytes>,
dataset_schema: Schema,
parameter_schema: Schema,
) -> Self {
PreparedStatement {
flight_sql_client: flight_client,
parameter_binding: None,
handle: handle.into(),
dataset_schema,
parameter_schema,
}
}
pub async fn execute(&mut self) -> Result<FlightInfo, ArrowError> {
let cmd = CommandPreparedStatementQuery {
prepared_statement_handle: self.handle.clone(),
};
let result = self
.flight_sql_client
.get_flight_info_for_command(cmd)
.await?;
Ok(result)
}
pub async fn execute_update(&mut self) -> Result<i64, ArrowError> {
let cmd = CommandPreparedStatementQuery {
prepared_statement_handle: self.handle.clone(),
};
let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
let mut result = self
.flight_sql_client
.do_put(stream::iter(vec![FlightData {
flight_descriptor: Some(descriptor),
..Default::default()
}]))
.await?;
let result = result
.message()
.await
.map_err(status_to_arrow_error)?
.unwrap();
let any =
Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
let result: DoPutUpdateResult = any.unpack()?.unwrap();
Ok(result.record_count)
}
pub fn parameter_schema(&self) -> Result<&Schema, ArrowError> {
Ok(&self.parameter_schema)
}
pub fn dataset_schema(&self) -> Result<&Schema, ArrowError> {
Ok(&self.dataset_schema)
}
pub fn set_parameters(
&mut self,
parameter_binding: RecordBatch,
) -> Result<(), ArrowError> {
self.parameter_binding = Some(parameter_binding);
Ok(())
}
pub async fn close(mut self) -> Result<(), ArrowError> {
let cmd = ActionClosePreparedStatementRequest {
prepared_statement_handle: self.handle.clone(),
};
let action = Action {
r#type: CLOSE_PREPARED_STATEMENT.to_string(),
body: cmd.as_any().encode_to_vec().into(),
};
let _ = self.flight_sql_client.do_action(action).await?;
Ok(())
}
}
fn decode_error_to_arrow_error(err: prost::DecodeError) -> ArrowError {
ArrowError::IoError(err.to_string())
}
fn status_to_arrow_error(status: tonic::Status) -> ArrowError {
ArrowError::IoError(format!("{status:?}"))
}
pub enum ArrowFlightData {
RecordBatch(RecordBatch),
Schema(Schema),
}
pub fn arrow_data_from_flight_data(
flight_data: FlightData,
arrow_schema_ref: &SchemaRef,
) -> Result<ArrowFlightData, ArrowError> {
let ipc_message = root_as_message(&flight_data.data_header[..]).map_err(|err| {
ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
})?;
match ipc_message.header_type() {
MessageHeader::RecordBatch => {
let ipc_record_batch =
ipc_message.header_as_record_batch().ok_or_else(|| {
ArrowError::ComputeError(
"Unable to convert flight data header to a record batch"
.to_string(),
)
})?;
let dictionaries_by_field = HashMap::new();
let record_batch = read_record_batch(
&Buffer::from(&flight_data.data_body),
ipc_record_batch,
arrow_schema_ref.clone(),
&dictionaries_by_field,
None,
&ipc_message.version(),
)?;
Ok(ArrowFlightData::RecordBatch(record_batch))
}
MessageHeader::Schema => {
let ipc_schema = ipc_message.header_as_schema().ok_or_else(|| {
ArrowError::ComputeError(
"Unable to convert flight data header to a schema".to_string(),
)
})?;
let arrow_schema = fb_to_schema(ipc_schema);
Ok(ArrowFlightData::Schema(arrow_schema))
}
MessageHeader::DictionaryBatch => {
let _ = ipc_message.header_as_dictionary_batch().ok_or_else(|| {
ArrowError::ComputeError(
"Unable to convert flight data header to a dictionary batch"
.to_string(),
)
})?;
Err(ArrowError::NotYetImplemented(
"no idea on how to convert an ipc dictionary batch to an arrow type"
.to_string(),
))
}
MessageHeader::Tensor => {
let _ = ipc_message.header_as_tensor().ok_or_else(|| {
ArrowError::ComputeError(
"Unable to convert flight data header to a tensor".to_string(),
)
})?;
Err(ArrowError::NotYetImplemented(
"no idea on how to convert an ipc tensor to an arrow type".to_string(),
))
}
MessageHeader::SparseTensor => {
let _ = ipc_message.header_as_sparse_tensor().ok_or_else(|| {
ArrowError::ComputeError(
"Unable to convert flight data header to a sparse tensor".to_string(),
)
})?;
Err(ArrowError::NotYetImplemented(
"no idea on how to convert an ipc sparse tensor to an arrow type"
.to_string(),
))
}
_ => Err(ArrowError::ComputeError(format!(
"Unable to convert message with header_type: '{:?}' to arrow data",
ipc_message.header_type()
))),
}
}