use anyhow::{Result, anyhow};
use once_cell::sync::OnceCell;
use secrecy::{ExposeSecret, SecretString};
use sqlx::PgPool;
use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
use std::{collections::HashMap, str::FromStr, sync::Arc};
use tokio::sync::RwLock;
static EXCLUDED: OnceCell<Arc<[String]>> = OnceCell::new();
static BASE_OPTS: OnceCell<PgConnectOptions> = OnceCell::new();
static DEFAULT_DB: OnceCell<String> = OnceCell::new();
static POOLS: OnceCell<RwLock<HashMap<String, PgPool>>> = OnceCell::new();
static PG_VERSION: OnceCell<i32> = OnceCell::new();
pub const PG_CATALOG: &str = "pg_catalog";
pub const INFORMATION_SCHEMA: &str = "information_schema";
pub const TEMPLATE0: &str = "template0";
pub const TEMPLATE1: &str = "template1";
pub const MS_TO_SEC: f64 = 1000.0;
const DEFAULT_APPLICATION_NAME: &str = env!("CARGO_PKG_NAME");
#[inline]
#[must_use]
pub fn apply_default_application_name(opts: PgConnectOptions) -> PgConnectOptions {
opts.application_name(DEFAULT_APPLICATION_NAME)
}
pub fn set_excluded_databases(list: Vec<String>) {
let mut cleaned: Vec<String> = list
.into_iter()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
cleaned.dedup();
let _ = EXCLUDED.set(Arc::from(cleaned));
}
#[inline]
pub fn get_excluded_databases() -> &'static [String] {
match EXCLUDED.get() {
Some(arc) => &arc[..],
None => &[],
}
}
#[inline]
#[must_use]
pub fn is_database_excluded(datname: &str) -> bool {
get_excluded_databases().iter().any(|d| d == datname)
}
pub fn set_pg_version(version: i32) {
let _ = PG_VERSION.set(version);
}
#[inline]
pub fn get_pg_version() -> i32 {
PG_VERSION.get().copied().unwrap_or(0)
}
#[inline]
#[must_use]
pub fn is_pg_version_at_least(min_version: i32) -> bool {
get_pg_version() >= min_version
}
pub fn set_base_connect_options_from_dsn(dsn: &SecretString) -> Result<()> {
if BASE_OPTS.get().is_none() {
let opts = apply_default_application_name(PgConnectOptions::from_str(dsn.expose_secret())?);
let _ = BASE_OPTS.set(opts.clone());
let dbname = opts.get_database().unwrap_or("postgres").to_string();
let _ = DEFAULT_DB.set(dbname);
}
if POOLS.get().is_none() {
let _ = POOLS.set(RwLock::new(HashMap::new()));
}
Ok(())
}
#[inline]
pub fn get_default_database() -> Option<&'static str> {
DEFAULT_DB.get().map(std::string::String::as_str)
}
pub fn connect_options_for_db(datname: &str) -> Result<PgConnectOptions> {
let base = BASE_OPTS.get().cloned().ok_or_else(|| {
anyhow!("BASE_OPTS not set; call set_base_connect_options_from_dsn() at startup")
})?;
Ok(base.database(datname))
}
pub async fn get_or_create_pool_for_db(datname: &str) -> Result<PgPool> {
if let Some(def) = get_default_database()
&& def == datname
{
return Err(anyhow!(
"get_or_create_pool_for_db called for default database; use shared pool"
));
}
let pools = POOLS.get().ok_or_else(|| {
anyhow!("Pool cache not initialized; call set_base_connect_options_from_dsn()")
})?;
{
let guard = pools.read().await;
if let Some(pool) = guard.get(datname) {
return Ok(pool.clone());
}
}
let opts = connect_options_for_db(datname)?;
let pool = PgPoolOptions::new()
.max_connections(2)
.min_connections(0)
.acquire_timeout(std::time::Duration::from_secs(5))
.test_before_acquire(false)
.connect_with(opts)
.await?;
{
let mut guard = pools.write().await;
guard.insert(datname.to_string(), pool.clone());
}
Ok(pool)
}
pub async fn drop_cached_pool_for_db(datname: &str) {
if let Some(pools) = POOLS.get() {
let removed = {
let mut guard = pools.write().await;
guard.remove(datname)
};
if let Some(pool) = removed {
pool.close().await;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
#[test]
fn test_set_and_get_exclusions() {
set_excluded_databases(vec![
"postgres".into(),
TEMPLATE0.into(),
TEMPLATE0.into(), " ".into(), ]);
let got = get_excluded_databases();
assert_eq!(got, &["postgres".to_string(), TEMPLATE0.to_string()]);
assert!(is_database_excluded("postgres"));
assert!(!is_database_excluded("not_there"));
}
#[test]
fn test_pg_version_utilities() {
assert_eq!(get_pg_version(), 0);
assert!(!is_pg_version_at_least(140_000));
set_pg_version(160_000); assert_eq!(get_pg_version(), 160_000);
assert!(is_pg_version_at_least(140_000)); assert!(is_pg_version_at_least(160_000)); assert!(!is_pg_version_at_least(170_000)); }
#[test]
fn test_apply_default_application_name_sets_pkg_name() -> Result<()> {
let opts = PgConnectOptions::from_str("postgresql://localhost/postgres")?;
let formatted = format!("{:?}", apply_default_application_name(opts));
assert!(formatted.contains("application_name"));
assert!(formatted.contains(DEFAULT_APPLICATION_NAME));
Ok(())
}
}