use crate::config::database::{ShardConnection, ShardingConfig};
use crate::config::DatabaseConfig;
use crate::sharding::shard_for;
use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
use std::sync::Arc;
use std::time::Duration;
pub type DbPool = PgPool;
pub async fn create_pool(config: &DatabaseConfig) -> Result<DbPool, sqlx::Error> {
let pool = PgPoolOptions::new()
.max_connections(config.max_connections)
.min_connections(config.min_connections)
.acquire_timeout(Duration::from_secs(config.acquire_timeout))
.connect_with(config.connect_options())
.await?;
tracing::info!(
host = %config.host,
port = %config.port,
database = %config.database,
max_connections = config.max_connections,
"Database connection pool created"
);
Ok(pool)
}
pub async fn health_check(pool: &DbPool) -> bool {
sqlx::query("SELECT 1").execute(pool).await.is_ok()
}
#[derive(Debug, Clone)]
pub struct DbPoolMap {
shards: Arc<Vec<DbPool>>,
cluster: DbPool,
shard_count: u32,
single_pool_mode: bool,
}
impl DbPoolMap {
pub async fn new(
legacy: &DatabaseConfig,
sharding: &ShardingConfig,
) -> Result<Self, sqlx::Error> {
if sharding.is_disabled() {
let pool = create_pool(legacy).await?;
tracing::info!("DbPoolMap: single-pool fallback (NOETL_SHARDS empty)");
return Ok(Self {
shards: Arc::new(vec![pool.clone()]),
cluster: pool,
shard_count: 1,
single_pool_mode: true,
});
}
let mut shard_pools = Vec::with_capacity(sharding.shards.len());
for (idx, conn) in sharding.shards.iter().enumerate() {
let pool = build_pool(legacy, conn).await.inspect_err(|e| {
tracing::error!(
shard_index = idx,
host = %conn.host,
error = %e,
"DbPoolMap: failed to build shard pool"
);
})?;
tracing::info!(
shard_index = idx,
host = %conn.host,
port = %conn.port,
database = %conn.database,
"DbPoolMap: shard pool ready"
);
shard_pools.push(pool);
}
let cluster = match &sharding.cluster {
Some(conn) => {
let pool = build_pool(legacy, conn).await.inspect_err(|e| {
tracing::error!(
host = %conn.host,
error = %e,
"DbPoolMap: failed to build cluster pool"
);
})?;
tracing::info!(
host = %conn.host,
port = %conn.port,
database = %conn.database,
"DbPoolMap: cluster pool ready"
);
pool
}
None => {
tracing::warn!(
"DbPoolMap: NOETL_CLUSTER_DSN unset; cluster-wide queries \
ride shard 0's pool (single-node kind topology)"
);
shard_pools[0].clone()
}
};
let shard_count = shard_pools.len() as u32;
Ok(Self {
shards: Arc::new(shard_pools),
cluster,
shard_count,
single_pool_mode: false,
})
}
pub fn from_single_pool(pool: DbPool) -> Self {
Self {
shards: Arc::new(vec![pool.clone()]),
cluster: pool,
shard_count: 1,
single_pool_mode: true,
}
}
pub fn shard_count(&self) -> u32 {
self.shard_count
}
pub fn is_single_pool(&self) -> bool {
self.single_pool_mode
}
pub fn pool_for(&self, execution_id: i64) -> &DbPool {
if self.shard_count <= 1 {
return &self.shards[0];
}
let idx = shard_for(execution_id, self.shard_count) as usize;
&self.shards[idx]
}
pub fn cluster(&self) -> &DbPool {
&self.cluster
}
pub fn all_shards(&self) -> impl Iterator<Item = (u32, &DbPool)> {
self.shards
.iter()
.enumerate()
.map(|(idx, pool)| (idx as u32, pool))
}
pub async fn for_each_shard<F, Fut, T, E>(&self, mut f: F) -> Result<Vec<(u32, T)>, E>
where
F: FnMut(u32, DbPool) -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
{
let mut out = Vec::with_capacity(self.shard_count as usize);
for (idx, pool) in self.all_shards() {
let result = f(idx, pool.clone()).await?;
out.push((idx, result));
}
Ok(out)
}
pub async fn find_first<F, Fut, T, E>(&self, mut f: F) -> Result<Option<(u32, T)>, E>
where
F: FnMut(u32, DbPool) -> Fut,
Fut: std::future::Future<Output = Result<Option<T>, E>>,
{
let results = self.for_each_shard(|idx, pool| f(idx, pool)).await?;
Ok(results
.into_iter()
.find_map(|(idx, opt)| opt.map(|t| (idx, t))))
}
}
async fn build_pool(
legacy: &DatabaseConfig,
conn: &ShardConnection,
) -> Result<DbPool, sqlx::Error> {
PgPoolOptions::new()
.max_connections(legacy.max_connections)
.min_connections(legacy.min_connections)
.acquire_timeout(Duration::from_secs(legacy.acquire_timeout))
.connect_with(conn.connect_options())
.await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_type_alias() {
fn _assert_type(_: DbPool) {}
}
#[test]
fn pool_for_routing_math_matches_drift_guard_pairs() {
assert_eq!(shard_for(1, 2), 1);
assert_eq!(shard_for(1, 4), 1);
assert_eq!(shard_for(1, 16), 5);
assert_eq!(shard_for(1, 64), 21);
assert_eq!(shard_for(1, 1024), 405);
}
#[test]
fn pool_for_degenerate_shard_count_short_circuits() {
assert_eq!(shard_for(42, 1), 0);
assert_eq!(shard_for(9_999_999_999, 1), 0);
assert_eq!(shard_for(-1, 1), 0);
}
fn dummy_pool() -> DbPool {
use sqlx::postgres::PgConnectOptions;
PgPoolOptions::new()
.max_connections(1)
.connect_lazy_with(PgConnectOptions::new().host("localhost"))
}
#[tokio::test]
async fn from_single_pool_marks_fallback_mode() {
let pool = dummy_pool();
let map = DbPoolMap::from_single_pool(pool);
assert!(map.is_single_pool());
assert_eq!(map.shard_count(), 1);
assert_eq!(map.all_shards().count(), 1);
}
#[tokio::test]
async fn from_single_pool_pool_for_does_not_panic_on_negative_eid() {
let map = DbPoolMap::from_single_pool(dummy_pool());
let _ = map.pool_for(-1);
let _ = map.pool_for(i64::MAX);
let _ = map.pool_for(0);
}
#[tokio::test]
async fn for_each_shard_runs_closure_once_per_shard_in_order() {
let map = DbPoolMap::from_single_pool(dummy_pool());
let observed: Vec<u32> = map
.for_each_shard::<_, _, u32, sqlx::Error>(|idx, _pool| async move { Ok(idx) })
.await
.expect("ok")
.into_iter()
.map(|(idx, _)| idx)
.collect();
assert_eq!(observed, vec![0]);
}
#[tokio::test]
async fn for_each_shard_propagates_first_error() {
let map = DbPoolMap::from_single_pool(dummy_pool());
let err = map
.for_each_shard::<_, _, (), &'static str>(|_idx, _pool| async move {
Err("kaboom")
})
.await
.unwrap_err();
assert_eq!(err, "kaboom");
}
#[tokio::test]
async fn find_first_returns_none_when_no_shard_matches() {
let map = DbPoolMap::from_single_pool(dummy_pool());
let out: Option<(u32, i64)> = map
.find_first::<_, _, i64, sqlx::Error>(|_idx, _pool| async move { Ok(None) })
.await
.expect("ok");
assert!(out.is_none());
}
#[tokio::test]
async fn find_first_returns_first_match_with_shard_index() {
let map = DbPoolMap::from_single_pool(dummy_pool());
let out: Option<(u32, &'static str)> = map
.find_first::<_, _, &'static str, sqlx::Error>(|_idx, _pool| async move {
Ok(Some("hit"))
})
.await
.expect("ok");
assert_eq!(out, Some((0, "hit")));
}
}