use arrow_schema::ArrowError;
use bytes::Bytes;
use paste::paste;
use prost::Message;
#[allow(clippy::all)]
mod r#gen {
#![allow(missing_docs)]
include!("arrow.flight.protocol.sql.rs");
}
pub use r#gen::ActionBeginSavepointRequest;
pub use r#gen::ActionBeginSavepointResult;
pub use r#gen::ActionBeginTransactionRequest;
pub use r#gen::ActionBeginTransactionResult;
pub use r#gen::ActionCancelQueryRequest;
pub use r#gen::ActionCancelQueryResult;
pub use r#gen::ActionClosePreparedStatementRequest;
pub use r#gen::ActionCreatePreparedStatementRequest;
pub use r#gen::ActionCreatePreparedStatementResult;
pub use r#gen::ActionCreatePreparedSubstraitPlanRequest;
pub use r#gen::ActionEndSavepointRequest;
pub use r#gen::ActionEndTransactionRequest;
pub use r#gen::CommandGetCatalogs;
pub use r#gen::CommandGetCrossReference;
pub use r#gen::CommandGetDbSchemas;
pub use r#gen::CommandGetExportedKeys;
pub use r#gen::CommandGetImportedKeys;
pub use r#gen::CommandGetPrimaryKeys;
pub use r#gen::CommandGetSqlInfo;
pub use r#gen::CommandGetTableTypes;
pub use r#gen::CommandGetTables;
pub use r#gen::CommandGetXdbcTypeInfo;
pub use r#gen::CommandPreparedStatementQuery;
pub use r#gen::CommandPreparedStatementUpdate;
pub use r#gen::CommandStatementIngest;
pub use r#gen::CommandStatementQuery;
pub use r#gen::CommandStatementSubstraitPlan;
pub use r#gen::CommandStatementUpdate;
pub use r#gen::DoPutPreparedStatementResult;
pub use r#gen::DoPutUpdateResult;
pub use r#gen::Nullable;
pub use r#gen::Searchable;
pub use r#gen::SqlInfo;
pub use r#gen::SqlNullOrdering;
pub use r#gen::SqlOuterJoinsSupportLevel;
pub use r#gen::SqlSupportedCaseSensitivity;
pub use r#gen::SqlSupportedElementActions;
pub use r#gen::SqlSupportedGroupBy;
pub use r#gen::SqlSupportedPositionedCommands;
pub use r#gen::SqlSupportedResultSetConcurrency;
pub use r#gen::SqlSupportedResultSetType;
pub use r#gen::SqlSupportedSubqueries;
pub use r#gen::SqlSupportedTransaction;
pub use r#gen::SqlSupportedTransactions;
pub use r#gen::SqlSupportedUnions;
pub use r#gen::SqlSupportsConvert;
pub use r#gen::SqlTransactionIsolationLevel;
pub use r#gen::SubstraitPlan;
pub use r#gen::SupportedSqlGrammar;
pub use r#gen::TicketStatementQuery;
pub use r#gen::UpdateDeleteRules;
pub use r#gen::XdbcDataType;
pub use r#gen::XdbcDatetimeSubcode;
pub use r#gen::action_end_transaction_request::EndTransaction;
pub use r#gen::command_statement_ingest::TableDefinitionOptions;
pub use r#gen::command_statement_ingest::table_definition_options::{
TableExistsOption, TableNotExistOption,
};
pub mod client;
pub mod metadata;
pub mod server;
pub use crate::streams::FallibleRequestStream;
pub trait ProstMessageExt: prost::Message + Default {
fn type_url() -> &'static str;
fn as_any(&self) -> Any;
}
macro_rules! as_item {
($i:item) => {
$i
};
}
macro_rules! prost_message_ext {
($($name:tt,)*) => {
paste! {
$(
const [<$name:snake:upper _TYPE_URL>]: &'static str = concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name));
)*
as_item! {
#[derive(Clone, Debug, PartialEq)]
pub enum Command {
$(
#[doc = concat!(stringify!($name), "variant")]
$name($name),)*
Unknown(Any),
}
}
impl Command {
pub fn into_any(self) -> Any {
match self {
$(
Self::$name(cmd) => cmd.as_any(),
)*
Self::Unknown(any) => any,
}
}
pub fn type_url(&self) -> &str {
match self {
$(
Self::$name(_) => [<$name:snake:upper _TYPE_URL>],
)*
Self::Unknown(any) => any.type_url.as_str(),
}
}
}
impl TryFrom<Any> for Command {
type Error = ArrowError;
fn try_from(any: Any) -> Result<Self, Self::Error> {
match any.type_url.as_str() {
$(
[<$name:snake:upper _TYPE_URL>]
=> {
let m: $name = Message::decode(&*any.value).map_err(|err| {
ArrowError::ParseError(format!("Unable to decode Any value: {err}"))
})?;
Ok(Self::$name(m))
}
)*
_ => Ok(Self::Unknown(any)),
}
}
}
$(
impl ProstMessageExt for $name {
fn type_url() -> &'static str {
[<$name:snake:upper _TYPE_URL>]
}
fn as_any(&self) -> Any {
Any {
type_url: <$name>::type_url().to_string(),
value: self.encode_to_vec().into(),
}
}
}
)*
}
};
}
prost_message_ext!(
ActionBeginSavepointRequest,
ActionBeginSavepointResult,
ActionBeginTransactionRequest,
ActionBeginTransactionResult,
ActionCancelQueryRequest,
ActionCancelQueryResult,
ActionClosePreparedStatementRequest,
ActionCreatePreparedStatementRequest,
ActionCreatePreparedStatementResult,
ActionCreatePreparedSubstraitPlanRequest,
ActionEndSavepointRequest,
ActionEndTransactionRequest,
CommandGetCatalogs,
CommandGetCrossReference,
CommandGetDbSchemas,
CommandGetExportedKeys,
CommandGetImportedKeys,
CommandGetPrimaryKeys,
CommandGetSqlInfo,
CommandGetTableTypes,
CommandGetTables,
CommandGetXdbcTypeInfo,
CommandPreparedStatementQuery,
CommandPreparedStatementUpdate,
CommandStatementIngest,
CommandStatementQuery,
CommandStatementSubstraitPlan,
CommandStatementUpdate,
DoPutPreparedStatementResult,
DoPutUpdateResult,
TicketStatementQuery,
);
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Any {
#[prost(string, tag = "1")]
pub type_url: String,
#[prost(bytes = "bytes", tag = "2")]
pub value: Bytes,
}
impl Any {
pub fn is<M: ProstMessageExt>(&self) -> bool {
M::type_url() == self.type_url
}
pub fn unpack<M: ProstMessageExt>(&self) -> Result<Option<M>, ArrowError> {
if !self.is::<M>() {
return Ok(None);
}
let m = Message::decode(&*self.value)
.map_err(|err| ArrowError::ParseError(format!("Unable to decode Any value: {err}")))?;
Ok(Some(m))
}
pub fn pack<M: ProstMessageExt>(message: &M) -> Result<Any, ArrowError> {
Ok(message.as_any())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_type_url() {
assert_eq!(
TicketStatementQuery::type_url(),
"type.googleapis.com/arrow.flight.protocol.sql.TicketStatementQuery"
);
assert_eq!(
CommandStatementQuery::type_url(),
"type.googleapis.com/arrow.flight.protocol.sql.CommandStatementQuery"
);
}
#[test]
fn test_prost_any_pack_unpack() {
let query = CommandStatementQuery {
query: "select 1".to_string(),
transaction_id: None,
};
let any = Any::pack(&query).unwrap();
assert!(any.is::<CommandStatementQuery>());
let unpack_query: CommandStatementQuery = any.unpack().unwrap().unwrap();
assert_eq!(query, unpack_query);
}
#[test]
fn test_command() {
let query = CommandStatementQuery {
query: "select 1".to_string(),
transaction_id: None,
};
let any = Any::pack(&query).unwrap();
let cmd: Command = any.try_into().unwrap();
assert!(matches!(cmd, Command::CommandStatementQuery(_)));
assert_eq!(cmd.type_url(), COMMAND_STATEMENT_QUERY_TYPE_URL);
let any = Any {
type_url: "fake_url".to_string(),
value: Default::default(),
};
let cmd: Command = any.try_into().unwrap();
assert!(matches!(cmd, Command::Unknown(_)));
assert_eq!(cmd.type_url(), "fake_url");
}
}