use std::{sync::Arc, time::Duration};
use anyhow::{Context, Result, bail};
use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray};
use arrow_cast::{CastOptions, cast_with_options, pretty::pretty_format_batches};
use arrow_flight::{
FlightInfo,
flight_service_client::FlightServiceClient,
sql::{CommandGetDbSchemas, CommandGetTables, client::FlightSqlServiceClient},
};
use arrow_schema::Schema;
use clap::{Parser, Subcommand, ValueEnum};
use core::str;
use futures::TryStreamExt;
use tonic::{
metadata::MetadataMap,
transport::{Channel, ClientTlsConfig, Endpoint},
};
use tracing_log::log::info;
#[derive(Debug, Parser)]
pub struct LoggingArgs {
#[clap(
short = 'v',
long = "verbose",
action = clap::ArgAction::Count,
)]
log_verbose_count: u8,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
pub enum CompressionEncoding {
Gzip,
Deflate,
Zstd,
}
impl From<CompressionEncoding> for tonic::codec::CompressionEncoding {
fn from(encoding: CompressionEncoding) -> Self {
match encoding {
CompressionEncoding::Gzip => Self::Gzip,
CompressionEncoding::Deflate => Self::Deflate,
CompressionEncoding::Zstd => Self::Zstd,
}
}
}
#[derive(Debug, Parser)]
struct ClientArgs {
#[clap(long = "header", short = 'H', value_parser = parse_key_val)]
headers: Vec<(String, String)>,
#[clap(long, requires = "password")]
username: Option<String>,
#[clap(long, requires = "username")]
password: Option<String>,
#[clap(long)]
token: Option<String>,
#[clap(long)]
tls: bool,
#[clap(long, requires = "tls")]
key_log: bool,
#[clap(long)]
host: String,
#[clap(long)]
port: Option<u16>,
#[clap(long, value_delimiter = ',')]
accept_compression: Vec<CompressionEncoding>,
#[clap(long)]
send_compression: Option<CompressionEncoding>,
}
#[derive(Debug, Parser)]
struct Args {
#[clap(flatten)]
logging_args: LoggingArgs,
#[clap(flatten)]
client_args: ClientArgs,
#[clap(subcommand)]
cmd: Command,
}
#[derive(Debug, Subcommand)]
enum Command {
Catalogs,
DbSchemas {
catalog: String,
#[clap(short, long)]
db_schema_filter: Option<String>,
},
Tables {
catalog: String,
#[clap(short, long)]
db_schema_filter: Option<String>,
#[clap(short, long)]
table_filter: Option<String>,
#[clap(long)]
table_types: Vec<String>,
},
TableTypes,
StatementQuery {
query: String,
},
PreparedStatementQuery {
query: String,
#[clap(short, value_parser = parse_key_val)]
params: Vec<(String, String)>,
},
}
#[tokio::main]
async fn main() -> Result<()> {
let args = Args::parse();
setup_logging(args.logging_args)?;
let mut client = setup_client(args.client_args)
.await
.context("setup client")?;
let flight_info = match args.cmd {
Command::Catalogs => client.get_catalogs().await.context("get catalogs")?,
Command::DbSchemas {
catalog,
db_schema_filter,
} => client
.get_db_schemas(CommandGetDbSchemas {
catalog: Some(catalog),
db_schema_filter_pattern: db_schema_filter,
})
.await
.context("get db schemas")?,
Command::Tables {
catalog,
db_schema_filter,
table_filter,
table_types,
} => client
.get_tables(CommandGetTables {
catalog: Some(catalog),
db_schema_filter_pattern: db_schema_filter,
table_name_filter_pattern: table_filter,
table_types,
include_schema: false,
})
.await
.context("get tables")?,
Command::TableTypes => client.get_table_types().await.context("get table types")?,
Command::StatementQuery { query } => client
.execute(query, None)
.await
.context("execute statement")?,
Command::PreparedStatementQuery { query, params } => {
let mut prepared_stmt = client
.prepare(query, None)
.await
.context("prepare statement")?;
if !params.is_empty() {
prepared_stmt
.set_parameters(
construct_record_batch_from_params(
¶ms,
prepared_stmt
.parameter_schema()
.context("get parameter schema")?,
)
.context("construct parameters")?,
)
.context("bind parameters")?;
}
prepared_stmt
.execute()
.await
.context("execute prepared statement")?
}
};
let batches = execute_flight(&mut client, flight_info)
.await
.context("read flight data")?;
let res = pretty_format_batches(batches.as_slice()).context("format results")?;
println!("{res}");
Ok(())
}
async fn execute_flight(
client: &mut FlightSqlServiceClient<Channel>,
info: FlightInfo,
) -> Result<Vec<RecordBatch>> {
let schema = Arc::new(Schema::try_from(info.clone()).context("valid schema")?);
let mut batches = Vec::with_capacity(info.endpoint.len() + 1);
batches.push(RecordBatch::new_empty(schema));
info!("decoded schema");
for endpoint in info.endpoint {
let Some(ticket) = &endpoint.ticket else {
bail!("did not get ticket");
};
let mut flight_data = client.do_get(ticket.clone()).await.context("do get")?;
log_metadata(flight_data.headers(), "header");
let mut endpoint_batches: Vec<_> = (&mut flight_data)
.try_collect()
.await
.context("collect data stream")?;
batches.append(&mut endpoint_batches);
if let Some(trailers) = flight_data.trailers() {
log_metadata(&trailers, "trailer");
}
}
info!("received data");
Ok(batches)
}
fn construct_record_batch_from_params(
params: &[(String, String)],
parameter_schema: &Schema,
) -> Result<RecordBatch> {
let mut items = Vec::<(&String, ArrayRef)>::new();
for (name, value) in params {
let field = parameter_schema.field_with_name(name)?;
let value_as_array = StringArray::new_scalar(value);
let casted = cast_with_options(
value_as_array.get().0,
field.data_type(),
&CastOptions::default(),
)?;
items.push((name, casted))
}
Ok(RecordBatch::try_from_iter(items)?)
}
fn setup_logging(args: LoggingArgs) -> Result<()> {
use tracing_subscriber::{EnvFilter, FmtSubscriber, util::SubscriberInitExt};
tracing_log::LogTracer::init().context("tracing log init")?;
let filter = match args.log_verbose_count {
0 => "warn",
1 => "info",
2 => "debug",
_ => "trace",
};
let filter = EnvFilter::try_new(filter).context("set up log env filter")?;
let subscriber = FmtSubscriber::builder().with_env_filter(filter).finish();
subscriber.try_init().context("init logging subscriber")?;
Ok(())
}
async fn setup_client(args: ClientArgs) -> Result<FlightSqlServiceClient<Channel>> {
let port = args.port.unwrap_or(if args.tls { 443 } else { 80 });
let protocol = if args.tls { "https" } else { "http" };
let mut endpoint = Endpoint::new(format!("{}://{}:{}", protocol, args.host, port))
.context("create endpoint")?
.connect_timeout(Duration::from_secs(20))
.timeout(Duration::from_secs(20))
.tcp_nodelay(true) .tcp_keepalive(Option::Some(Duration::from_secs(3600)))
.http2_keep_alive_interval(Duration::from_secs(300))
.keep_alive_timeout(Duration::from_secs(20))
.keep_alive_while_idle(true);
if args.tls {
let mut tls_config = ClientTlsConfig::new().with_enabled_roots();
if args.key_log {
tls_config = tls_config.use_key_log();
}
endpoint = endpoint
.tls_config(tls_config)
.context("create TLS endpoint")?;
}
let channel = endpoint.connect().await.context("connect to endpoint")?;
let mut client = FlightServiceClient::new(channel);
for encoding in args.accept_compression {
client = client.accept_compressed(encoding.into());
}
if let Some(encoding) = args.send_compression {
client = client.send_compressed(encoding.into());
}
let mut client = FlightSqlServiceClient::new_from_inner(client);
info!("connected");
for (k, v) in args.headers {
client.set_header(k, v);
}
if let Some(token) = args.token {
client.set_token(token);
info!("token set");
}
match (args.username, args.password) {
(None, None) => {}
(Some(username), Some(password)) => {
client
.handshake(&username, &password)
.await
.context("handshake")?;
info!("performed handshake");
}
(Some(_), None) => {
bail!("when username is set, you also need to set a password")
}
(None, Some(_)) => {
bail!("when password is set, you also need to set a username")
}
}
Ok(client)
}
fn parse_key_val(s: &str) -> Result<(String, String), String> {
let pos = s
.find('=')
.ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?;
Ok((s[..pos].to_owned(), s[pos + 1..].to_owned()))
}
fn log_metadata(map: &MetadataMap, what: &'static str) {
for k_v in map.iter() {
match k_v {
tonic::metadata::KeyAndValueRef::Ascii(k, v) => {
info!(
"{}: {}={}",
what,
k.as_str(),
v.to_str().unwrap_or("<invalid>"),
);
}
tonic::metadata::KeyAndValueRef::Binary(k, v) => {
info!(
"{}: {}={}",
what,
k.as_str(),
String::from_utf8_lossy(v.as_ref()),
);
}
}
}
}