use dashmap::DashMap;
use once_cell::sync::Lazy;
use sqlx::PgPool;
use super::executor;
use crate::core::model::Model;
use crate::core::query::QueryBuilder;
tokio::task_local! {
static CURRENT_POOL: PgPool;
static REPLICA_POOL: PgPool;
}
static NAMED_POOLS: Lazy<DashMap<String, PgPool>> = Lazy::new(DashMap::new);
pub fn register_named_pool(name: impl Into<String>, pool: PgPool) {
NAMED_POOLS.insert(name.into(), pool);
}
pub fn get_named_pool(name: &str) -> Option<PgPool> {
NAMED_POOLS.get(name).map(|p| p.clone())
}
pub fn unregister_named_pool(name: &str) {
NAMED_POOLS.remove(name);
}
pub async fn with_pool<F>(pool: PgPool, f: F) -> F::Output
where
F: std::future::Future,
{
CURRENT_POOL.scope(pool, f).await
}
pub async fn with_pools<F>(primary: PgPool, replica: PgPool, f: F) -> F::Output
where
F: std::future::Future,
{
REPLICA_POOL
.scope(replica, async { CURRENT_POOL.scope(primary, f).await })
.await
}
#[cfg(feature = "axum")]
pub(crate) fn scope_pool<F: std::future::Future>(
pool: PgPool,
f: F,
) -> impl std::future::Future<Output = F::Output> {
CURRENT_POOL.scope(pool, f)
}
pub fn try_current_pool() -> Option<PgPool> {
CURRENT_POOL.try_with(|p| p.clone()).ok()
}
pub fn try_replica_pool() -> Option<PgPool> {
REPLICA_POOL.try_with(|p| p.clone()).ok()
}
pub fn resolve_read_pool<T>(builder: &QueryBuilder<T>) -> Result<PgPool, sqlx::Error> {
if builder.use_replica {
if let Some(replica) = try_replica_pool() {
return Ok(replica);
}
}
try_current_pool().ok_or_else(|| {
sqlx::Error::Configuration(
"no database pool in scope — add OrmLayer to your router or \
call pool::with_pool() in tests"
.to_string()
.into(),
)
})
}
pub async fn fetch_all<T>(builder: QueryBuilder<T>) -> Result<Vec<T>, sqlx::Error>
where
T: Model + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> + Send + Unpin,
{
let pool = resolve_read_pool(&builder)?;
executor::fetch_all(&pool, builder).await
}
pub async fn fetch_optional<T>(builder: QueryBuilder<T>) -> Result<Option<T>, sqlx::Error>
where
T: Model + for<'r> sqlx::FromRow<'r, sqlx::postgres::PgRow> + Send + Unpin,
{
let pool = resolve_read_pool(&builder)?;
executor::fetch_optional(&pool, builder).await
}
pub async fn count<T: Model>(builder: QueryBuilder<T>) -> Result<i64, sqlx::Error> {
let pool = resolve_read_pool(&builder)?;
executor::count(&pool, builder).await
}
pub async fn aggregate<T: Model>(
builder: QueryBuilder<T>,
agg_expr: &str,
) -> Result<Option<f64>, sqlx::Error> {
let pool = resolve_read_pool(&builder)?;
executor::aggregate(&pool, builder, agg_expr).await
}
pub async fn warm(n: u32, pool: &PgPool) -> Result<(), sqlx::Error> {
let mut handles = Vec::with_capacity(n as usize);
for _ in 0..n {
handles.push(pool.acquire());
}
let _conns: Vec<_> = futures::future::try_join_all(handles).await?;
Ok(())
}
pub async fn ping(pool: &PgPool) -> bool {
sqlx::query("SELECT 1").execute(pool).await.is_ok()
}
#[derive(Debug, Clone, Copy)]
pub struct PoolMetrics {
pub size: u32,
pub idle: u32,
pub active: u32,
}
pub fn snapshot(pool: &PgPool) -> PoolMetrics {
let size = pool.size();
let idle = pool.num_idle() as u32;
PoolMetrics {
size,
idle,
active: size.saturating_sub(idle),
}
}
#[cfg(feature = "metrics")]
pub fn emit_metrics(pool: &PgPool, label: &str) {
let m = snapshot(pool);
metrics::gauge!("db_pool_size", "pool" => label.to_string()).set(m.size as f64);
metrics::gauge!("db_pool_idle", "pool" => label.to_string()).set(m.idle as f64);
metrics::gauge!("db_pool_active", "pool" => label.to_string()).set(m.active as f64);
}