use std::time::Duration;
use miette::{IntoDiagnostic, Report, Result, WrapErr, miette};
use mobc::{Connection, Pool};
use tokio_postgres::config::SslMode;
use tracing::debug;
pub use manager::PgError;
mod manager;
mod parse;
mod tls;
fn is_tls_error(error: &Report) -> bool {
if error.downcast_ref::<rustls::Error>().is_some() {
return true;
}
let mut source = error.source();
while let Some(err) = source {
if err.downcast_ref::<rustls::Error>().is_some() {
return true;
}
source = err.source();
}
let message = error.to_string();
message.contains("tls:")
|| message.contains("rustls")
|| message.contains("certificate")
|| message.contains("TLS handshake")
|| message.contains("invalid configuration")
}
fn is_auth_error(error: &Report) -> bool {
if let Some(db_error) = error.downcast_ref::<tokio_postgres::Error>() {
if let Some(db_error) = db_error.as_db_error() {
let code = db_error.code().code();
return code == "28000" || code == "28P01";
}
}
let message = error.to_string();
message.contains("password authentication failed")
|| message.contains("no password supplied")
|| message.contains("authentication failed")
}
pub type PgConnection = Connection<manager::PgConnectionManager>;
#[derive(Debug, Clone)]
pub struct PgPool {
pub manager: manager::PgConnectionManager,
pub inner: Pool<manager::PgConnectionManager>,
}
impl PgPool {
pub async fn get(&self) -> Result<PgConnection, mobc::Error<PgError>> {
self.inner.get().await
}
pub async fn get_timeout(
&self,
duration: Duration,
) -> Result<PgConnection, mobc::Error<PgError>> {
self.inner.get_timeout(duration).await
}
}
pub async fn create_pool(url: &str) -> Result<PgPool> {
let mut config = parse::parse_connection_url(url)?;
let mut tried_ssl_fallback = false;
let pool = loop {
debug!("Creating manager");
let tls = config.get_ssl_mode() != SslMode::Disable;
let manager = manager::PgConnectionManager::new(config.clone(), tls);
debug!("Creating pool");
let pool = Pool::builder()
.max_lifetime(Some(Duration::from_secs(3600)))
.build(manager.clone());
let pool = PgPool {
manager,
inner: pool,
};
debug!("Checking pool");
match check_pool(&pool).await {
Ok(_) => {
if tried_ssl_fallback {
tracing::info!("Connected successfully with SSL disabled after TLS error");
}
break pool;
}
Err(e) => {
debug!("Connection error: {:#}", e);
debug!(
"is_tls_error: {}, is_auth_error: {}",
is_tls_error(&e),
is_auth_error(&e)
);
if is_tls_error(&e) {
if config.get_ssl_mode() == SslMode::Prefer && !tried_ssl_fallback {
debug!("TLS failed with prefer mode, retrying with SSL disabled");
config.ssl_mode(SslMode::Disable);
tried_ssl_fallback = true;
continue;
}
return Err(e).wrap_err(
"TLS/SSL connection failed. Try using --ssl disable, \
or use a connection URL with sslmode=disable: \
postgresql://user@host/db?sslmode=disable",
);
} else if is_auth_error(&e) && config.get_password().is_none() {
let password = rpassword::prompt_password("Password: ").into_diagnostic()?;
config.password(password);
} else {
return Err(e);
}
}
}
};
Ok(pool)
}
async fn check_pool(pool: &PgPool) -> Result<()> {
let conn = match pool.get().await {
Err(mobc::Error::Inner(db_err)) => {
return Err(match db_err.as_db_error() {
Some(db_err) => miette!(
"E{code} at {func} in {file}:{line}",
code = db_err.code().code(),
func = db_err.routine().unwrap_or("{unknown}"),
file = db_err.file().unwrap_or("unknown.c"),
line = db_err.line().unwrap_or(0)
),
_ => miette!("{db_err}"),
})
.wrap_err(
db_err
.as_db_error()
.map(|e| e.to_string())
.unwrap_or_default(),
)?;
}
res @ Err(_) => {
let res = res.map(drop).into_diagnostic();
return if let Err(ref err) = res
&& is_auth_error(err)
{
res.wrap_err("hint: check the password")
} else {
res
};
}
Ok(conn) => conn,
};
conn.simple_query("SELECT 1")
.await
.into_diagnostic()
.wrap_err("checking connection")?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_create_pool_valid_connection_string() {
let connection_string = "postgresql://localhost/test";
let result = create_pool(connection_string).await;
if let Err(e) = result {
let error_msg = format!("{:?}", e);
assert!(
!error_msg.contains("parsing connection string"),
"Should not be a parsing error: {}",
error_msg
);
}
}
#[tokio::test]
async fn test_create_pool_with_full_url() {
let connection_string = "postgresql://user:pass@localhost:5432/testdb";
let result = create_pool(connection_string).await;
if let Err(e) = result {
let error_msg = format!("{:?}", e);
assert!(
!error_msg.contains("parsing connection string"),
"Should not be a parsing error: {}",
error_msg
);
}
}
#[tokio::test]
async fn test_create_pool_with_unix_socket_path() {
let url = "postgresql:///postgres?host=/var/run/postgresql";
let result = create_pool(url).await;
match result {
Ok(_) => {
}
Err(e) => {
let error_msg = format!("{:?}", e);
assert!(
!error_msg.contains("parsing connection string"),
"Should not be a parsing error: {}",
error_msg
);
}
}
}
#[tokio::test]
async fn test_create_pool_with_encoded_unix_socket() {
let url = "postgresql://%2Fvar%2Frun%2Fpostgresql/postgres";
let result = create_pool(url).await;
match result {
Ok(_) => {
}
Err(e) => {
let error_msg = format!("{:?}", e);
assert!(
!error_msg.contains("parsing connection string"),
"Should not be a parsing error: {}",
error_msg
);
}
}
}
#[tokio::test]
async fn test_create_pool_with_no_host() {
let url = "postgresql:///postgres";
let result = create_pool(url).await;
match result {
Ok(_) => {
}
Err(e) => {
let error_msg = format!("{:?}", e);
assert!(
!error_msg.contains("parsing connection string"),
"Should not be a parsing error: {}",
error_msg
);
}
}
}
#[tokio::test]
async fn test_unix_socket_connection_end_to_end() {
let url = "postgresql:///postgres?host=/var/run/postgresql";
let result = create_pool(url).await;
match result {
Ok(pool) => {
let conn = pool.get().await;
if let Ok(conn) = conn {
let result = conn.simple_query("SELECT 1 as test").await;
assert!(result.is_ok(), "Query should succeed");
}
}
Err(e) => {
let error_msg = format!("{:?}", e);
assert!(
!error_msg.contains("parsing connection string"),
"Should not be a parsing error: {}",
error_msg
);
assert!(
!error_msg.contains("TLS handshake"),
"Should not be a TLS error for Unix socket: {}",
error_msg
);
}
}
}
}