datafusion_postgres/
lib.rs

1pub mod auth;
2pub(crate) mod client;
3mod handlers;
4pub mod hooks;
5#[cfg(any(test, debug_assertions))]
6pub mod testing;
7
8use std::fs::File;
9use std::io::{BufReader, Error as IOError, ErrorKind};
10use std::sync::Arc;
11
12use datafusion::prelude::SessionContext;
13use getset::{Getters, Setters, WithSetters};
14use log::{info, warn};
15use pgwire::api::PgWireServerHandlers;
16use pgwire::tokio::process_socket;
17use rustls_pemfile::{certs, pkcs8_private_keys};
18use rustls_pki_types::{CertificateDer, PrivateKeyDer};
19use tokio::net::TcpListener;
20use tokio::sync::Semaphore;
21use tokio_rustls::rustls::{self, ServerConfig};
22use tokio_rustls::TlsAcceptor;
23
24use crate::auth::AuthManager;
25use handlers::HandlerFactory;
26pub use handlers::{DfSessionService, Parser};
27pub use hooks::QueryHook;
28
29/// re-exports
30pub use arrow_pg;
31pub use datafusion_pg_catalog;
32pub use pgwire;
33
34#[derive(Getters, Setters, WithSetters, Debug)]
35#[getset(get = "pub", set = "pub", set_with = "pub")]
36pub struct ServerOptions {
37    host: String,
38    port: u16,
39    tls_cert_path: Option<String>,
40    tls_key_path: Option<String>,
41    max_connections: usize,
42}
43
44impl ServerOptions {
45    pub fn new() -> ServerOptions {
46        ServerOptions::default()
47    }
48}
49
50impl Default for ServerOptions {
51    fn default() -> Self {
52        ServerOptions {
53            host: "127.0.0.1".to_string(),
54            port: 5432,
55            tls_cert_path: None,
56            tls_key_path: None,
57            max_connections: 0, // 0 = no limit
58        }
59    }
60}
61
62/// Set up TLS configuration if certificate and key paths are provided
63fn setup_tls(cert_path: &str, key_path: &str) -> Result<TlsAcceptor, IOError> {
64    // Install ring crypto provider for rustls
65    let _ = rustls::crypto::ring::default_provider().install_default();
66
67    let cert = certs(&mut BufReader::new(File::open(cert_path)?))
68        .collect::<Result<Vec<CertificateDer>, IOError>>()?;
69
70    let key = pkcs8_private_keys(&mut BufReader::new(File::open(key_path)?))
71        .map(|key| key.map(PrivateKeyDer::from))
72        .collect::<Result<Vec<PrivateKeyDer>, IOError>>()?
73        .into_iter()
74        .next()
75        .ok_or_else(|| IOError::new(ErrorKind::InvalidInput, "No private key found"))?;
76
77    let config = ServerConfig::builder()
78        .with_no_client_auth()
79        .with_single_cert(cert, key)
80        .map_err(|err| IOError::new(ErrorKind::InvalidInput, err))?;
81
82    Ok(TlsAcceptor::from(Arc::new(config)))
83}
84
85/// Serve the Datafusion `SessionContext` with Postgres protocol.
86pub async fn serve(
87    session_context: Arc<SessionContext>,
88    opts: &ServerOptions,
89    auth_manager: Arc<AuthManager>,
90) -> Result<(), std::io::Error> {
91    // Create the handler factory with authentication
92    let factory = Arc::new(HandlerFactory::new(session_context, auth_manager));
93
94    serve_with_handlers(factory, opts).await
95}
96
97/// Serve the Datafusion `SessionContext` with Postgres protocol, using custom
98/// query processing hooks.
99pub async fn serve_with_hooks(
100    session_context: Arc<SessionContext>,
101    opts: &ServerOptions,
102    auth_manager: Arc<AuthManager>,
103    hooks: Vec<Arc<dyn QueryHook>>,
104) -> Result<(), std::io::Error> {
105    // Create the handler factory with authentication
106    let factory = Arc::new(HandlerFactory::new_with_hooks(
107        session_context,
108        auth_manager,
109        hooks,
110    ));
111
112    serve_with_handlers(factory, opts).await
113}
114
115/// Serve with custom pgwire handlers
116///
117/// This function allows you to rewrite some of the built-in logic including
118/// authentication and query processing. You can Implement your own
119/// `PgWireServerHandlers` by reusing `DfSessionService`.
120pub async fn serve_with_handlers(
121    handlers: Arc<impl PgWireServerHandlers + Sync + Send + 'static>,
122    opts: &ServerOptions,
123) -> Result<(), std::io::Error> {
124    // Set up TLS if configured
125    let tls_acceptor =
126        if let (Some(cert_path), Some(key_path)) = (&opts.tls_cert_path, &opts.tls_key_path) {
127            match setup_tls(cert_path, key_path) {
128                Ok(acceptor) => {
129                    info!("TLS enabled using cert: {cert_path} and key: {key_path}");
130                    Some(acceptor)
131                }
132                Err(e) => {
133                    warn!("Failed to setup TLS: {e}. Running without encryption.");
134                    None
135                }
136            }
137        } else {
138            info!("TLS not configured. Running without encryption.");
139            None
140        };
141
142    // Bind to the specified host and port
143    let server_addr = format!("{}:{}", opts.host, opts.port);
144    let listener = TcpListener::bind(&server_addr).await?;
145    if tls_acceptor.is_some() {
146        info!("Listening on {server_addr} with TLS encryption");
147    } else {
148        info!("Listening on {server_addr} (unencrypted)");
149    }
150
151    // Connection limiter (if configured)
152    let max_conn_count = opts.max_connections;
153    let connection_limiter = if max_conn_count > 0 {
154        Some(Arc::new(Semaphore::new(max_conn_count)))
155    } else {
156        None
157    };
158
159    // Accept incoming connections
160    loop {
161        match listener.accept().await {
162            Ok((socket, addr)) => {
163                let factory_ref = handlers.clone();
164                let tls_acceptor_ref = tls_acceptor.clone();
165                let limiter_ref = connection_limiter.clone();
166
167                tokio::spawn(async move {
168                    // Check connection limit if configured
169                    let _permit = if let Some(ref semaphore) = limiter_ref {
170                        match semaphore.try_acquire() {
171                            Ok(permit) => Some(permit),
172                            Err(_) => {
173                                warn!("Connection rejected from {addr}: max connections ({max_conn_count}) reached");
174                                return;
175                            }
176                        }
177                    } else {
178                        None
179                    };
180
181                    if let Err(e) = process_socket(socket, tls_acceptor_ref, factory_ref).await {
182                        warn!("Error processing socket from {addr}: {e}");
183                    }
184                    // Permit is automatically released when _permit is dropped
185                });
186            }
187            Err(e) => {
188                warn!("Error accept socket: {e}");
189            }
190        }
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn test_server_options_default_max_connections() {
200        let opts = ServerOptions::default();
201        assert_eq!(opts.max_connections, 0); // No limit by default
202    }
203
204    #[test]
205    fn test_server_options_max_connections_configuration() {
206        let opts = ServerOptions::new().with_max_connections(500);
207        assert_eq!(opts.max_connections, 500);
208
209        // Test that 0 means no limit
210        let opts_no_limit = ServerOptions::new().with_max_connections(0);
211        assert_eq!(opts_no_limit.max_connections, 0);
212    }
213}