datafusion_postgres/
lib.rs

1pub mod auth;
2pub(crate) mod client;
3mod handlers;
4pub mod hooks;
5#[cfg(any(test, debug_assertions))]
6pub mod testing;
7
8use std::fs::File;
9use std::io::{BufReader, Error as IOError, ErrorKind};
10use std::sync::Arc;
11
12use datafusion::prelude::SessionContext;
13use getset::{Getters, Setters, WithSetters};
14use log::{info, warn};
15use pgwire::api::PgWireServerHandlers;
16use pgwire::tokio::process_socket;
17use rustls_pemfile::{certs, pkcs8_private_keys};
18use rustls_pki_types::{CertificateDer, PrivateKeyDer};
19use tokio::net::TcpListener;
20use tokio::sync::Semaphore;
21use tokio_rustls::rustls::{self, ServerConfig};
22use tokio_rustls::TlsAcceptor;
23
24use handlers::HandlerFactory;
25pub use handlers::{DfSessionService, Parser};
26pub use hooks::QueryHook;
27
28/// re-exports
29pub use arrow_pg;
30pub use datafusion_pg_catalog;
31pub use pgwire;
32
33#[derive(Getters, Setters, WithSetters, Debug)]
34#[getset(get = "pub", set = "pub", set_with = "pub")]
35pub struct ServerOptions {
36    host: String,
37    port: u16,
38    tls_cert_path: Option<String>,
39    tls_key_path: Option<String>,
40    max_connections: usize,
41}
42
43impl ServerOptions {
44    pub fn new() -> ServerOptions {
45        ServerOptions::default()
46    }
47}
48
49impl Default for ServerOptions {
50    fn default() -> Self {
51        ServerOptions {
52            host: "127.0.0.1".to_string(),
53            port: 5432,
54            tls_cert_path: None,
55            tls_key_path: None,
56            max_connections: 0, // 0 = no limit
57        }
58    }
59}
60
61/// Set up TLS configuration if certificate and key paths are provided
62fn setup_tls(cert_path: &str, key_path: &str) -> Result<TlsAcceptor, IOError> {
63    // Install ring crypto provider for rustls
64    let _ = rustls::crypto::ring::default_provider().install_default();
65
66    let cert = certs(&mut BufReader::new(File::open(cert_path)?))
67        .collect::<Result<Vec<CertificateDer>, IOError>>()?;
68
69    let key = pkcs8_private_keys(&mut BufReader::new(File::open(key_path)?))
70        .map(|key| key.map(PrivateKeyDer::from))
71        .collect::<Result<Vec<PrivateKeyDer>, IOError>>()?
72        .into_iter()
73        .next()
74        .ok_or_else(|| IOError::new(ErrorKind::InvalidInput, "No private key found"))?;
75
76    let config = ServerConfig::builder()
77        .with_no_client_auth()
78        .with_single_cert(cert, key)
79        .map_err(|err| IOError::new(ErrorKind::InvalidInput, err))?;
80
81    Ok(TlsAcceptor::from(Arc::new(config)))
82}
83
84/// Serve the Datafusion `SessionContext` with Postgres protocol.
85pub async fn serve(
86    session_context: Arc<SessionContext>,
87    opts: &ServerOptions,
88) -> Result<(), std::io::Error> {
89    // Create the handler factory with authentication
90    let factory = Arc::new(HandlerFactory::new(session_context));
91
92    serve_with_handlers(factory, opts).await
93}
94
95/// Serve the Datafusion `SessionContext` with Postgres protocol, using custom
96/// query processing hooks.
97pub async fn serve_with_hooks(
98    session_context: Arc<SessionContext>,
99    opts: &ServerOptions,
100    hooks: Vec<Arc<dyn QueryHook>>,
101) -> Result<(), std::io::Error> {
102    // Create the handler factory with authentication
103    let factory = Arc::new(HandlerFactory::new_with_hooks(session_context, hooks));
104
105    serve_with_handlers(factory, opts).await
106}
107
108/// Serve with custom pgwire handlers
109///
110/// This function allows you to rewrite some of the built-in logic including
111/// authentication and query processing. You can Implement your own
112/// `PgWireServerHandlers` by reusing `DfSessionService`.
113pub async fn serve_with_handlers(
114    handlers: Arc<impl PgWireServerHandlers + Sync + Send + 'static>,
115    opts: &ServerOptions,
116) -> Result<(), std::io::Error> {
117    // Set up TLS if configured
118    let tls_acceptor =
119        if let (Some(cert_path), Some(key_path)) = (&opts.tls_cert_path, &opts.tls_key_path) {
120            match setup_tls(cert_path, key_path) {
121                Ok(acceptor) => {
122                    info!("TLS enabled using cert: {cert_path} and key: {key_path}");
123                    Some(acceptor)
124                }
125                Err(e) => {
126                    warn!("Failed to setup TLS: {e}. Running without encryption.");
127                    None
128                }
129            }
130        } else {
131            info!("TLS not configured. Running without encryption.");
132            None
133        };
134
135    // Bind to the specified host and port
136    let server_addr = format!("{}:{}", opts.host, opts.port);
137    let listener = TcpListener::bind(&server_addr).await?;
138    if tls_acceptor.is_some() {
139        info!("Listening on {server_addr} with TLS encryption");
140    } else {
141        info!("Listening on {server_addr} (unencrypted)");
142    }
143
144    // Connection limiter (if configured)
145    let max_conn_count = opts.max_connections;
146    let connection_limiter = if max_conn_count > 0 {
147        Some(Arc::new(Semaphore::new(max_conn_count)))
148    } else {
149        None
150    };
151
152    // Accept incoming connections
153    loop {
154        match listener.accept().await {
155            Ok((socket, addr)) => {
156                let factory_ref = handlers.clone();
157                let tls_acceptor_ref = tls_acceptor.clone();
158                let limiter_ref = connection_limiter.clone();
159
160                tokio::spawn(async move {
161                    // Check connection limit if configured
162                    let _permit = if let Some(ref semaphore) = limiter_ref {
163                        match semaphore.try_acquire() {
164                            Ok(permit) => Some(permit),
165                            Err(_) => {
166                                warn!("Connection rejected from {addr}: max connections ({max_conn_count}) reached");
167                                return;
168                            }
169                        }
170                    } else {
171                        None
172                    };
173
174                    if let Err(e) = process_socket(socket, tls_acceptor_ref, factory_ref).await {
175                        warn!("Error processing socket from {addr}: {e}");
176                    }
177                    // Permit is automatically released when _permit is dropped
178                });
179            }
180            Err(e) => {
181                warn!("Error accept socket: {e}");
182            }
183        }
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn test_server_options_default_max_connections() {
193        let opts = ServerOptions::default();
194        assert_eq!(opts.max_connections, 0); // No limit by default
195    }
196
197    #[test]
198    fn test_server_options_max_connections_configuration() {
199        let opts = ServerOptions::new().with_max_connections(500);
200        assert_eq!(opts.max_connections, 500);
201
202        // Test that 0 means no limit
203        let opts_no_limit = ServerOptions::new().with_max_connections(0);
204        assert_eq!(opts_no_limit.max_connections, 0);
205    }
206}