datafusion_postgres/
lib.rs

1mod handlers;
2pub mod pg_catalog;
3
4use std::fs::File;
5use std::io::{BufReader, Error as IOError, ErrorKind};
6use std::sync::Arc;
7
8use datafusion::prelude::SessionContext;
9
10pub mod auth;
11use getset::{Getters, Setters, WithSetters};
12use log::{info, warn};
13use pgwire::api::PgWireServerHandlers;
14use pgwire::tokio::process_socket;
15use rustls_pemfile::{certs, pkcs8_private_keys};
16use rustls_pki_types::{CertificateDer, PrivateKeyDer};
17use tokio::net::TcpListener;
18use tokio_rustls::rustls::{self, ServerConfig};
19use tokio_rustls::TlsAcceptor;
20
21use crate::auth::AuthManager;
22use handlers::HandlerFactory;
23pub use handlers::{DfSessionService, Parser};
24
25/// re-exports
26pub use arrow_pg;
27pub use pgwire;
28
29#[derive(Getters, Setters, WithSetters, Debug)]
30#[getset(get = "pub", set = "pub", set_with = "pub")]
31pub struct ServerOptions {
32    host: String,
33    port: u16,
34    tls_cert_path: Option<String>,
35    tls_key_path: Option<String>,
36}
37
38impl ServerOptions {
39    pub fn new() -> ServerOptions {
40        ServerOptions::default()
41    }
42}
43
44impl Default for ServerOptions {
45    fn default() -> Self {
46        ServerOptions {
47            host: "127.0.0.1".to_string(),
48            port: 5432,
49            tls_cert_path: None,
50            tls_key_path: None,
51        }
52    }
53}
54
55/// Set up TLS configuration if certificate and key paths are provided
56fn setup_tls(cert_path: &str, key_path: &str) -> Result<TlsAcceptor, IOError> {
57    // Install ring crypto provider for rustls
58    let _ = rustls::crypto::ring::default_provider().install_default();
59
60    let cert = certs(&mut BufReader::new(File::open(cert_path)?))
61        .collect::<Result<Vec<CertificateDer>, IOError>>()?;
62
63    let key = pkcs8_private_keys(&mut BufReader::new(File::open(key_path)?))
64        .map(|key| key.map(PrivateKeyDer::from))
65        .collect::<Result<Vec<PrivateKeyDer>, IOError>>()?
66        .into_iter()
67        .next()
68        .ok_or_else(|| IOError::new(ErrorKind::InvalidInput, "No private key found"))?;
69
70    let config = ServerConfig::builder()
71        .with_no_client_auth()
72        .with_single_cert(cert, key)
73        .map_err(|err| IOError::new(ErrorKind::InvalidInput, err))?;
74
75    Ok(TlsAcceptor::from(Arc::new(config)))
76}
77
78/// Serve the Datafusion `SessionContext` with Postgres protocol.
79pub async fn serve(
80    session_context: Arc<SessionContext>,
81    opts: &ServerOptions,
82) -> Result<(), std::io::Error> {
83    // Create authentication manager
84    let auth_manager = Arc::new(AuthManager::new());
85
86    // Create the handler factory with authentication
87    let factory = Arc::new(HandlerFactory::new(session_context, auth_manager));
88
89    serve_with_handlers(factory, opts).await
90}
91
92/// Serve with custom pgwire handlers
93///
94/// This function allows you to rewrite some of the built-in logic including
95/// authentication and query processing. You can Implement your own
96/// `PgWireServerHandlers` by reusing `DfSessionService`.
97pub async fn serve_with_handlers(
98    handlers: Arc<impl PgWireServerHandlers + Sync + Send + 'static>,
99    opts: &ServerOptions,
100) -> Result<(), std::io::Error> {
101    // Set up TLS if configured
102    let tls_acceptor =
103        if let (Some(cert_path), Some(key_path)) = (&opts.tls_cert_path, &opts.tls_key_path) {
104            match setup_tls(cert_path, key_path) {
105                Ok(acceptor) => {
106                    info!("TLS enabled using cert: {cert_path} and key: {key_path}");
107                    Some(acceptor)
108                }
109                Err(e) => {
110                    warn!("Failed to setup TLS: {e}. Running without encryption.");
111                    None
112                }
113            }
114        } else {
115            info!("TLS not configured. Running without encryption.");
116            None
117        };
118
119    // Bind to the specified host and port
120    let server_addr = format!("{}:{}", opts.host, opts.port);
121    let listener = TcpListener::bind(&server_addr).await?;
122    if tls_acceptor.is_some() {
123        info!("Listening on {server_addr} with TLS encryption");
124    } else {
125        info!("Listening on {server_addr} (unencrypted)");
126    }
127
128    // Accept incoming connections
129    loop {
130        match listener.accept().await {
131            Ok((socket, _addr)) => {
132                let factory_ref = handlers.clone();
133                let tls_acceptor_ref = tls_acceptor.clone();
134
135                tokio::spawn(async move {
136                    if let Err(e) = process_socket(socket, tls_acceptor_ref, factory_ref).await {
137                        warn!("Error processing socket: {e}");
138                    }
139                });
140            }
141            Err(e) => {
142                warn!("Error accept socket: {e}");
143            }
144        }
145    }
146}