use std::sync::Arc;
use std::thread::JoinHandle;
use datafusion::prelude::SessionContext;
use tokio::sync::oneshot;
use datafusion_postgres::auth::{AuthManager, DfAuthSource};
use datafusion_postgres::datafusion_pg_catalog::pg_catalog::context::User;
use datafusion_postgres::datafusion_pg_catalog::setup_pg_catalog;
use datafusion_postgres::pgwire::api::PgWireServerHandlers;
use datafusion_postgres::pgwire::api::auth::StartupHandler;
use datafusion_postgres::pgwire::api::auth::DefaultServerParameterProvider;
use datafusion_postgres::pgwire::api::auth::cleartext::CleartextPasswordAuthStartupHandler;
use datafusion_postgres::pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
use datafusion_postgres::{DfSessionService, ServerOptions, serve, serve_with_handlers};
use datapress_core::config::PgwireConfig;
type CleartextStartup =
CleartextPasswordAuthStartupHandler<DfAuthSource, DefaultServerParameterProvider>;
struct DatapressHandlers {
session_service: Arc<DfSessionService>,
startup: Arc<CleartextStartup>,
}
impl PgWireServerHandlers for DatapressHandlers {
fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
self.session_service.clone()
}
fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
self.session_service.clone()
}
fn startup_handler(&self) -> Arc<impl StartupHandler> {
self.startup.clone()
}
}
const PGWIRE_THREAD_STACK: usize = 32 * 1024 * 1024;
pub struct PgwireServer {
shutdown_tx: Option<oneshot::Sender<()>>,
thread: Option<JoinHandle<()>>,
}
impl Drop for PgwireServer {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
if let Some(thread) = self.thread.take() {
let _ = thread.join();
}
}
}
pub fn spawn_pgwire(ctx: SessionContext, cfg: PgwireConfig) -> std::io::Result<PgwireServer> {
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let (listen, port) = (cfg.listen, cfg.port);
let thread = std::thread::Builder::new()
.name("pgwire".to_string())
.stack_size(PGWIRE_THREAD_STACK)
.spawn(move || {
let runtime = match tokio::runtime::Builder::new_multi_thread()
.enable_all()
.thread_name("pgwire-worker")
.thread_stack_size(PGWIRE_THREAD_STACK)
.build()
{
Ok(runtime) => runtime,
Err(e) => {
log::error!("pgwire: failed to build runtime: {e}");
return;
}
};
runtime.block_on(async move {
log::info!("pgwire: PostgreSQL wire protocol listening on {listen}:{port}");
tokio::select! {
_ = shutdown_rx => {
log::info!("pgwire: shutdown signalled; stopping listener");
}
res = serve_pgwire(ctx, cfg) => {
if let Err(e) = res {
log::error!("pgwire: server task exited with error: {e}");
}
}
}
});
})?;
Ok(PgwireServer {
shutdown_tx: Some(shutdown_tx),
thread: Some(thread),
})
}
pub async fn serve_pgwire(ctx: SessionContext, cfg: PgwireConfig) -> std::io::Result<()> {
let mut opts = ServerOptions::new()
.with_host(cfg.listen.to_string())
.with_port(cfg.port);
if let (Some(cert), Some(key)) = (&cfg.tls_cert, &cfg.tls_key) {
opts = opts
.with_tls_cert_path(Some(cert.to_string_lossy().into_owned()))
.with_tls_key_path(Some(key.to_string_lossy().into_owned()));
}
let ctx = Arc::new(ctx);
let auth_manager = Arc::new(AuthManager::new());
let default_catalog = ctx
.copied_config()
.options()
.catalog
.default_catalog
.clone();
setup_pg_catalog(&ctx, &default_catalog, auth_manager.clone())
.map_err(|e| std::io::Error::other(*e))?;
match &cfg.password {
Some(password) => {
auth_manager
.add_user(User {
username: cfg.username.clone(),
password_hash: password.clone(),
roles: Vec::new(),
is_superuser: true,
can_login: true,
connection_limit: None,
})
.await
.map_err(std::io::Error::other)?;
let auth_source = DfAuthSource::new(auth_manager);
let startup = Arc::new(CleartextPasswordAuthStartupHandler::new(
auth_source,
DefaultServerParameterProvider::default(),
));
let handlers = Arc::new(DatapressHandlers {
session_service: Arc::new(DfSessionService::new(ctx)),
startup,
});
serve_with_handlers(handlers, &opts).await
}
None => serve(ctx, &opts).await,
}
}