use std::sync::Arc;
use std::thread::JoinHandle;
use async_trait::async_trait;
use datafusion::common::ParamValues;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::SessionContext;
use datafusion::sql::sqlparser::ast::Statement;
use tokio::sync::oneshot;
use datafusion_postgres::auth::{AuthManager, DfAuthSource};
use datafusion_postgres::datafusion_pg_catalog::pg_catalog::context::User;
use datafusion_postgres::datafusion_pg_catalog::setup_pg_catalog;
use datafusion_postgres::hooks::HookClient;
use datafusion_postgres::hooks::cursor::CursorStatementHook;
use datafusion_postgres::hooks::set_show::SetShowHook;
use datafusion_postgres::hooks::transactions::TransactionStatementHook;
use datafusion_postgres::pgwire::api::PgWireServerHandlers;
use datafusion_postgres::pgwire::api::ClientInfo;
use datafusion_postgres::pgwire::api::auth::StartupHandler;
use datafusion_postgres::pgwire::api::auth::DefaultServerParameterProvider;
use datafusion_postgres::pgwire::api::auth::cleartext::CleartextPasswordAuthStartupHandler;
use datafusion_postgres::pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
use datafusion_postgres::pgwire::api::results::{Response, Tag};
use datafusion_postgres::pgwire::error::PgWireResult;
use datafusion_postgres::{
DfSessionService, QueryHook, ServerOptions, serve_with_handlers, serve_with_hooks,
};
use datapress_core::config::PgwireConfig;
type CleartextStartup =
CleartextPasswordAuthStartupHandler<DfAuthSource, DefaultServerParameterProvider>;
struct DatapressHandlers {
session_service: Arc<DfSessionService>,
startup: Arc<CleartextStartup>,
}
impl PgWireServerHandlers for DatapressHandlers {
fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
self.session_service.clone()
}
fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
self.session_service.clone()
}
fn startup_handler(&self) -> Arc<impl StartupHandler> {
self.startup.clone()
}
}
#[derive(Debug)]
struct SessionResetHook;
impl SessionResetHook {
fn tag_for(statement: &Statement) -> Option<String> {
match statement {
Statement::Discard { object_type } => Some(format!("DISCARD {object_type}")),
Statement::Deallocate { .. } => Some("DEALLOCATE".to_string()),
Statement::Reset(_) => Some("RESET".to_string()),
Statement::UNLISTEN { .. } => Some("UNLISTEN".to_string()),
_ => None,
}
}
}
#[async_trait]
impl QueryHook for SessionResetHook {
async fn handle_simple_query(
&self,
statement: &Statement,
_session_context: &SessionContext,
_client: &mut dyn HookClient,
) -> Option<PgWireResult<Response>> {
let tag = Self::tag_for(statement)?;
log::debug!("pgwire: swallowing session-maintenance statement: {statement}");
Some(Ok(Response::Execution(Tag::new(&tag))))
}
async fn handle_extended_parse_query(
&self,
statement: &Statement,
_session_context: &SessionContext,
_client: &(dyn ClientInfo + Send + Sync),
) -> Option<PgWireResult<LogicalPlan>> {
if Self::tag_for(statement).is_some() {
let schema = datafusion::common::DFSchema::empty();
return Some(Ok(LogicalPlan::EmptyRelation(
datafusion::logical_expr::EmptyRelation {
produce_one_row: false,
schema: Arc::new(schema),
},
)));
}
None
}
async fn handle_extended_query(
&self,
statement: &Statement,
_logical_plan: &LogicalPlan,
_params: &ParamValues,
session_context: &SessionContext,
client: &mut dyn HookClient,
) -> Option<PgWireResult<Response>> {
self.handle_simple_query(statement, session_context, client)
.await
}
}
fn query_hooks() -> Vec<Arc<dyn QueryHook>> {
vec![
Arc::new(CursorStatementHook),
Arc::new(SetShowHook),
Arc::new(TransactionStatementHook),
Arc::new(SessionResetHook),
]
}
const PGWIRE_THREAD_STACK: usize = 32 * 1024 * 1024;
pub struct PgwireServer {
shutdown_tx: Option<oneshot::Sender<()>>,
thread: Option<JoinHandle<()>>,
}
impl Drop for PgwireServer {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(thread) = self.thread.take() {
let _ = thread.join();
}
}
}
pub fn spawn_pgwire(ctx: SessionContext, cfg: PgwireConfig) -> std::io::Result<PgwireServer> {
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let (listen, port) = (cfg.listen, cfg.port);
let thread = std::thread::Builder::new()
.name("pgwire".to_string())
.stack_size(PGWIRE_THREAD_STACK)
.spawn(move || {
let runtime = match tokio::runtime::Builder::new_multi_thread()
.enable_all()
.thread_name("pgwire-worker")
.thread_stack_size(PGWIRE_THREAD_STACK)
.build()
{
Ok(runtime) => runtime,
Err(e) => {
log::error!("pgwire: failed to build runtime: {e}");
return;
}
};
runtime.block_on(async move {
log::info!("pgwire: PostgreSQL wire protocol listening on {listen}:{port}");
tokio::select! {
_ = shutdown_rx => {
log::info!("pgwire: shutdown signalled; stopping listener");
}
res = serve_pgwire(ctx, cfg) => {
if let Err(e) = res {
log::error!("pgwire: server task exited with error: {e}");
}
}
}
});
})?;
Ok(PgwireServer {
shutdown_tx: Some(shutdown_tx),
thread: Some(thread),
})
}
pub async fn serve_pgwire(ctx: SessionContext, cfg: PgwireConfig) -> std::io::Result<()> {
let mut opts = ServerOptions::new()
.with_host(cfg.listen.to_string())
.with_port(cfg.port);
if let (Some(cert), Some(key)) = (&cfg.tls_cert, &cfg.tls_key) {
opts = opts
.with_tls_cert_path(Some(cert.to_string_lossy().into_owned()))
.with_tls_key_path(Some(key.to_string_lossy().into_owned()));
}
let ctx = Arc::new(ctx);
let auth_manager = Arc::new(AuthManager::new());
let default_catalog = ctx
.copied_config()
.options()
.catalog
.default_catalog
.clone();
setup_pg_catalog(&ctx, &default_catalog, auth_manager.clone())
.map_err(|e| std::io::Error::other(*e))?;
match &cfg.password {
Some(password) => {
auth_manager
.add_user(User {
username: cfg.username.clone(),
password_hash: password.clone(),
roles: Vec::new(),
is_superuser: true,
can_login: true,
connection_limit: None,
})
.await
.map_err(std::io::Error::other)?;
let auth_source = DfAuthSource::new(auth_manager);
let startup = Arc::new(CleartextPasswordAuthStartupHandler::new(
auth_source,
DefaultServerParameterProvider::default(),
));
let handlers = Arc::new(DatapressHandlers {
session_service: Arc::new(DfSessionService::new_with_hooks(ctx, query_hooks())),
startup,
});
serve_with_handlers(handlers, &opts).await
}
None => serve_with_hooks(ctx, &opts, query_hooks()).await,
}
}