use std::collections::HashMap;
use std::sync::Arc;
use crate::sql::sqlx::postgres::{PgPool, PgPoolOptions};
use tokio::sync::RwLock;
use super::error::TenancyError;
use super::org::{Org, StorageMode};
use super::secrets::{LiteralSecretsResolver, SecretsResolver};
#[derive(Debug, Clone)]
pub struct TenantPoolsConfig {
pub max_cached_database_pools: usize,
pub database_pool_max_connections: u32,
pub database_pool_min_connections: u32,
pub database_pool_acquire_timeout: std::time::Duration,
pub database_pool_idle_timeout: Option<std::time::Duration>,
pub database_pool_max_lifetime: Option<std::time::Duration>,
pub prewarm_active_tenants: bool,
}
impl Default for TenantPoolsConfig {
fn default() -> Self {
Self {
max_cached_database_pools: 64,
database_pool_max_connections: 4,
database_pool_min_connections: 0,
database_pool_acquire_timeout: std::time::Duration::from_secs(30),
database_pool_idle_timeout: Some(std::time::Duration::from_secs(10 * 60)),
database_pool_max_lifetime: Some(std::time::Duration::from_secs(30 * 60)),
prewarm_active_tenants: false,
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct PrewarmReport {
pub total_active: usize,
pub warmed: usize,
pub failed: usize,
pub skipped_cap: usize,
}
#[derive(Debug, Clone)]
pub enum TenantPool {
Schema { schema: String, registry: PgPool },
Database { pool: Arc<PgPool> },
}
impl TenantPool {
#[must_use]
pub fn pool(&self) -> &PgPool {
match self {
Self::Schema { registry, .. } => registry,
Self::Database { pool } => pool,
}
}
#[must_use]
pub fn is_schema(&self) -> bool {
matches!(self, Self::Schema { .. })
}
}
pub struct TenantPools {
registry: PgPool,
config: TenantPoolsConfig,
secrets: Arc<dyn SecretsResolver>,
cache: RwLock<HashMap<String, Arc<PgPool>>>,
}
impl TenantPools {
#[must_use]
pub fn new(registry: PgPool) -> Self {
Self::with_secrets(registry, LiteralSecretsResolver)
}
#[must_use]
pub fn with_secrets<R: SecretsResolver>(registry: PgPool, secrets: R) -> Self {
Self {
registry,
config: TenantPoolsConfig::default(),
secrets: Arc::new(secrets),
cache: RwLock::new(HashMap::new()),
}
}
#[must_use]
pub fn config(mut self, config: TenantPoolsConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub fn pool_config(&self) -> &TenantPoolsConfig {
&self.config
}
#[must_use]
pub fn registry(&self) -> &PgPool {
&self.registry
}
pub async fn pool_for_org(&self, org: &Org) -> Result<TenantPool, TenancyError> {
let mode = StorageMode::parse(&org.storage_mode).map_err(|got| {
TenancyError::Validation(format!(
"org `{}` has unknown storage_mode `{got}` (expected `schema` or `database`)",
org.slug
))
})?;
match mode {
StorageMode::Schema => {
let schema = org.schema_name.clone().unwrap_or_else(|| org.slug.clone());
Ok(TenantPool::Schema {
schema,
registry: self.registry.clone(),
})
}
StorageMode::Database => {
let pool = self.pool_for_database_mode(org).await?;
Ok(TenantPool::Database { pool })
}
}
}
pub async fn acquire(&self, org: &Org) -> Result<TenantConn, TenancyError> {
let pool = self.pool_for_org(org).await?;
match &pool {
TenantPool::Schema { schema, registry } => {
let mut conn = registry.acquire().await?;
let stmt = format!("SET search_path TO {}, public", quote_ident(schema));
rustango::sql::sqlx::query(&stmt)
.execute(&mut *conn)
.await?;
Ok(TenantConn {
inner: conn,
schema: Some(schema.clone()),
})
}
TenantPool::Database { pool } => {
let conn = pool.acquire().await?;
Ok(TenantConn {
inner: conn,
schema: None,
})
}
}
}
pub async fn scoped_pool(&self, org: &Org) -> Result<PgPool, TenancyError> {
match self.pool_for_org(org).await? {
TenantPool::Schema { schema, registry } => {
let mut opts = (*registry.connect_options()).clone();
opts = opts.options([("search_path", &format!("{schema},public") as &str)]);
let scoped = PgPoolOptions::new()
.max_connections(2)
.connect_with(opts)
.await?;
Ok(scoped)
}
TenantPool::Database { pool } => Ok((*pool).clone()),
}
}
pub async fn invalidate(&self, slug: &str) {
let mut cache = self.cache.write().await;
cache.remove(slug);
}
pub async fn prewarm_database_tenants(&self) -> Result<PrewarmReport, TenancyError> {
use crate::core::Column as _;
use crate::sql::Fetcher;
let span = tracing::info_span!("tenant_pools_prewarm");
let _enter = span.enter();
let started = std::time::Instant::now();
let orgs: Vec<Org> = Org::objects()
.where_(Org::storage_mode.eq("database".to_owned()))
.where_(Org::active.eq(true))
.fetch(&self.registry)
.await?;
let total = orgs.len();
let mut report = PrewarmReport {
total_active: total,
warmed: 0,
failed: 0,
skipped_cap: 0,
};
for org in orgs {
if self.cache.read().await.len() >= self.config.max_cached_database_pools {
tracing::warn!(
target: "crate::tenancy::pools",
slug = %org.slug,
cap = self.config.max_cached_database_pools,
"skipping pre-warm: cache cap reached",
);
report.skipped_cap += 1;
continue;
}
match self.pool_for_database_mode(&org).await {
Ok(_) => report.warmed += 1,
Err(e) => {
tracing::warn!(
target: "crate::tenancy::pools",
slug = %org.slug,
error = %e,
"pre-warm failed for tenant",
);
report.failed += 1;
}
}
}
tracing::info!(
target: "crate::tenancy::pools",
elapsed_ms = started.elapsed().as_millis() as u64,
total = report.total_active,
warmed = report.warmed,
failed = report.failed,
skipped_cap = report.skipped_cap,
"prewarm complete",
);
Ok(report)
}
pub async fn resolved_database_url(&self, org: &Org) -> Result<String, TenancyError> {
let reference = org.database_url.as_deref().ok_or_else(|| {
TenancyError::Validation(format!(
"org `{}` has no `database_url` to resolve (schema mode?)",
org.slug
))
})?;
let url = self.secrets.resolve(reference).await?;
Ok(url)
}
pub async fn cached_database_pool_count(&self) -> usize {
self.cache.read().await.len()
}
async fn pool_for_database_mode(&self, org: &Org) -> Result<Arc<PgPool>, TenancyError> {
{
let cache = self.cache.read().await;
if let Some(pool) = cache.get(&org.slug) {
return Ok(Arc::clone(pool));
}
}
let span = tracing::info_span!("tenant_pool_init", slug = %org.slug, mode = "database");
let _enter = span.enter();
let resolve_start = std::time::Instant::now();
let reference = org.database_url.as_deref().ok_or_else(|| {
TenancyError::Validation(format!(
"org `{}` is `storage_mode = database` but has no `database_url`",
org.slug
))
})?;
let url = self.secrets.resolve(reference).await?;
tracing::debug!(
target: "crate::tenancy::pools",
slug = %org.slug,
elapsed_ms = resolve_start.elapsed().as_millis() as u64,
"secrets resolver resolved tenant URL",
);
let connect_start = std::time::Instant::now();
let pool = build_database_pool(&url, &self.config).await?;
tracing::info!(
target: "crate::tenancy::pools",
slug = %org.slug,
elapsed_ms = connect_start.elapsed().as_millis() as u64,
min_conn = self.config.database_pool_min_connections,
max_conn = self.config.database_pool_max_connections,
"tenant pool connected (database mode)",
);
let pool = Arc::new(pool);
let mut cache = self.cache.write().await;
if let Some(existing) = cache.get(&org.slug) {
return Ok(Arc::clone(existing));
}
if cache.len() >= self.config.max_cached_database_pools {
return Err(TenancyError::Validation(format!(
"tenant pool cache is full ({} cached); raise \
`TenantPoolsConfig::max_cached_database_pools` or \
invalidate idle tenants",
cache.len(),
)));
}
cache.insert(org.slug.clone(), Arc::clone(&pool));
Ok(pool)
}
}
async fn build_database_pool(
url: &str,
config: &TenantPoolsConfig,
) -> Result<PgPool, TenancyError> {
let mut opts = PgPoolOptions::new()
.max_connections(config.database_pool_max_connections)
.min_connections(config.database_pool_min_connections)
.acquire_timeout(config.database_pool_acquire_timeout);
if let Some(idle) = config.database_pool_idle_timeout {
opts = opts.idle_timeout(idle);
}
if let Some(lifetime) = config.database_pool_max_lifetime {
opts = opts.max_lifetime(lifetime);
}
Ok(opts.connect(url).await?)
}
pub struct TenantConn {
inner: rustango::sql::sqlx::pool::PoolConnection<rustango::sql::sqlx::Postgres>,
schema: Option<String>,
}
impl TenantConn {
#[must_use]
pub fn schema(&self) -> Option<&str> {
self.schema.as_deref()
}
}
impl std::ops::Deref for TenantConn {
type Target = rustango::sql::sqlx::pool::PoolConnection<rustango::sql::sqlx::Postgres>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl std::ops::DerefMut for TenantConn {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
fn quote_ident(name: &str) -> String {
let escaped = name.replace('"', "\"\"");
format!("\"{escaped}\"")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn quote_ident_wraps_and_escapes() {
assert_eq!(quote_ident("acme"), "\"acme\"");
assert_eq!(quote_ident("a\"b"), "\"a\"\"b\"");
assert_eq!(quote_ident(""), "\"\"");
}
#[test]
fn config_defaults_are_sane() {
let c = TenantPoolsConfig::default();
assert!(c.max_cached_database_pools >= 16);
assert!(c.database_pool_max_connections >= 1);
}
#[test]
fn config_pool_timeout_defaults_preserve_pre_0_27_7_behavior() {
let c = TenantPoolsConfig::default();
assert!(!c.prewarm_active_tenants);
assert_eq!(c.database_pool_min_connections, 0);
assert!(c.database_pool_acquire_timeout >= std::time::Duration::from_secs(5));
assert!(c.database_pool_idle_timeout.is_some());
assert!(c.database_pool_max_lifetime.is_some());
}
#[test]
fn prewarm_report_zeroed_default() {
let r = PrewarmReport::default();
assert_eq!(r.total_active, 0);
assert_eq!(r.warmed, 0);
assert_eq!(r.failed, 0);
assert_eq!(r.skipped_cap, 0);
}
}