datafusion_table_providers/sql/db_connection_pool/
postgrespool.rs1use std::{collections::HashMap, path::PathBuf, str::FromStr, sync::Arc};
2
3use crate::{
4 util::{self, ns_lookup::verify_ns_lookup_and_tcp_connect},
5 UnsupportedTypeAction,
6};
7use async_trait::async_trait;
8use bb8::ErrorSink;
9use bb8_postgres::{
10 tokio_postgres::{config::Host, types::ToSql, Config},
11 PostgresConnectionManager,
12};
13use native_tls::{Certificate, TlsConnector};
14use postgres_native_tls::MakeTlsConnector;
15use secrecy::{ExposeSecret, SecretBox, SecretString};
16use snafu::{prelude::*, ResultExt};
17use tokio_postgres;
18
19use super::{runtime::run_async_with_tokio, DbConnectionPool};
20use crate::sql::db_connection_pool::{
21 dbconnection::{postgresconn::PostgresConnection, AsyncDbConnection, DbConnection},
22 JoinPushDown,
23};
24
25#[derive(Debug, Snafu)]
26pub enum Error {
27 #[snafu(display("PostgreSQL connection failed.\n{source}\nFor details, refer to the PostgreSQL documentation: https://www.postgresql.org/docs/17/index.html"))]
28 ConnectionPoolError {
29 source: bb8_postgres::tokio_postgres::Error,
30 },
31
32 #[snafu(display("PostgreSQL connection failed.\n{source}\nAdjust the connection pool parameters for sufficient capacity."))]
33 ConnectionPoolRunError {
34 source: bb8::RunError<bb8_postgres::tokio_postgres::Error>,
35 },
36
37 #[snafu(display(
38 "Invalid parameter: {parameter_name}. Ensure the parameter name is correct."
39 ))]
40 InvalidParameterError { parameter_name: String },
41
42 #[snafu(display("Could not parse {parameter_name} into a valid integer. Ensure it is configured with a valid value."))]
43 InvalidIntegerParameterError {
44 parameter_name: String,
45 source: std::num::ParseIntError,
46 },
47
48 #[snafu(display("Cannot connect to PostgreSQL on {host}:{port}. Ensure the host and port are correct and reachable."))]
49 InvalidHostOrPortError {
50 source: crate::util::ns_lookup::Error,
51 host: String,
52 port: u16,
53 },
54
55 #[snafu(display(
56 "Invalid root certificate path: {path}. Ensure it points to a valid root certificate."
57 ))]
58 InvalidRootCertPathError { path: String },
59
60 #[snafu(display(
61 "Failed to read certificate.\n{source}\nEnsure the root certificate path points to a valid certificate."
62 ))]
63 FailedToReadCertError { source: std::io::Error },
64
65 #[snafu(display(
66 "Certificate loading failed.\n{source}\nEnsure the root certificate path points to a valid certificate."
67 ))]
68 FailedToLoadCertError { source: native_tls::Error },
69
70 #[snafu(display("TLS connector initialization failed.\n{source}\nVerify SSL mode and root certificate validity"))]
71 FailedToBuildTlsConnectorError { source: native_tls::Error },
72
73 #[snafu(display("PostgreSQL connection failed.\n{source}\nFor details, refer to the PostgreSQL documentation: https://www.postgresql.org/docs/17/index.html"))]
74 PostgresConnectionError { source: tokio_postgres::Error },
75
76 #[snafu(display("Authentication failed. Verify username and password."))]
77 InvalidUsernameOrPassword { source: tokio_postgres::Error },
78}
79
80pub type Result<T, E = Error> = std::result::Result<T, E>;
81
82#[derive(Debug)]
83pub struct PostgresConnectionPool {
84 pool: Arc<bb8::Pool<PostgresConnectionManager<MakeTlsConnector>>>,
85 join_push_down: JoinPushDown,
86 unsupported_type_action: UnsupportedTypeAction,
87}
88
89impl PostgresConnectionPool {
90 pub async fn new(params: HashMap<String, SecretString>) -> Result<Self> {
96 let params = util::remove_prefix_from_hashmap_keys(params, "pg_");
98
99 let mut connection_string = String::new();
100 let mut ssl_mode = "verify-full".to_string();
101 let mut ssl_rootcert_path: Option<PathBuf> = None;
102
103 if let Some(pg_connection_string) = params
104 .get("connection_string")
105 .map(SecretBox::expose_secret)
106 {
107 let (str, mode, cert_path) = parse_connection_string(pg_connection_string);
108 connection_string = str;
109 ssl_mode = mode;
110 if let Some(cert_path) = cert_path {
111 let sslrootcert = cert_path.as_str();
112 ensure!(
113 std::path::Path::new(sslrootcert).exists(),
114 InvalidRootCertPathSnafu { path: cert_path }
115 );
116 ssl_rootcert_path = Some(PathBuf::from(sslrootcert));
117 }
118 } else {
119 if let Some(pg_host) = params.get("host").map(SecretBox::expose_secret) {
120 connection_string.push_str(format!("host={pg_host} ").as_str());
121 }
122 if let Some(pg_user) = params.get("user").map(SecretBox::expose_secret) {
123 connection_string.push_str(format!("user={pg_user} ").as_str());
124 }
125 if let Some(pg_db) = params.get("db").map(SecretBox::expose_secret) {
126 connection_string.push_str(format!("dbname={pg_db} ").as_str());
127 }
128 if let Some(pg_pass) = params.get("pass").map(SecretBox::expose_secret) {
129 connection_string.push_str(format!("password={pg_pass} ").as_str());
130 }
131 if let Some(pg_port) = params.get("port").map(SecretBox::expose_secret) {
132 connection_string.push_str(format!("port={pg_port} ").as_str());
133 }
134 }
135
136 if let Some(pg_sslmode) = params.get("sslmode").map(SecretBox::expose_secret) {
137 match pg_sslmode.to_lowercase().as_str() {
138 "disable" | "require" | "prefer" | "verify-ca" | "verify-full" => {
139 ssl_mode = pg_sslmode.to_string();
140 }
141 _ => {
142 InvalidParameterSnafu {
143 parameter_name: "sslmode".to_string(),
144 }
145 .fail()?;
146 }
147 }
148 }
149 if let Some(pg_sslrootcert) = params.get("sslrootcert").map(SecretBox::expose_secret) {
150 ensure!(
151 std::path::Path::new(pg_sslrootcert).exists(),
152 InvalidRootCertPathSnafu {
153 path: pg_sslrootcert,
154 }
155 );
156
157 ssl_rootcert_path = Some(PathBuf::from(pg_sslrootcert));
158 }
159
160 let mode = match ssl_mode.as_str() {
161 "disable" => "disable",
162 "prefer" => "prefer",
163 _ => "require",
165 };
166
167 connection_string.push_str(format!("sslmode={mode} ").as_str());
168 let mut config =
169 Config::from_str(connection_string.as_str()).context(ConnectionPoolSnafu)?;
170
171 if let Some(application_name) = params.get("application_name").map(SecretBox::expose_secret)
172 {
173 config.application_name(application_name);
174 }
175
176 verify_postgres_config(&config).await?;
177
178 let mut certs: Option<Vec<Certificate>> = None;
179
180 if let Some(path) = ssl_rootcert_path {
181 let buf = tokio::fs::read(path).await.context(FailedToReadCertSnafu)?;
182 certs = Some(parse_certs(&buf)?);
183 }
184
185 let tls_connector = get_tls_connector(ssl_mode.as_str(), certs)?;
186 let connector = MakeTlsConnector::new(tls_connector);
187 test_postgres_connection(connection_string.as_str(), connector.clone()).await?;
188
189 let join_push_down = get_join_context(&config);
190
191 let manager = PostgresConnectionManager::new(config, connector);
192 let error_sink = PostgresErrorSink::new();
193
194 let mut connection_pool_size = 10; if let Some(pg_pool_size) = params
196 .get("connection_pool_size")
197 .map(SecretBox::expose_secret)
198 {
199 connection_pool_size = pg_pool_size.parse().context(InvalidIntegerParameterSnafu {
200 parameter_name: "pool_size".to_string(),
201 })?;
202 }
203
204 let pool = bb8::Pool::builder()
205 .max_size(connection_pool_size)
206 .error_sink(Box::new(error_sink))
207 .build(manager)
208 .await
209 .context(ConnectionPoolSnafu)?;
210
211 let conn = pool.get().await.context(ConnectionPoolRunSnafu)?;
213 conn.execute("SELECT 1", &[])
214 .await
215 .context(ConnectionPoolSnafu)?;
216
217 Ok(PostgresConnectionPool {
218 pool: Arc::new(pool.clone()),
219 join_push_down,
220 unsupported_type_action: UnsupportedTypeAction::default(),
221 })
222 }
223
224 #[must_use]
226 pub fn with_unsupported_type_action(mut self, action: UnsupportedTypeAction) -> Self {
227 self.unsupported_type_action = action;
228 self
229 }
230
231 pub async fn connect_direct(&self) -> super::Result<PostgresConnection> {
237 let pool = Arc::clone(&self.pool);
238 let conn = pool.get_owned().await.context(ConnectionPoolRunSnafu)?;
239 Ok(PostgresConnection::new(conn))
240 }
241}
242
243fn parse_connection_string(pg_connection_string: &str) -> (String, String, Option<String>) {
244 let mut connection_string = String::new();
245 let mut ssl_mode = "verify-full".to_string();
246 let mut ssl_rootcert_path: Option<String> = None;
247
248 let str = pg_connection_string;
249 let str_params: Vec<&str> = str.split_whitespace().collect();
250 for param in str_params {
251 let param = param.split('=').collect::<Vec<&str>>();
252 if let (Some(&name), Some(&value)) = (param.first(), param.get(1)) {
253 match name {
254 "sslmode" => {
255 ssl_mode = value.to_string();
256 }
257 "sslrootcert" => {
258 ssl_rootcert_path = Some(value.to_string());
259 }
260 _ => {
261 connection_string.push_str(format!("{name}={value} ").as_str());
262 }
263 }
264 }
265 }
266
267 (connection_string, ssl_mode, ssl_rootcert_path)
268}
269
270fn get_join_context(config: &Config) -> JoinPushDown {
271 let mut join_push_context_str = String::new();
272 for host in config.get_hosts() {
273 join_push_context_str.push_str(&format!("host={host:?},"));
274 }
275 if !config.get_ports().is_empty() {
276 join_push_context_str.push_str(&format!("port={port},", port = config.get_ports()[0]));
277 }
278 if let Some(dbname) = config.get_dbname() {
279 join_push_context_str.push_str(&format!("db={dbname},"));
280 }
281 if let Some(user) = config.get_user() {
282 join_push_context_str.push_str(&format!("user={user},"));
283 }
284
285 JoinPushDown::AllowedFor(join_push_context_str)
286}
287
288async fn test_postgres_connection(
289 connection_string: &str,
290 connector: MakeTlsConnector,
291) -> Result<()> {
292 match tokio_postgres::connect(connection_string, connector).await {
293 Ok(_) => Ok(()),
294 Err(err) => {
295 if let Some(code) = err.code() {
296 if *code == tokio_postgres::error::SqlState::INVALID_PASSWORD {
297 return Err(Error::InvalidUsernameOrPassword { source: err });
298 }
299 }
300
301 Err(Error::PostgresConnectionError { source: err })
302 }
303 }
304}
305
306async fn verify_postgres_config(config: &Config) -> Result<()> {
307 for host in config.get_hosts() {
308 for port in config.get_ports() {
309 if let Host::Tcp(host) = host {
310 verify_ns_lookup_and_tcp_connect(host, *port)
311 .await
312 .context(InvalidHostOrPortSnafu { host, port: *port })?;
313 }
314 }
315 }
316
317 Ok(())
318}
319
320fn get_tls_connector(ssl_mode: &str, rootcerts: Option<Vec<Certificate>>) -> Result<TlsConnector> {
321 let mut builder = TlsConnector::builder();
322
323 if ssl_mode == "disable" {
324 return builder.build().context(FailedToBuildTlsConnectorSnafu);
325 }
326
327 if let Some(certs) = rootcerts {
328 for cert in certs {
329 builder.add_root_certificate(cert);
330 }
331 }
332
333 builder
334 .danger_accept_invalid_hostnames(ssl_mode != "verify-full")
335 .danger_accept_invalid_certs(ssl_mode != "verify-full" && ssl_mode != "verify-ca")
336 .build()
337 .context(FailedToBuildTlsConnectorSnafu)
338}
339
340fn parse_certs(buf: &[u8]) -> Result<Vec<Certificate>> {
341 Certificate::from_der(buf)
342 .map(|x| vec![x])
343 .or_else(|_| {
344 pem::parse_many(buf)
345 .unwrap_or_default()
346 .iter()
347 .map(pem::encode)
348 .map(|s| Certificate::from_pem(s.as_bytes()))
349 .collect()
350 })
351 .context(FailedToLoadCertSnafu)
352}
353
354#[derive(Debug, Clone, Copy)]
355struct PostgresErrorSink {}
356
357impl PostgresErrorSink {
358 pub fn new() -> Self {
359 PostgresErrorSink {}
360 }
361}
362
363impl<E> ErrorSink<E> for PostgresErrorSink
364where
365 E: std::fmt::Debug,
366 E: std::fmt::Display,
367{
368 fn sink(&self, error: E) {
369 tracing::debug!("Postgres Pool Error: {}", error);
370 }
371
372 fn boxed_clone(&self) -> Box<dyn ErrorSink<E>> {
373 Box::new(*self)
374 }
375}
376
377#[async_trait]
378impl
379 DbConnectionPool<
380 bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
381 &'static (dyn ToSql + Sync),
382 > for PostgresConnectionPool
383{
384 async fn connect(
385 &self,
386 ) -> super::Result<
387 Box<
388 dyn DbConnection<
389 bb8::PooledConnection<'static, PostgresConnectionManager<MakeTlsConnector>>,
390 &'static (dyn ToSql + Sync),
391 >,
392 >,
393 > {
394 let pool = Arc::clone(&self.pool);
395 let get_conn = async || pool.get_owned().await.context(ConnectionPoolRunSnafu);
396 let conn = run_async_with_tokio(get_conn).await?;
397 Ok(Box::new(
398 PostgresConnection::new(conn)
399 .with_unsupported_type_action(self.unsupported_type_action),
400 ))
401 }
402
403 fn join_push_down(&self) -> JoinPushDown {
404 self.join_push_down.clone()
405 }
406}