datafusion_postgres/
lib.rs1mod handlers;
2pub mod pg_catalog;
3
4use std::fs::File;
5use std::io::{BufReader, Error as IOError, ErrorKind};
6use std::sync::Arc;
7
8use datafusion::prelude::SessionContext;
9
10pub mod auth;
11use getset::{Getters, Setters, WithSetters};
12use log::{info, warn};
13use pgwire::api::PgWireServerHandlers;
14use pgwire::tokio::process_socket;
15use rustls_pemfile::{certs, pkcs8_private_keys};
16use rustls_pki_types::{CertificateDer, PrivateKeyDer};
17use tokio::net::TcpListener;
18use tokio_rustls::rustls::{self, ServerConfig};
19use tokio_rustls::TlsAcceptor;
20
21use crate::auth::AuthManager;
22use handlers::HandlerFactory;
23pub use handlers::{DfSessionService, Parser};
24
25#[derive(Getters, Setters, WithSetters, Debug)]
26#[getset(get = "pub", set = "pub", set_with = "pub")]
27pub struct ServerOptions {
28 host: String,
29 port: u16,
30 tls_cert_path: Option<String>,
31 tls_key_path: Option<String>,
32}
33
34impl ServerOptions {
35 pub fn new() -> ServerOptions {
36 ServerOptions::default()
37 }
38}
39
40impl Default for ServerOptions {
41 fn default() -> Self {
42 ServerOptions {
43 host: "127.0.0.1".to_string(),
44 port: 5432,
45 tls_cert_path: None,
46 tls_key_path: None,
47 }
48 }
49}
50
51fn setup_tls(cert_path: &str, key_path: &str) -> Result<TlsAcceptor, IOError> {
53 let _ = rustls::crypto::ring::default_provider().install_default();
55
56 let cert = certs(&mut BufReader::new(File::open(cert_path)?))
57 .collect::<Result<Vec<CertificateDer>, IOError>>()?;
58
59 let key = pkcs8_private_keys(&mut BufReader::new(File::open(key_path)?))
60 .map(|key| key.map(PrivateKeyDer::from))
61 .collect::<Result<Vec<PrivateKeyDer>, IOError>>()?
62 .into_iter()
63 .next()
64 .ok_or_else(|| IOError::new(ErrorKind::InvalidInput, "No private key found"))?;
65
66 let config = ServerConfig::builder()
67 .with_no_client_auth()
68 .with_single_cert(cert, key)
69 .map_err(|err| IOError::new(ErrorKind::InvalidInput, err))?;
70
71 Ok(TlsAcceptor::from(Arc::new(config)))
72}
73
74pub async fn serve(
76 session_context: Arc<SessionContext>,
77 opts: &ServerOptions,
78) -> Result<(), std::io::Error> {
79 let auth_manager = Arc::new(AuthManager::new());
81
82 let factory = Arc::new(HandlerFactory::new(session_context, auth_manager));
84
85 serve_with_handlers(factory, opts).await
86}
87
88pub async fn serve_with_handlers(
94 handlers: Arc<impl PgWireServerHandlers + Sync + Send + 'static>,
95 opts: &ServerOptions,
96) -> Result<(), std::io::Error> {
97 let tls_acceptor =
99 if let (Some(cert_path), Some(key_path)) = (&opts.tls_cert_path, &opts.tls_key_path) {
100 match setup_tls(cert_path, key_path) {
101 Ok(acceptor) => {
102 info!("TLS enabled using cert: {cert_path} and key: {key_path}");
103 Some(acceptor)
104 }
105 Err(e) => {
106 warn!("Failed to setup TLS: {e}. Running without encryption.");
107 None
108 }
109 }
110 } else {
111 info!("TLS not configured. Running without encryption.");
112 None
113 };
114
115 let server_addr = format!("{}:{}", opts.host, opts.port);
117 let listener = TcpListener::bind(&server_addr).await?;
118 if tls_acceptor.is_some() {
119 info!("Listening on {server_addr} with TLS encryption");
120 } else {
121 info!("Listening on {server_addr} (unencrypted)");
122 }
123
124 loop {
126 match listener.accept().await {
127 Ok((socket, _addr)) => {
128 let factory_ref = handlers.clone();
129 let tls_acceptor_ref = tls_acceptor.clone();
130
131 tokio::spawn(async move {
132 if let Err(e) = process_socket(socket, tls_acceptor_ref, factory_ref).await {
133 warn!("Error processing socket: {e}");
134 }
135 });
136 }
137 Err(e) => {
138 warn!("Error accept socket: {e}");
139 }
140 }
141 }
142}