use std::collections::HashMap;
use std::sync::Arc;
#[cfg(feature = "postgres")]
use crate::sql::sqlx::postgres::{PgPool, PgPoolOptions};
use crate::sql::sqlx::{self, Database};
use tokio::sync::RwLock;
use super::error::TenancyError;
use super::org::{Org, StorageMode};
use super::secrets::{LiteralSecretsResolver, SecretsResolver};
#[derive(Debug, Clone)]
pub struct TenantPoolsConfig {
pub max_cached_database_pools: usize,
pub database_pool_max_connections: u32,
pub database_pool_min_connections: u32,
pub database_pool_acquire_timeout: std::time::Duration,
pub database_pool_idle_timeout: Option<std::time::Duration>,
pub database_pool_max_lifetime: Option<std::time::Duration>,
pub prewarm_active_tenants: bool,
}
impl Default for TenantPoolsConfig {
fn default() -> Self {
Self {
max_cached_database_pools: 64,
database_pool_max_connections: 4,
database_pool_min_connections: 0,
database_pool_acquire_timeout: std::time::Duration::from_secs(30),
database_pool_idle_timeout: Some(std::time::Duration::from_secs(10 * 60)),
database_pool_max_lifetime: Some(std::time::Duration::from_secs(30 * 60)),
prewarm_active_tenants: false,
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct PrewarmReport {
pub total_active: usize,
pub warmed: usize,
pub failed: usize,
pub skipped_cap: usize,
}
#[cfg(feature = "postgres")]
pub type DefaultTenantDb = sqlx::Postgres;
#[cfg(all(not(feature = "postgres"), feature = "sqlite"))]
pub type DefaultTenantDb = sqlx::Sqlite;
#[cfg(all(not(feature = "postgres"), not(feature = "sqlite"), feature = "mysql"))]
pub type DefaultTenantDb = sqlx::MySql;
#[derive(Debug)]
pub enum TenantPool<DB: Database = DefaultTenantDb> {
#[cfg(feature = "postgres")]
Schema { schema: String, registry: PgPool },
Database { pool: Arc<sqlx::Pool<DB>> },
}
impl<DB: Database> Clone for TenantPool<DB> {
fn clone(&self) -> Self {
match self {
#[cfg(feature = "postgres")]
Self::Schema { schema, registry } => Self::Schema {
schema: schema.clone(),
registry: registry.clone(),
},
Self::Database { pool } => Self::Database {
pool: Arc::clone(pool),
},
}
}
}
impl<DB: Database> TenantPool<DB> {
#[must_use]
pub fn is_schema(&self) -> bool {
#[cfg(feature = "postgres")]
{
matches!(self, Self::Schema { .. })
}
#[cfg(not(feature = "postgres"))]
{
false
}
}
}
#[cfg(feature = "postgres")]
impl TenantPool<sqlx::Postgres> {
#[must_use]
pub fn pool(&self) -> &PgPool {
match self {
Self::Schema { registry, .. } => registry,
Self::Database { pool } => pool,
}
}
}
pub struct TenantPools<DB: Database = DefaultTenantDb> {
registry: sqlx::Pool<DB>,
config: TenantPoolsConfig,
secrets: Arc<dyn SecretsResolver>,
cache: RwLock<HashMap<String, Arc<sqlx::Pool<DB>>>>,
}
impl<DB: Database> TenantPools<DB> {
#[must_use]
pub fn new(registry: sqlx::Pool<DB>) -> Self {
Self::with_secrets(registry, LiteralSecretsResolver)
}
#[must_use]
pub fn with_secrets<R: SecretsResolver>(registry: sqlx::Pool<DB>, secrets: R) -> Self {
Self {
registry,
config: TenantPoolsConfig::default(),
secrets: Arc::new(secrets),
cache: RwLock::new(HashMap::new()),
}
}
#[must_use]
pub fn config(mut self, config: TenantPoolsConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub fn pool_config(&self) -> &TenantPoolsConfig {
&self.config
}
#[must_use]
pub fn registry_inner(&self) -> &sqlx::Pool<DB> {
&self.registry
}
}
#[cfg(feature = "postgres")]
impl TenantPools<sqlx::Postgres> {
#[must_use]
pub fn registry(&self) -> &PgPool {
&self.registry
}
}
impl<DB: Database> TenantPools<DB>
where
crate::sql::Pool: From<sqlx::Pool<DB>>,
{
#[must_use]
pub fn registry_pool(&self) -> crate::sql::Pool {
crate::sql::Pool::from(self.registry.clone())
}
}
impl<DB: Database> TenantPools<DB> {
pub async fn database_pool_for_org(&self, org: &Org) -> Result<TenantPool<DB>, TenancyError> {
let mode = StorageMode::parse(&org.storage_mode).map_err(|got| {
TenancyError::Validation(format!(
"org `{}` has unknown storage_mode `{got}` (expected `schema` or `database`)",
org.slug
))
})?;
match mode {
StorageMode::Schema => Err(TenancyError::Validation(format!(
"org `{}` has `storage_mode = 'schema'` but TenantPools<{dbname}> is non-Postgres. \
Schema-mode is a Postgres-only optimization (uses `SET search_path` — no \
equivalent on MySQL/SQLite). Switch this org to `storage_mode = 'database'` and \
set `database_url` to its dedicated database / file; isolation semantics are \
equivalent.",
org.slug,
dbname = std::any::type_name::<DB>(),
))),
StorageMode::Database => {
let pool = self.pool_for_database_mode(org).await?;
Ok(TenantPool::Database { pool })
}
}
}
pub async fn database_acquire(&self, org: &Org) -> Result<TenantConn<DB>, TenancyError> {
let pool = self.database_pool_for_org(org).await?;
let TenantPool::Database { pool } = pool else {
unreachable!("database_pool_for_org rejects schema-mode")
};
let conn = pool.acquire().await?;
Ok(TenantConn {
inner: conn,
schema: None,
})
}
pub async fn invalidate(&self, slug: &str) {
let mut cache = self.cache.write().await;
cache.remove(slug);
}
}
pub trait TenantPoolInvalidator: Send + Sync {
fn invalidate<'a>(
&'a self,
slug: &'a str,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'a>>;
}
impl<DB: Database> TenantPoolInvalidator for TenantPools<DB> {
fn invalidate<'a>(
&'a self,
slug: &'a str,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'a>> {
Box::pin(async move { TenantPools::<DB>::invalidate(self, slug).await })
}
}
#[allow(dead_code)]
impl<DB: Database> TenantPools<DB> {
#[must_use]
pub fn into_invalidator(self: Arc<Self>) -> Arc<dyn TenantPoolInvalidator> {
self
}
pub async fn resolved_database_url(&self, org: &Org) -> Result<String, TenancyError> {
let reference = org.database_url.as_deref().ok_or_else(|| {
TenancyError::Validation(format!(
"org `{}` has no `database_url` to resolve (schema mode?)",
org.slug
))
})?;
let url = self.secrets.resolve(reference).await?;
Ok(url)
}
pub async fn cached_database_pool_count(&self) -> usize {
self.cache.read().await.len()
}
async fn pool_for_database_mode(&self, org: &Org) -> Result<Arc<sqlx::Pool<DB>>, TenancyError> {
{
let cache = self.cache.read().await;
if let Some(pool) = cache.get(&org.slug) {
return Ok(Arc::clone(pool));
}
}
let span = tracing::info_span!("tenant_pool_init", slug = %org.slug, mode = "database");
let _enter = span.enter();
let resolve_start = std::time::Instant::now();
let reference = org.database_url.as_deref().ok_or_else(|| {
TenancyError::Validation(format!(
"org `{}` is `storage_mode = database` but has no `database_url`",
org.slug
))
})?;
let url = self.secrets.resolve(reference).await?;
tracing::debug!(
target: "crate::tenancy::pools",
slug = %org.slug,
elapsed_ms = resolve_start.elapsed().as_millis() as u64,
"secrets resolver resolved tenant URL",
);
let connect_start = std::time::Instant::now();
let pool = build_database_pool::<DB>(&url, &self.config).await?;
tracing::info!(
target: "crate::tenancy::pools",
slug = %org.slug,
elapsed_ms = connect_start.elapsed().as_millis() as u64,
min_conn = self.config.database_pool_min_connections,
max_conn = self.config.database_pool_max_connections,
"tenant pool connected (database mode)",
);
let pool = Arc::new(pool);
let mut cache = self.cache.write().await;
if let Some(existing) = cache.get(&org.slug) {
return Ok(Arc::clone(existing));
}
if cache.len() >= self.config.max_cached_database_pools {
return Err(TenancyError::Validation(format!(
"tenant pool cache is full ({} cached); raise \
`TenantPoolsConfig::max_cached_database_pools` or \
invalidate idle tenants",
cache.len(),
)));
}
cache.insert(org.slug.clone(), Arc::clone(&pool));
Ok(pool)
}
}
impl<DB: Database> TenantPools<DB>
where
crate::sql::Pool: From<sqlx::Pool<DB>>,
{
pub async fn scoped_pool_dyn(&self, org: &Org) -> Result<crate::sql::Pool, TenancyError> {
let mode = StorageMode::parse(&org.storage_mode).map_err(|got| {
TenancyError::Validation(format!(
"org `{}` has unknown storage_mode `{got}` (expected `schema` or `database`)",
org.slug
))
})?;
match mode {
StorageMode::Schema => {
#[cfg(feature = "postgres")]
{
if let Some(pg_pools) =
(self as &dyn std::any::Any).downcast_ref::<TenantPools<sqlx::Postgres>>()
{
let scoped = pg_pools.scoped_pool(org).await?;
return Ok(crate::sql::Pool::Postgres(scoped));
}
}
Err(TenancyError::Validation(format!(
"org `{}` has `storage_mode = 'schema'` but TenantPools<{dbname}> is \
non-Postgres. Schema-mode is a Postgres-only optimization (uses \
`SET search_path` — no equivalent on MySQL/SQLite). Switch this org to \
`storage_mode = 'database'` and set `database_url` to its dedicated \
database / file; isolation semantics are equivalent.",
org.slug,
dbname = std::any::type_name::<DB>(),
)))
}
StorageMode::Database => {
let pool = self.pool_for_database_mode(org).await?;
Ok(crate::sql::Pool::from((*pool).clone()))
}
}
}
pub async fn prewarm_database_tenants(&self) -> Result<PrewarmReport, TenancyError> {
use crate::core::Column as _;
use crate::sql::FetcherPool as _;
let span = tracing::info_span!("tenant_pools_prewarm");
let _enter = span.enter();
let started = std::time::Instant::now();
let registry_pool = self.registry_pool();
let orgs: Vec<Org> = Org::objects()
.where_(Org::storage_mode.eq("database".to_owned()))
.where_(Org::active.eq(true))
.fetch_pool(®istry_pool)
.await?;
let total = orgs.len();
let mut report = PrewarmReport {
total_active: total,
warmed: 0,
failed: 0,
skipped_cap: 0,
};
for org in orgs {
if self.cache.read().await.len() >= self.config.max_cached_database_pools {
tracing::warn!(
target: "crate::tenancy::pools",
slug = %org.slug,
cap = self.config.max_cached_database_pools,
"skipping pre-warm: cache cap reached",
);
report.skipped_cap += 1;
continue;
}
match self.pool_for_database_mode(&org).await {
Ok(_) => report.warmed += 1,
Err(e) => {
tracing::warn!(
target: "crate::tenancy::pools",
slug = %org.slug,
error = %e,
"pre-warm failed for tenant",
);
report.failed += 1;
}
}
}
tracing::info!(
target: "crate::tenancy::pools",
elapsed_ms = started.elapsed().as_millis() as u64,
total = report.total_active,
warmed = report.warmed,
failed = report.failed,
skipped_cap = report.skipped_cap,
"prewarm complete",
);
Ok(report)
}
}
#[cfg(feature = "postgres")]
impl TenantPools<sqlx::Postgres> {
pub async fn pool_for_org(
&self,
org: &Org,
) -> Result<TenantPool<sqlx::Postgres>, TenancyError> {
let mode = StorageMode::parse(&org.storage_mode).map_err(|got| {
TenancyError::Validation(format!(
"org `{}` has unknown storage_mode `{got}` (expected `schema` or `database`)",
org.slug
))
})?;
match mode {
StorageMode::Schema => {
let schema = org.schema_name.clone().unwrap_or_else(|| org.slug.clone());
Ok(TenantPool::Schema {
schema,
registry: self.registry.clone(),
})
}
StorageMode::Database => {
let pool = self.pool_for_database_mode(org).await?;
Ok(TenantPool::Database { pool })
}
}
}
pub async fn acquire(&self, org: &Org) -> Result<TenantConn<sqlx::Postgres>, TenancyError> {
let pool = self.pool_for_org(org).await?;
match &pool {
TenantPool::Schema { schema, registry } => {
let mut conn = registry.acquire().await?;
let stmt = format!("SET search_path TO {}, public", quote_ident(schema));
rustango::sql::sqlx::query(&stmt)
.execute(&mut *conn)
.await?;
Ok(TenantConn {
inner: conn,
schema: Some(schema.clone()),
})
}
TenantPool::Database { pool } => {
let conn = pool.acquire().await?;
Ok(TenantConn {
inner: conn,
schema: None,
})
}
}
}
pub async fn scoped_pool(&self, org: &Org) -> Result<PgPool, TenancyError> {
match self.pool_for_org(org).await? {
TenantPool::Schema { schema, registry } => {
let mut opts = (*registry.connect_options()).clone();
opts = opts.options([("search_path", &format!("{schema},public") as &str)]);
let scoped = PgPoolOptions::new()
.max_connections(2)
.connect_with(opts)
.await?;
Ok(scoped)
}
TenantPool::Database { pool } => Ok((*pool).clone()),
}
}
}
async fn build_database_pool<DB: Database>(
url: &str,
config: &TenantPoolsConfig,
) -> Result<sqlx::Pool<DB>, TenancyError> {
let mut opts = sqlx::pool::PoolOptions::<DB>::new()
.max_connections(config.database_pool_max_connections)
.min_connections(config.database_pool_min_connections)
.acquire_timeout(config.database_pool_acquire_timeout);
if let Some(idle) = config.database_pool_idle_timeout {
opts = opts.idle_timeout(idle);
}
if let Some(lifetime) = config.database_pool_max_lifetime {
opts = opts.max_lifetime(lifetime);
}
Ok(opts.connect(url).await?)
}
pub struct TenantConn<DB: Database = DefaultTenantDb> {
inner: sqlx::pool::PoolConnection<DB>,
schema: Option<String>,
}
impl<DB: Database> TenantConn<DB> {
#[must_use]
pub fn schema(&self) -> Option<&str> {
self.schema.as_deref()
}
}
impl<DB: Database> std::ops::Deref for TenantConn<DB> {
type Target = sqlx::pool::PoolConnection<DB>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<DB: Database> std::ops::DerefMut for TenantConn<DB> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
#[cfg(feature = "postgres")]
fn quote_ident(name: &str) -> String {
let escaped = name.replace('"', "\"\"");
format!("\"{escaped}\"")
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "postgres")]
#[test]
fn quote_ident_wraps_and_escapes() {
assert_eq!(quote_ident("acme"), "\"acme\"");
assert_eq!(quote_ident("a\"b"), "\"a\"\"b\"");
assert_eq!(quote_ident(""), "\"\"");
}
#[test]
fn config_defaults_are_sane() {
let c = TenantPoolsConfig::default();
assert!(c.max_cached_database_pools >= 16);
assert!(c.database_pool_max_connections >= 1);
}
#[test]
fn config_pool_timeout_defaults_preserve_pre_0_27_7_behavior() {
let c = TenantPoolsConfig::default();
assert!(!c.prewarm_active_tenants);
assert_eq!(c.database_pool_min_connections, 0);
assert!(c.database_pool_acquire_timeout >= std::time::Duration::from_secs(5));
assert!(c.database_pool_idle_timeout.is_some());
assert!(c.database_pool_max_lifetime.is_some());
}
#[cfg(feature = "postgres")]
#[test]
fn prewarm_report_zeroed_default() {
let r = PrewarmReport::default();
assert_eq!(r.total_active, 0);
assert_eq!(r.warmed, 0);
assert_eq!(r.failed, 0);
assert_eq!(r.skipped_cap, 0);
}
}