datafusion_table_providers/sql/db_connection_pool/
postgrespool.rs

1use std::{collections::HashMap, path::PathBuf, str::FromStr, sync::Arc};
2
3use crate::{
4    util::{self, ns_lookup::verify_ns_lookup_and_tcp_connect},
5    UnsupportedTypeAction,
6};
7use async_trait::async_trait;
8use bb8::ErrorSink;
9use bb8_postgres::{
10    tokio_postgres::{config::Host, types::ToSql, Config},
11    PostgresConnectionManager,
12};
13use native_tls::{Certificate, TlsConnector};
14use postgres_native_tls::MakeTlsConnector;
15use secrecy::{ExposeSecret, SecretBox, SecretString};
16use snafu::{prelude::*, ResultExt};
17use tokio_postgres;
18
19use super::{runtime::run_async_with_tokio, DbConnectionPool};
20use crate::sql::db_connection_pool::{
21    dbconnection::{postgresconn::PostgresConnection, AsyncDbConnection, DbConnection},
22    JoinPushDown,
23};
24
25#[derive(Debug, Snafu)]
26pub enum Error {
27    #[snafu(display("PostgreSQL connection failed.\n{source}\nFor details, refer to the PostgreSQL documentation: https://www.postgresql.org/docs/17/index.html"))]
28    ConnectionPoolError {
29        source: bb8_postgres::tokio_postgres::Error,
30    },
31
32    #[snafu(display("PostgreSQL connection failed.\n{source}\nAdjust the connection pool parameters for sufficient capacity."))]
33    ConnectionPoolRunError {
34        source: bb8::RunError<bb8_postgres::tokio_postgres::Error>,
35    },
36
37    #[snafu(display(
38        "Invalid parameter: {parameter_name}. Ensure the parameter name is correct."
39    ))]
40    InvalidParameterError { parameter_name: String },
41
42    #[snafu(display("Could not parse {parameter_name} into a valid integer. Ensure it is configured with a valid value."))]
43    InvalidIntegerParameterError {
44        parameter_name: String,
45        source: std::num::ParseIntError,
46    },
47
48    #[snafu(display("Cannot connect to PostgreSQL on {host}:{port}. Ensure the host and port are correct and reachable."))]
49    InvalidHostOrPortError {
50        source: crate::util::ns_lookup::Error,
51        host: String,
52        port: u16,
53    },
54
55    #[snafu(display(
56        "Invalid root certificate path: {path}. Ensure it points to a valid root certificate."
57    ))]
58    InvalidRootCertPathError { path: String },
59
60    #[snafu(display(
61        "Failed to read certificate.\n{source}\nEnsure the root certificate path points to a valid certificate."
62    ))]
63    FailedToReadCertError { source: std::io::Error },
64
65    #[snafu(display(
66        "Certificate loading failed.\n{source}\nEnsure the root certificate path points to a valid certificate."
67    ))]
68    FailedToLoadCertError { source: native_tls::Error },
69
70    #[snafu(display("TLS connector initialization failed.\n{source}\nVerify SSL mode and root certificate validity"))]
71    FailedToBuildTlsConnectorError { source: native_tls::Error },
72
73    #[snafu(display("PostgreSQL connection failed.\n{source}\nFor details, refer to the PostgreSQL documentation: https://www.postgresql.org/docs/17/index.html"))]
74    PostgresConnectionError { source: tokio_postgres::Error },
75
76    #[snafu(display("Authentication failed. Verify username and password."))]
77    InvalidUsernameOrPassword { source: tokio_postgres::Error },
78}
79
80pub type Result<T, E = Error> = std::result::Result<T, E>;
81
82#[derive(Debug)]
83pub struct PostgresConnectionPool {
84    pool: Arc<bb8::Pool<PostgresConnectionManager<MakeTlsConnector>>>,
85    join_push_down: JoinPushDown,
86    unsupported_type_action: UnsupportedTypeAction,
87}
88
89impl PostgresConnectionPool {
90    /// Creates a new instance of `PostgresConnectionPool`.
91    ///
92    /// # Errors
93    ///
94    /// Returns an error if there is a problem creating the connection pool.
95    pub async fn new(params: HashMap<String, SecretString>) -> Result<Self> {
96        // Remove the "pg_" prefix from the keys to keep backward compatibility
97        let params = util::remove_prefix_from_hashmap_keys(params, "pg_");
98
99        let mut connection_string = String::new();
100        let mut ssl_mode = "verify-full".to_string();
101        let mut ssl_rootcert_path: Option<PathBuf> = None;
102
103        if let Some(pg_connection_string) = params
104            .get("connection_string")
105            .map(SecretBox::expose_secret)
106        {
107            let (str, mode, cert_path) = parse_connection_string(pg_connection_string);
108            connection_string = str;
109            ssl_mode = mode;
110            if let Some(cert_path) = cert_path {
111                let sslrootcert = cert_path.as_str();
112                ensure!(
113                    std::path::Path::new(sslrootcert).exists(),
114                    InvalidRootCertPathSnafu { path: cert_path }
115                );
116                ssl_rootcert_path = Some(PathBuf::from(sslrootcert));
117            }
118        } else {
119            if let Some(pg_host) = params.get("host").map(SecretBox::expose_secret) {
120                connection_string.push_str(format!("host={pg_host} ").as_str());
121            }
122            if let Some(pg_user) = params.get("user").map(SecretBox::expose_secret) {
123                connection_string.push_str(format!("user={pg_user} ").as_str());
124            }
125            if let Some(pg_db) = params.get("db").map(SecretBox::expose_secret) {
126                connection_string.push_str(format!("dbname={pg_db} ").as_str());
127            }
128            if let Some(pg_pass) = params.get("pass").map(SecretBox::expose_secret) {
129                connection_string.push_str(format!("password={pg_pass} ").as_str());
130            }
131            if let Some(pg_port) = params.get("port").map(SecretBox::expose_secret) {
132                connection_string.push_str(format!("port={pg_port} ").as_str());
133            }
134        }
135
136        if let Some(pg_sslmode) = params.get("sslmode").map(SecretBox::expose_secret) {
137            match pg_sslmode.to_lowercase().as_str() {
138                "disable" | "require" | "prefer" | "verify-ca" | "verify-full" => {
139                    ssl_mode = pg_sslmode.to_string();
140                }
141                _ => {
142                    InvalidParameterSnafu {
143                        parameter_name: "sslmode".to_string(),
144                    }
145                    .fail()?;
146                }
147            }
148        }
149        if let Some(pg_sslrootcert) = params.get("sslrootcert").map(SecretBox::expose_secret) {
150            ensure!(
151                std::path::Path::new(pg_sslrootcert).exists(),
152                InvalidRootCertPathSnafu {
153                    path: pg_sslrootcert,
154                }
155            );
156
157            ssl_rootcert_path = Some(PathBuf::from(pg_sslrootcert));
158        }
159
160        let mode = match ssl_mode.as_str() {
161            "disable" => "disable",
162            "prefer" => "prefer",
163            // tokio_postgres supports only disable, require and prefer
164            _ => "require",
165        };
166
167        connection_string.push_str(format!("sslmode={mode} ").as_str());
168        let mut config =
169            Config::from_str(connection_string.as_str()).context(ConnectionPoolSnafu)?;
170
171        if let Some(application_name) = params.get("application_name").map(SecretBox::expose_secret)
172        {
173            config.application_name(application_name);
174        }
175
176        verify_postgres_config(&config).await?;
177
178        let mut certs: Option<Vec<Certificate>> = None;
179
180        if let Some(path) = ssl_rootcert_path {
181            let buf = tokio::fs::read(path).await.context(FailedToReadCertSnafu)?;
182            certs = Some(parse_certs(&buf)?);
183        }
184
185        let tls_connector = get_tls_connector(ssl_mode.as_str(), certs)?;
186        let connector = MakeTlsConnector::new(tls_connector);
187        test_postgres_connection(connection_string.as_str(), connector.clone()).await?;
188
189        let join_push_down = get_join_context(&config);
190
191        let manager = PostgresConnectionManager::new(config, connector);
192        let error_sink = PostgresErrorSink::new();
193
194        let mut connection_pool_size = 10; // The BB8 default is 10
195        if let Some(pg_pool_size) = params
196            .get("connection_pool_size")
197            .map(SecretBox::expose_secret)
198        {
199            connection_pool_size = pg_pool_size.parse().context(InvalidIntegerParameterSnafu {
200                parameter_name: "pool_size".to_string(),
201            })?;
202        }
203
204        let pool = bb8::Pool::builder()
205            .max_size(connection_pool_size)
206            .error_sink(Box::new(error_sink))
207            .build(manager)
208            .await
209            .context(ConnectionPoolSnafu)?;
210
211        // Test the connection
212        let conn = pool.get().await.context(ConnectionPoolRunSnafu)?;
213        conn.execute("SELECT 1", &[])
214            .await
215            .context(ConnectionPoolSnafu)?;
216
217        Ok(PostgresConnectionPool {
218            pool: Arc::new(pool.clone()),
219            join_push_down,
220            unsupported_type_action: UnsupportedTypeAction::default(),
221        })
222    }
223
224    /// Specify the action to take when an invalid type is encountered.
225    #[must_use]
226    pub fn with_unsupported_type_action(mut self, action: UnsupportedTypeAction) -> Self {
227        self.unsupported_type_action = action;
228        self
229    }
230
231    /// Returns a direct connection to the underlying database.
232    ///
233    /// # Errors
234    ///
235    /// Returns an error if there is a problem creating the connection pool.
236    pub async fn connect_direct(&self) -> super::Result<PostgresConnection> {
237        let pool = Arc::clone(&self.pool);
238        let conn = pool.get_owned().await.context(ConnectionPoolRunSnafu)?;
239        Ok(PostgresConnection::new(conn))
240    }
241}
242
243fn parse_connection_string(pg_connection_string: &str) -> (String, String, Option<String>) {
244    let mut connection_string = String::new();
245    let mut ssl_mode = "verify-full".to_string();
246    let mut ssl_rootcert_path: Option<String> = None;
247
248    let str = pg_connection_string;
249    let str_params: Vec<&str> = str.split_whitespace().collect();
250    for param in str_params {
251        let param = param.split('=').collect::<Vec<&str>>();
252        if let (Some(&name), Some(&value)) = (param.first(), param.get(1)) {
253            match name {
254                "sslmode" => {
255                    ssl_mode = value.to_string();
256                }
257                "sslrootcert" => {
258                    ssl_rootcert_path = Some(value.to_string());
259                }
260                _ => {
261                    connection_string.push_str(format!("{name}={value} ").as_str());
262                }
263            }
264        }
265    }
266
267    (connection_string, ssl_mode, ssl_rootcert_path)
268}
269
270fn get_join_context(config: &Config) -> JoinPushDown {
271    let mut join_push_context_str = String::new();
272    for host in config.get_hosts() {
273        join_push_context_str.push_str(&format!("host={host:?},"));
274    }
275    if !config.get_ports().is_empty() {
276        join_push_context_str.push_str(&format!("port={port},", port = config.get_ports()[0]));
277    }
278    if let Some(dbname) = config.get_dbname() {
279        join_push_context_str.push_str(&format!("db={dbname},"));
280    }
281    if let Some(user) = config.get_user() {
282        join_push_context_str.push_str(&format!("user={user},"));
283    }
284
285    JoinPushDown::AllowedFor(join_push_context_str)
286}
287
288async fn test_postgres_connection(
289    connection_string: &str,
290    connector: MakeTlsConnector,
291) -> Result<()> {
292    match tokio_postgres::connect(connection_string, connector).await {
293        Ok(_) => Ok(()),
294        Err(err) => {
295            if let Some(code) = err.code() {
296                if *code == tokio_postgres::error::SqlState::INVALID_PASSWORD {
297                    return Err(Error::InvalidUsernameOrPassword { source: err });
298                }
299            }
300
301            Err(Error::PostgresConnectionError { source: err })
302        }
303    }
304}
305
306async fn verify_postgres_config(config: &Config) -> Result<()> {
307    for host in config.get_hosts() {
308        for port in config.get_ports() {
309            if let Host::Tcp(host) = host {
310                verify_ns_lookup_and_tcp_connect(host, *port)
311                    .await
312                    .context(InvalidHostOrPortSnafu { host, port: *port })?;
313            }
314        }
315    }
316
317    Ok(())
318}
319
320fn get_tls_connector(ssl_mode: &str, rootcerts: Option<Vec<Certificate>>) -> Result<TlsConnector> {
321    let mut builder = TlsConnector::builder();
322
323    if ssl_mode == "disable" {
324        return builder.build().context(FailedToBuildTlsConnectorSnafu);
325    }
326
327    if let Some(certs) = rootcerts {
328        for cert in certs {
329            builder.add_root_certificate(cert);
330        }
331    }
332
333    builder
334        .danger_accept_invalid_hostnames(ssl_mode != "verify-full")
335        .danger_accept_invalid_certs(ssl_mode != "verify-full" && ssl_mode != "verify-ca")
336        .build()
337        .context(FailedToBuildTlsConnectorSnafu)
338}
339
340fn parse_certs(buf: &[u8]) -> Result<Vec<Certificate>> {
341    Certificate::from_der(buf)
342        .map(|x| vec![x])
343        .or_else(|_| {
344            pem::parse_many(buf)
345                .unwrap_or_default()
346                .iter()
347                .map(pem::encode)
348                .map(|s| Certificate::from_pem(s.as_bytes()))
349                .collect()
350        })
351        .context(FailedToLoadCertSnafu)
352}
353
354#[derive(Debug, Clone, Copy)]
355struct PostgresErrorSink {}
356
357impl PostgresErrorSink {
358    pub fn new() -> Self {
359        PostgresErrorSink {}
360    }
361}
362
363impl<E> ErrorSink<E> for PostgresErrorSink
364where
365    E: std::fmt::Debug,
366    E: std::fmt::Display,
367{
368    fn sink(&self, error: E) {
369        tracing::debug!("Postgres Pool Error: {}", error);
370    }
371
372    fn boxed_clone(&self) -> Box<dyn ErrorSink<E>> {
373        Box::new(*self)
374    }
375}
376
377#[async_trait]
378impl
379    DbConnectionPool<
380        bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
381        &'static (dyn ToSql + Sync),
382    > for PostgresConnectionPool
383{
384    async fn connect(
385        &self,
386    ) -> super::Result<
387        Box<
388            dyn DbConnection<
389                bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
390                &'static (dyn ToSql + Sync),
391            >,
392        >,
393    > {
394        let pool = Arc::clone(&self.pool);
395        let get_conn = async || pool.get_owned().await.context(ConnectionPoolRunSnafu);
396        let conn = run_async_with_tokio(get_conn).await?;
397        Ok(Box::new(
398            PostgresConnection::new(conn)
399                .with_unsupported_type_action(self.unsupported_type_action),
400        ))
401    }
402
403    fn join_push_down(&self) -> JoinPushDown {
404        self.join_push_down.clone()
405    }
406}