use std::{sync::Arc, time::Duration};
use arrow_array::RecordBatch;
use arrow_cast::pretty::pretty_format_batches;
use arrow_flight::{
sql::client::FlightSqlServiceClient, utils::flight_data_to_batches, FlightData,
};
use arrow_schema::{ArrowError, Schema};
use clap::Parser;
use futures::TryStreamExt;
use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
use tracing_log::log::info;
#[derive(Debug, Clone)]
struct KeyValue<K, V> {
pub key: K,
pub value: V,
}
impl<K, V> std::str::FromStr for KeyValue<K, V>
where
K: std::str::FromStr,
V: std::str::FromStr,
K::Err: std::fmt::Display,
V::Err: std::fmt::Display,
{
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
let parts = s.splitn(2, ':').collect::<Vec<_>>();
match parts.as_slice() {
[key, value] => {
let key = K::from_str(key).map_err(|e| e.to_string())?;
let value = V::from_str(value).map_err(|e| e.to_string())?;
Ok(Self { key, value })
}
_ => Err(format!(
"Invalid key value pair - expected 'KEY:VALUE' got '{s}'"
)),
}
}
}
#[derive(Debug, Parser)]
struct ClientArgs {
#[clap(long, value_delimiter = ',')]
headers: Vec<KeyValue<String, String>>,
#[clap(long)]
username: Option<String>,
#[clap(long)]
password: Option<String>,
#[clap(long)]
token: Option<String>,
#[clap(long)]
tls: bool,
#[clap(long)]
host: String,
#[clap(long)]
port: Option<u16>,
}
#[derive(Debug, Parser)]
struct Args {
#[clap(flatten)]
client_args: ClientArgs,
query: String,
}
#[tokio::main]
async fn main() {
let args = Args::parse();
setup_logging();
let mut client = setup_client(args.client_args).await.expect("setup client");
let info = client.execute(args.query).await.expect("prepare statement");
info!("got flight info");
let schema = Arc::new(Schema::try_from(info.clone()).expect("valid schema"));
let mut batches = Vec::with_capacity(info.endpoint.len() + 1);
batches.push(RecordBatch::new_empty(schema));
info!("decoded schema");
for endpoint in info.endpoint {
let Some(ticket) = &endpoint.ticket else {
panic!("did not get ticket");
};
let flight_data = client.do_get(ticket.clone()).await.expect("do get");
let flight_data: Vec<FlightData> = flight_data
.try_collect()
.await
.expect("collect data stream");
let mut endpoint_batches = flight_data_to_batches(&flight_data)
.expect("convert flight data to record batches");
batches.append(&mut endpoint_batches);
}
info!("received data");
let res = pretty_format_batches(batches.as_slice()).expect("format results");
println!("{res}");
}
fn setup_logging() {
tracing_log::LogTracer::init().expect("tracing log init");
tracing_subscriber::fmt::init();
}
async fn setup_client(
args: ClientArgs,
) -> Result<FlightSqlServiceClient<Channel>, ArrowError> {
let port = args.port.unwrap_or(if args.tls { 443 } else { 80 });
let protocol = if args.tls { "https" } else { "http" };
let mut endpoint = Endpoint::new(format!("{}://{}:{}", protocol, args.host, port))
.map_err(|_| ArrowError::IoError("Cannot create endpoint".to_string()))?
.connect_timeout(Duration::from_secs(20))
.timeout(Duration::from_secs(20))
.tcp_nodelay(true) .tcp_keepalive(Option::Some(Duration::from_secs(3600)))
.http2_keep_alive_interval(Duration::from_secs(300))
.keep_alive_timeout(Duration::from_secs(20))
.keep_alive_while_idle(true);
if args.tls {
let tls_config = ClientTlsConfig::new();
endpoint = endpoint
.tls_config(tls_config)
.map_err(|_| ArrowError::IoError("Cannot create TLS endpoint".to_string()))?;
}
let channel = endpoint
.connect()
.await
.map_err(|e| ArrowError::IoError(format!("Cannot connect to endpoint: {e}")))?;
let mut client = FlightSqlServiceClient::new(channel);
info!("connected");
for kv in args.headers {
client.set_header(kv.key, kv.value);
}
if let Some(token) = args.token {
client.set_token(token);
info!("token set");
}
match (args.username, args.password) {
(None, None) => {}
(Some(username), Some(password)) => {
client
.handshake(&username, &password)
.await
.expect("handshake");
info!("performed handshake");
}
(Some(_), None) => {
panic!("when username is set, you also need to set a password")
}
(None, Some(_)) => {
panic!("when password is set, you also need to set a username")
}
}
Ok(client)
}