use anyhow::{Result, anyhow};
use once_cell::sync::OnceCell;
use secrecy::{ExposeSecret, SecretString};
use sqlx::PgPool;
use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
use std::{collections::HashMap, str::FromStr, sync::Arc};
use tokio::sync::RwLock;
static EXCLUDED: OnceCell<Arc<[String]>> = OnceCell::new();
static BASE_OPTS: OnceCell<PgConnectOptions> = OnceCell::new();
static DEFAULT_DB: OnceCell<String> = OnceCell::new();
static POOLS: OnceCell<RwLock<HashMap<String, PgPool>>> = OnceCell::new();
pub fn set_excluded_databases(list: Vec<String>) {
let mut cleaned: Vec<String> = list
.into_iter()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
cleaned.dedup();
let _ = EXCLUDED.set(Arc::from(cleaned));
}
#[inline]
pub fn get_excluded_databases() -> &'static [String] {
match EXCLUDED.get() {
Some(arc) => &arc[..],
None => &[],
}
}
#[inline]
pub fn is_database_excluded(datname: &str) -> bool {
get_excluded_databases().iter().any(|d| d == datname)
}
pub fn set_base_connect_options_from_dsn(dsn: &SecretString) -> Result<()> {
if BASE_OPTS.get().is_none() {
let opts = PgConnectOptions::from_str(dsn.expose_secret())?;
let _ = BASE_OPTS.set(opts.clone());
let dbname = opts.get_database().unwrap_or("postgres").to_string();
let _ = DEFAULT_DB.set(dbname);
}
if POOLS.get().is_none() {
let _ = POOLS.set(RwLock::new(HashMap::new()));
}
Ok(())
}
#[inline]
pub fn get_default_database() -> Option<&'static str> {
DEFAULT_DB.get().map(|s| s.as_str())
}
pub fn connect_options_for_db(datname: &str) -> Result<PgConnectOptions> {
let base = BASE_OPTS.get().cloned().ok_or_else(|| {
anyhow!("BASE_OPTS not set; call set_base_connect_options_from_dsn() at startup")
})?;
Ok(base.database(datname))
}
pub async fn get_or_create_pool_for_db(datname: &str) -> Result<PgPool> {
if let Some(def) = get_default_database()
&& def == datname
{
return Err(anyhow!(
"get_or_create_pool_for_db called for default database; use shared pool"
));
}
let pools = POOLS.get().ok_or_else(|| {
anyhow!("Pool cache not initialized; call set_base_connect_options_from_dsn()")
})?;
{
let guard = pools.read().await;
if let Some(pool) = guard.get(datname) {
return Ok(pool.clone());
}
}
let opts = connect_options_for_db(datname)?;
let pool = PgPoolOptions::new()
.max_connections(1)
.min_connections(0)
.acquire_timeout(std::time::Duration::from_secs(5))
.connect_with(opts)
.await?;
let mut guard = pools.write().await;
guard.insert(datname.to_string(), pool.clone());
Ok(pool)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_set_and_get_exclusions() {
set_excluded_databases(vec![
"postgres".into(),
"template0".into(),
"template0".into(), " ".into(), ]);
let got = get_excluded_databases();
assert_eq!(got, &["postgres".to_string(), "template0".to_string()]);
assert!(is_database_excluded("postgres"));
assert!(!is_database_excluded("not_there"));
}
}