pub mod service;
use crate::args::{Command, DftArgs};
use crate::config::AppConfig;
use crate::db::register_db;
use crate::execution::AppExecution;
use color_eyre::{eyre::eyre, Result};
use datafusion_app::config::merge_configs;
use datafusion_app::extensions::DftSessionStateBuilder;
use datafusion_app::local::ExecutionContext;
use log::info;
use service::FlightSqlServiceImpl;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tonic::transport::Server;
#[cfg(feature = "flightsql")]
use tower_http::validate_request::ValidateRequestHeaderLayer;
use super::try_start_metrics_server;
const DEFAULT_TIMEOUT_SECONDS: u64 = 60;
#[allow(deprecated)]
pub fn create_server_handle(
config: &AppConfig,
flightsql: FlightSqlServiceImpl,
listener: TcpListener,
rx: oneshot::Receiver<()>,
) -> Result<JoinHandle<std::result::Result<(), tonic::transport::Error>>> {
let server_timeout = Duration::from_secs(DEFAULT_TIMEOUT_SECONDS);
let mut server_builder = Server::builder().timeout(server_timeout);
let shutdown_future = async move {
rx.await.ok();
};
if cfg!(feature = "flightsql") {
match (
&config.flightsql_server.auth.basic_auth,
&config.flightsql_server.auth.bearer_token,
) {
(Some(_), Some(_)) => Err(eyre!("Only one auth type can be used at a time")),
(Some(basic), None) => {
let basic_auth_layer =
ValidateRequestHeaderLayer::basic(&basic.username, &basic.password);
let f = server_builder
.layer(basic_auth_layer)
.add_service(flightsql.service())
.serve_with_incoming_shutdown(
tokio_stream::wrappers::TcpListenerStream::new(listener),
shutdown_future,
);
Ok(tokio::task::spawn(f))
}
(None, Some(token)) => {
let bearer_auth_layer = ValidateRequestHeaderLayer::bearer(token);
let f = server_builder
.layer(bearer_auth_layer)
.add_service(flightsql.service())
.serve_with_incoming_shutdown(
tokio_stream::wrappers::TcpListenerStream::new(listener),
shutdown_future,
);
Ok(tokio::task::spawn(f))
}
(None, None) => {
let f = server_builder
.add_service(flightsql.service())
.serve_with_incoming_shutdown(
tokio_stream::wrappers::TcpListenerStream::new(listener),
shutdown_future,
);
Ok(tokio::task::spawn(f))
}
}
} else {
let f = server_builder
.add_service(flightsql.service())
.serve_with_incoming_shutdown(
tokio_stream::wrappers::TcpListenerStream::new(listener),
shutdown_future,
);
Ok(tokio::task::spawn(f))
}
}
pub struct FlightSqlApp {
shutdown: Option<tokio::sync::oneshot::Sender<()>>,
pub addr: SocketAddr,
handle: Option<JoinHandle<Result<(), tonic::transport::Error>>>,
}
impl FlightSqlApp {
pub async fn try_new(
app_execution: AppExecution,
config: &AppConfig,
addr: SocketAddr,
metrics_addr: SocketAddr,
) -> Result<Self> {
info!("listening to FlightSQL on {addr}");
let flightsql = service::FlightSqlServiceImpl::new(app_execution);
let listener = TcpListener::bind(addr).await.unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let handle = create_server_handle(config, flightsql, listener, rx)?;
try_start_metrics_server(metrics_addr)?;
let app = Self {
shutdown: Some(tx),
addr: metrics_addr,
handle: Some(handle),
};
Ok(app)
}
pub async fn shutdown_and_wait(mut self) {
if let Some(shutdown) = self.shutdown.take() {
shutdown.send(()).expect("server quit early");
}
if let Some(handle) = self.handle.take() {
handle
.await
.expect("task join error (panic?)")
.expect("Server Error found at shutdown");
}
}
pub async fn run(self) {
if let Some(handle) = self.handle {
handle
.await
.expect("Unable to run server task")
.expect("Server Error found at shutdown");
} else {
panic!("Server task not found");
}
}
}
pub async fn try_run(cli: DftArgs, config: AppConfig) -> Result<()> {
let merged_exec_config = merge_configs(
config.shared.clone(),
config.flightsql_server.execution.clone(),
);
let session_state_builder = DftSessionStateBuilder::try_new(Some(merged_exec_config.clone()))?
.with_extensions()
.await?;
let session_state = session_state_builder.build()?;
let execution_ctx = ExecutionContext::try_new(
&merged_exec_config,
session_state,
crate::APP_NAME,
env!("CARGO_PKG_VERSION"),
)?;
if cli.run_ddl {
execution_ctx.execute_ddl().await;
}
let app_execution = AppExecution::new(execution_ctx);
let (addr, metrics_addr) = if let Some(cmd) = cli.command.clone() {
match cmd {
Command::ServeFlightSql {
addr: Some(addr),
metrics_addr: Some(metrics_addr),
..
} => (addr, metrics_addr),
Command::ServeFlightSql {
addr: Some(addr),
metrics_addr: None,
..
} => (addr, config.flightsql_server.server_metrics_addr),
Command::ServeFlightSql {
addr: None,
metrics_addr: Some(metrics_addr),
..
} => (
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 50051),
metrics_addr,
),
_ => (
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 50051),
config.flightsql_server.server_metrics_addr,
),
}
} else {
(
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 50051),
config.flightsql_server.server_metrics_addr,
)
};
register_db(app_execution.session_ctx(), &config.db).await?;
let app = FlightSqlApp::try_new(app_execution, &config, addr, metrics_addr).await?;
app.run().await;
Ok(())
}