use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration;
use sqlx::ConnectOptions;
use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};
use tokio::sync::broadcast;
use tokio::task::JoinHandle;
use tracing::log::LevelFilter;
use forge_core::config::DatabaseConfig;
use forge_core::error::{ForgeError, Result};
struct ReplicaEntry {
pool: Arc<PgPool>,
healthy: Arc<AtomicBool>,
}
#[derive(Clone)]
pub struct Database {
primary: Arc<PgPool>,
replicas: Arc<Vec<ReplicaEntry>>,
config: DatabaseConfig,
replica_counter: Arc<AtomicUsize>,
}
impl Database {
pub async fn from_config(config: &DatabaseConfig) -> Result<Self> {
Self::from_config_with_service(config, "forge").await
}
pub async fn from_config_with_service(
config: &DatabaseConfig,
service_name: &str,
) -> Result<Self> {
if config.url.is_empty() {
return Err(ForgeError::internal(
"database.url cannot be empty. Provide a PostgreSQL connection URL.",
));
}
let primary = Self::create_pool_with_statement_timeout(
&config.url,
config.pool_size,
config.min_pool_size,
config.pool_timeout.as_secs(),
config.statement_timeout.as_secs(),
config.test_before_acquire,
service_name,
)
.await
.map_err(|e| ForgeError::internal_with("Failed to connect to primary", e))?;
verify_postgres_version(&primary, "primary").await?;
detect_pgbouncer(&primary).await?;
let mut replicas = Vec::new();
for replica_url in &config.replica_urls {
let pool = Self::create_pool(
replica_url,
config.replica_pool_size.unwrap_or(config.pool_size / 2),
config.pool_timeout.as_secs(),
service_name,
)
.await
.map_err(|e| ForgeError::internal_with("Failed to connect to replica", e))?;
verify_postgres_version(&pool, "replica").await?;
replicas.push(ReplicaEntry {
pool: Arc::new(pool),
healthy: Arc::new(AtomicBool::new(true)),
});
}
Ok(Self {
primary: Arc::new(primary),
replicas: Arc::new(replicas),
config: config.clone(),
replica_counter: Arc::new(AtomicUsize::new(0)),
})
}
fn connect_options(url: &str, service_name: &str) -> sqlx::Result<PgConnectOptions> {
let options: PgConnectOptions = url.parse()?;
Ok(options
.application_name(service_name)
.log_statements(LevelFilter::Off)
.log_slow_statements(LevelFilter::Warn, Duration::from_millis(500)))
}
fn connect_options_with_timeout(
url: &str,
service_name: &str,
statement_timeout_secs: u64,
) -> sqlx::Result<PgConnectOptions> {
let options: PgConnectOptions = url.parse()?;
let mut opts = options
.application_name(service_name)
.log_statements(LevelFilter::Off)
.log_slow_statements(LevelFilter::Warn, Duration::from_millis(500));
if statement_timeout_secs > 0 {
opts = opts.options([("statement_timeout", &format!("{}s", statement_timeout_secs))]);
}
Ok(opts)
}
async fn create_pool(
url: &str,
size: u32,
timeout_secs: u64,
service_name: &str,
) -> sqlx::Result<PgPool> {
Self::create_pool_with_opts(url, size, 0, timeout_secs, true, service_name).await
}
async fn create_pool_with_opts(
url: &str,
size: u32,
min_size: u32,
timeout_secs: u64,
test_before_acquire: bool,
service_name: &str,
) -> sqlx::Result<PgPool> {
Self::create_pool_with_statement_timeout(
url,
size,
min_size,
timeout_secs,
0,
test_before_acquire,
service_name,
)
.await
}
async fn create_pool_with_statement_timeout(
url: &str,
size: u32,
min_size: u32,
timeout_secs: u64,
statement_timeout_secs: u64,
test_before_acquire: bool,
service_name: &str,
) -> sqlx::Result<PgPool> {
let options = if statement_timeout_secs > 0 {
Self::connect_options_with_timeout(url, service_name, statement_timeout_secs)?
} else {
Self::connect_options(url, service_name)?
};
PgPoolOptions::new()
.max_connections(size)
.min_connections(min_size)
.acquire_timeout(Duration::from_secs(timeout_secs))
.test_before_acquire(test_before_acquire)
.connect_with(options)
.await
}
pub fn primary(&self) -> &PgPool {
&self.primary
}
pub fn read_pool(&self) -> &PgPool {
if !self.config.read_from_replica || self.replicas.is_empty() {
return &self.primary;
}
let len = self.replicas.len();
let start = self.replica_counter.fetch_add(1, Ordering::Relaxed) % len;
for offset in 0..len {
let idx = (start + offset) % len;
if let Some(entry) = self.replicas.get(idx)
&& entry.healthy.load(Ordering::Relaxed)
{
return &entry.pool;
}
}
&self.primary
}
pub fn start_health_monitor(
&self,
mut shutdown_rx: broadcast::Receiver<()>,
) -> Option<JoinHandle<()>> {
if self.replicas.is_empty() {
return None;
}
let replicas = Arc::clone(&self.replicas);
let handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(15));
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
tracing::debug!("Replica health monitor shutting down");
break;
}
_ = interval.tick() => {
for entry in replicas.iter() {
let ok = sqlx::query_scalar!("SELECT 1 as \"v!\"")
.fetch_one(entry.pool.as_ref())
.await
.is_ok();
let was_healthy = entry.healthy.swap(ok, Ordering::Relaxed);
if was_healthy && !ok {
tracing::warn!("Replica marked unhealthy");
} else if !was_healthy && ok {
tracing::info!("Replica recovered");
}
}
}
}
}
});
Some(handle)
}
#[doc(hidden)]
pub fn from_pool(pool: PgPool) -> Self {
Self {
primary: Arc::new(pool),
replicas: Arc::new(Vec::new()),
config: DatabaseConfig::default(),
replica_counter: Arc::new(AtomicUsize::new(0)),
}
}
pub async fn health_check(&self) -> Result<()> {
sqlx::query_scalar!("SELECT 1 as \"v!\"")
.fetch_one(self.primary.as_ref())
.await
.map_err(|e| ForgeError::internal_with("Health check failed", e))?;
Ok(())
}
pub async fn close(&self) {
self.primary.close().await;
for entry in self.replicas.iter() {
entry.pool.close().await;
}
}
}
pub const MIN_POSTGRES_MAJOR: i32 = 18;
async fn verify_postgres_version(pool: &PgPool, role: &str) -> Result<()> {
let row = sqlx::query_scalar!("SELECT current_setting('server_version_num')")
.fetch_one(pool)
.await
.map_err(|e| {
ForgeError::internal(format!(
"Failed to read PostgreSQL server_version_num from {}: {}",
role, e
))
})?;
let version_num: i32 = row
.ok_or_else(|| {
ForgeError::internal(format!(
"PostgreSQL {} returned NULL for server_version_num",
role
))
})?
.parse()
.map_err(|e| {
ForgeError::internal(format!(
"Could not parse PostgreSQL server_version_num from {}: {}",
role, e
))
})?;
let major = version_num / 10_000;
if major < MIN_POSTGRES_MAJOR {
return Err(ForgeError::internal(format!(
"PostgreSQL {} is at version {} but Forge requires {} or newer. \
See https://forge.dev/scale/hosting for supported versions.",
role, major, MIN_POSTGRES_MAJOR
)));
}
tracing::debug!(role, postgres_major = major, "PostgreSQL version verified");
Ok(())
}
async fn detect_pgbouncer(pool: &PgPool) -> Result<()> {
#[allow(clippy::disallowed_methods)]
let (backend_pid, version_str): (i32, String) =
sqlx::query_as("SELECT pg_backend_pid(), version()")
.fetch_one(pool)
.await
.map_err(|e| {
ForgeError::internal(format!(
"PgBouncer detection query failed: {}. \
Forge requires a direct PostgreSQL connection.",
e
))
})?;
let version_lower = version_str.to_lowercase();
if backend_pid == 0 || version_lower.contains("pgbouncer") {
return Err(ForgeError::config(
"Forge detected a PgBouncer proxy in the connection path. \
Forge requires direct PostgreSQL connections because it relies on \
`pg_try_advisory_lock` (for leader election) and persistent `LISTEN/NOTIFY` \
listeners (for real-time reactivity). Both break under PgBouncer's transaction \
pooling mode. Connect directly to PostgreSQL, or use a session-level pooler \
that preserves connection identity (e.g. pgcat in query-router mode).",
));
}
tracing::debug!(
backend_pid,
"Direct PostgreSQL connection confirmed (no PgBouncer)"
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_database_config_clone() {
let config = DatabaseConfig::new("postgres://localhost/test");
let cloned = config.clone();
assert_eq!(cloned.url(), config.url());
assert_eq!(cloned.pool_size, config.pool_size);
}
}