use anyhow::{Context, Result};
use sqlx::PgPool;
use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
use std::time::Duration;
use tracing::info;
use tracing::log::LevelFilter;
pub fn mask_url_password(url: &str) -> String {
if !url.contains("://") {
return url.to_string();
}
let parts: Vec<&str> = url.splitn(2, "://").collect();
if parts.len() != 2 {
return url.to_string();
}
let protocol = parts[0];
let rest = parts[1];
if let Some(at_pos) = rest.find('@') {
let user_info = &rest[..at_pos];
let host_and_path = &rest[at_pos + 1..];
if let Some(colon_pos) = user_info.find(':') {
let username = &user_info[..colon_pos];
return format!("{}://{}:***@{}", protocol, username, host_and_path);
}
}
url.to_string()
}
pub async fn connect_to_database(url: &str, label: &str) -> Result<PgPool> {
PgPoolOptions::new()
.acquire_timeout(Duration::from_secs(5))
.connect(url)
.await
.with_context(|| {
format!(
"Failed to connect to {} at {}",
label,
mask_url_password(url)
)
})
}
#[derive(Debug, Clone)]
pub struct ConnectionConfig {
pub max_retries: u32,
pub retry_delay: Duration,
}
impl Default for ConnectionConfig {
fn default() -> Self {
Self {
max_retries: 5,
retry_delay: Duration::from_millis(200),
}
}
}
pub async fn connect_with_retry(url: &str) -> Result<PgPool> {
connect_with_retry_config(url, &ConnectionConfig::default()).await
}
pub async fn connect_with_retry_config(url: &str, config: &ConnectionConfig) -> Result<PgPool> {
let mut last_error = None;
for attempt in 0..=config.max_retries {
match PgPool::connect(url).await {
Ok(pool) => {
if attempt > 0 {
info!(
"✅ Connected to database (after {} retry{})",
attempt,
if attempt == 1 { "" } else { "ies" }
);
} else {
info!("✅ Connected to database");
}
return Ok(pool);
}
Err(e) => {
last_error = Some(e);
if attempt < config.max_retries {
if attempt == 0 {
info!("🔄 Database not ready, retrying...");
}
tokio::time::sleep(config.retry_delay).await;
}
}
}
}
Err(anyhow::anyhow!(
"Failed to connect to database at {} after {} attempts: {}",
mask_url_password(url),
config.max_retries + 1,
last_error.unwrap()
))
}
pub async fn connect_with_retry_quiet(url: &str) -> Result<PgPool> {
connect_with_retry_config_quiet(url, &ConnectionConfig::default()).await
}
pub async fn connect_with_retry_config_quiet(
url: &str,
config: &ConnectionConfig,
) -> Result<PgPool> {
use sqlx::ConnectOptions;
use std::str::FromStr;
let mut last_error = None;
let connect_options = PgConnectOptions::from_str(url)
.map_err(|e| anyhow::anyhow!("Invalid database URL: {}", e))?
.log_slow_statements(LevelFilter::Off, Duration::from_secs(0));
for attempt in 0..=config.max_retries {
let result = PgPoolOptions::new()
.connect_with(connect_options.clone())
.await;
match result {
Ok(pool) => {
if attempt > 0 {
info!(
"✅ Connected to database (after {} retry{})",
attempt,
if attempt == 1 { "" } else { "ies" }
);
} else {
info!("✅ Connected to database");
}
return Ok(pool);
}
Err(e) => {
last_error = Some(e);
if attempt < config.max_retries {
if attempt == 0 {
info!("🔄 Database not ready, retrying...");
}
tokio::time::sleep(config.retry_delay).await;
}
}
}
}
Err(anyhow::anyhow!(
"Failed to connect to database at {} after {} attempts: {}",
mask_url_password(url),
config.max_retries + 1,
last_error.unwrap()
))
}
pub async fn initialize_database_session(pool: &PgPool) -> Result<()> {
sqlx::query("SET search_path = public, pg_catalog")
.execute(pool)
.await
.map_err(|e| anyhow::anyhow!("Failed to set search_path: {}", e))?;
sqlx::query("SET standard_conforming_strings = on")
.execute(pool)
.await
.map_err(|e| anyhow::anyhow!("Failed to set standard_conforming_strings: {}", e))?;
let search_path: (String,) = sqlx::query_as("SHOW search_path")
.fetch_one(pool)
.await
.map_err(|e| anyhow::anyhow!("Failed to query search_path: {}", e))?;
info!(
"🔧 Database session initialized with search_path: {}",
search_path.0
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_config_default() {
let config = ConnectionConfig::default();
assert_eq!(config.max_retries, 5);
assert_eq!(config.retry_delay, Duration::from_millis(200));
}
#[test]
fn test_connection_config_custom() {
let config = ConnectionConfig {
max_retries: 3,
retry_delay: Duration::from_millis(100),
};
assert_eq!(config.max_retries, 3);
assert_eq!(config.retry_delay, Duration::from_millis(100));
}
#[test]
fn test_mask_url_password() {
assert_eq!(
mask_url_password("postgres://user:secret@localhost:5432/mydb"),
"postgres://user:***@localhost:5432/mydb"
);
assert_eq!(
mask_url_password("postgres://user@localhost/mydb"),
"postgres://user@localhost/mydb"
);
assert_eq!(
mask_url_password("postgres://localhost/mydb"),
"postgres://localhost/mydb"
);
assert_eq!(mask_url_password("not a url"), "not a url");
}
}