datafusion_postgres/
lib.rs

1mod handlers;
2
3use std::fs::File;
4use std::io::{BufReader, Error as IOError, ErrorKind};
5use std::sync::Arc;
6
7use datafusion::prelude::SessionContext;
8
9pub mod auth;
10use getset::{Getters, Setters, WithSetters};
11use log::{info, warn};
12use pgwire::api::PgWireServerHandlers;
13use pgwire::tokio::process_socket;
14use rustls_pemfile::{certs, pkcs8_private_keys};
15use rustls_pki_types::{CertificateDer, PrivateKeyDer};
16use tokio::net::TcpListener;
17use tokio::sync::Semaphore;
18use tokio_rustls::rustls::{self, ServerConfig};
19use tokio_rustls::TlsAcceptor;
20
21use crate::auth::AuthManager;
22use handlers::HandlerFactory;
23pub use handlers::{DfSessionService, Parser};
24
25/// re-exports
26pub use arrow_pg;
27pub use datafusion_pg_catalog;
28pub use pgwire;
29
30#[derive(Getters, Setters, WithSetters, Debug)]
31#[getset(get = "pub", set = "pub", set_with = "pub")]
32pub struct ServerOptions {
33    host: String,
34    port: u16,
35    tls_cert_path: Option<String>,
36    tls_key_path: Option<String>,
37    max_connections: usize,
38}
39
40impl ServerOptions {
41    pub fn new() -> ServerOptions {
42        ServerOptions::default()
43    }
44}
45
46impl Default for ServerOptions {
47    fn default() -> Self {
48        ServerOptions {
49            host: "127.0.0.1".to_string(),
50            port: 5432,
51            tls_cert_path: None,
52            tls_key_path: None,
53            max_connections: 0, // 0 = no limit
54        }
55    }
56}
57
58/// Set up TLS configuration if certificate and key paths are provided
59fn setup_tls(cert_path: &str, key_path: &str) -> Result<TlsAcceptor, IOError> {
60    // Install ring crypto provider for rustls
61    let _ = rustls::crypto::ring::default_provider().install_default();
62
63    let cert = certs(&mut BufReader::new(File::open(cert_path)?))
64        .collect::<Result<Vec<CertificateDer>, IOError>>()?;
65
66    let key = pkcs8_private_keys(&mut BufReader::new(File::open(key_path)?))
67        .map(|key| key.map(PrivateKeyDer::from))
68        .collect::<Result<Vec<PrivateKeyDer>, IOError>>()?
69        .into_iter()
70        .next()
71        .ok_or_else(|| IOError::new(ErrorKind::InvalidInput, "No private key found"))?;
72
73    let config = ServerConfig::builder()
74        .with_no_client_auth()
75        .with_single_cert(cert, key)
76        .map_err(|err| IOError::new(ErrorKind::InvalidInput, err))?;
77
78    Ok(TlsAcceptor::from(Arc::new(config)))
79}
80
81/// Serve the Datafusion `SessionContext` with Postgres protocol.
82pub async fn serve(
83    session_context: Arc<SessionContext>,
84    opts: &ServerOptions,
85    auth_manager: Arc<AuthManager>,
86) -> Result<(), std::io::Error> {
87    // Create the handler factory with authentication
88    let factory = Arc::new(HandlerFactory::new(session_context, auth_manager));
89
90    serve_with_handlers(factory, opts).await
91}
92
93/// Serve with custom pgwire handlers
94///
95/// This function allows you to rewrite some of the built-in logic including
96/// authentication and query processing. You can Implement your own
97/// `PgWireServerHandlers` by reusing `DfSessionService`.
98pub async fn serve_with_handlers(
99    handlers: Arc<impl PgWireServerHandlers + Sync + Send + 'static>,
100    opts: &ServerOptions,
101) -> Result<(), std::io::Error> {
102    // Set up TLS if configured
103    let tls_acceptor =
104        if let (Some(cert_path), Some(key_path)) = (&opts.tls_cert_path, &opts.tls_key_path) {
105            match setup_tls(cert_path, key_path) {
106                Ok(acceptor) => {
107                    info!("TLS enabled using cert: {cert_path} and key: {key_path}");
108                    Some(acceptor)
109                }
110                Err(e) => {
111                    warn!("Failed to setup TLS: {e}. Running without encryption.");
112                    None
113                }
114            }
115        } else {
116            info!("TLS not configured. Running without encryption.");
117            None
118        };
119
120    // Bind to the specified host and port
121    let server_addr = format!("{}:{}", opts.host, opts.port);
122    let listener = TcpListener::bind(&server_addr).await?;
123    if tls_acceptor.is_some() {
124        info!("Listening on {server_addr} with TLS encryption");
125    } else {
126        info!("Listening on {server_addr} (unencrypted)");
127    }
128
129    // Connection limiter (if configured)
130    let max_conn_count = opts.max_connections;
131    let connection_limiter = if max_conn_count > 0 {
132        Some(Arc::new(Semaphore::new(max_conn_count)))
133    } else {
134        None
135    };
136
137    // Accept incoming connections
138    loop {
139        match listener.accept().await {
140            Ok((socket, addr)) => {
141                let factory_ref = handlers.clone();
142                let tls_acceptor_ref = tls_acceptor.clone();
143                let limiter_ref = connection_limiter.clone();
144
145                tokio::spawn(async move {
146                    // Check connection limit if configured
147                    let _permit = if let Some(ref semaphore) = limiter_ref {
148                        match semaphore.try_acquire() {
149                            Ok(permit) => Some(permit),
150                            Err(_) => {
151                                warn!("Connection rejected from {addr}: max connections ({max_conn_count}) reached");
152                                return;
153                            }
154                        }
155                    } else {
156                        None
157                    };
158
159                    if let Err(e) = process_socket(socket, tls_acceptor_ref, factory_ref).await {
160                        warn!("Error processing socket from {addr}: {e}");
161                    }
162                    // Permit is automatically released when _permit is dropped
163                });
164            }
165            Err(e) => {
166                warn!("Error accept socket: {e}");
167            }
168        }
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn test_server_options_default_max_connections() {
178        let opts = ServerOptions::default();
179        assert_eq!(opts.max_connections, 0); // No limit by default
180    }
181
182    #[test]
183    fn test_server_options_max_connections_configuration() {
184        let opts = ServerOptions::new().with_max_connections(500);
185        assert_eq!(opts.max_connections, 500);
186
187        // Test that 0 means no limit
188        let opts_no_limit = ServerOptions::new().with_max_connections(0);
189        assert_eq!(opts_no_limit.max_connections, 0);
190    }
191}