use std::collections::HashMap;
use std::sync::Arc;
use crate::sql::sqlx::postgres::{PgPool, PgPoolOptions};
use tokio::sync::RwLock;
use super::error::TenancyError;
use super::org::{Org, StorageMode};
use super::secrets::{LiteralSecretsResolver, SecretsResolver};
#[derive(Debug, Clone)]
pub struct TenantPoolsConfig {
pub max_cached_database_pools: usize,
pub database_pool_max_connections: u32,
}
impl Default for TenantPoolsConfig {
fn default() -> Self {
Self {
max_cached_database_pools: 64,
database_pool_max_connections: 4,
}
}
}
#[derive(Debug, Clone)]
pub enum TenantPool {
Schema {
schema: String,
registry: PgPool,
},
Database { pool: Arc<PgPool> },
}
impl TenantPool {
#[must_use]
pub fn pool(&self) -> &PgPool {
match self {
Self::Schema { registry, .. } => registry,
Self::Database { pool } => pool,
}
}
#[must_use]
pub fn is_schema(&self) -> bool {
matches!(self, Self::Schema { .. })
}
}
pub struct TenantPools {
registry: PgPool,
config: TenantPoolsConfig,
secrets: Arc<dyn SecretsResolver>,
cache: RwLock<HashMap<String, Arc<PgPool>>>,
}
impl TenantPools {
#[must_use]
pub fn new(registry: PgPool) -> Self {
Self::with_secrets(registry, LiteralSecretsResolver)
}
#[must_use]
pub fn with_secrets<R: SecretsResolver>(registry: PgPool, secrets: R) -> Self {
Self {
registry,
config: TenantPoolsConfig::default(),
secrets: Arc::new(secrets),
cache: RwLock::new(HashMap::new()),
}
}
#[must_use]
pub fn config(mut self, config: TenantPoolsConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub fn registry(&self) -> &PgPool {
&self.registry
}
pub async fn pool_for_org(&self, org: &Org) -> Result<TenantPool, TenancyError> {
let mode = StorageMode::parse(&org.storage_mode).map_err(|got| {
TenancyError::Validation(format!(
"org `{}` has unknown storage_mode `{got}` (expected `schema` or `database`)",
org.slug
))
})?;
match mode {
StorageMode::Schema => {
let schema = org
.schema_name
.clone()
.unwrap_or_else(|| org.slug.clone());
Ok(TenantPool::Schema {
schema,
registry: self.registry.clone(),
})
}
StorageMode::Database => {
let pool = self.pool_for_database_mode(org).await?;
Ok(TenantPool::Database { pool })
}
}
}
pub async fn acquire(&self, org: &Org) -> Result<TenantConn, TenancyError> {
let pool = self.pool_for_org(org).await?;
match &pool {
TenantPool::Schema { schema, registry } => {
let mut conn = registry.acquire().await?;
let stmt = format!(
"SET search_path TO {}, public",
quote_ident(schema)
);
rustango::sql::sqlx::query(&stmt)
.execute(&mut *conn)
.await?;
Ok(TenantConn {
inner: conn,
schema: Some(schema.clone()),
})
}
TenantPool::Database { pool } => {
let conn = pool.acquire().await?;
Ok(TenantConn { inner: conn, schema: None })
}
}
}
pub async fn invalidate(&self, slug: &str) {
let mut cache = self.cache.write().await;
cache.remove(slug);
}
pub async fn resolved_database_url(&self, org: &Org) -> Result<String, TenancyError> {
let reference = org.database_url.as_deref().ok_or_else(|| {
TenancyError::Validation(format!(
"org `{}` has no `database_url` to resolve (schema mode?)",
org.slug
))
})?;
let url = self.secrets.resolve(reference).await?;
Ok(url)
}
pub async fn cached_database_pool_count(&self) -> usize {
self.cache.read().await.len()
}
async fn pool_for_database_mode(&self, org: &Org) -> Result<Arc<PgPool>, TenancyError> {
{
let cache = self.cache.read().await;
if let Some(pool) = cache.get(&org.slug) {
return Ok(Arc::clone(pool));
}
}
let reference = org.database_url.as_deref().ok_or_else(|| {
TenancyError::Validation(format!(
"org `{}` is `storage_mode = database` but has no `database_url`",
org.slug
))
})?;
let url = self.secrets.resolve(reference).await?;
let pool = PgPoolOptions::new()
.max_connections(self.config.database_pool_max_connections)
.connect(&url)
.await?;
let pool = Arc::new(pool);
let mut cache = self.cache.write().await;
if let Some(existing) = cache.get(&org.slug) {
return Ok(Arc::clone(existing));
}
if cache.len() >= self.config.max_cached_database_pools {
return Err(TenancyError::Validation(format!(
"tenant pool cache is full ({} cached); raise \
`TenantPoolsConfig::max_cached_database_pools` or \
invalidate idle tenants",
cache.len(),
)));
}
cache.insert(org.slug.clone(), Arc::clone(&pool));
Ok(pool)
}
}
pub struct TenantConn {
inner: rustango::sql::sqlx::pool::PoolConnection<rustango::sql::sqlx::Postgres>,
schema: Option<String>,
}
impl TenantConn {
#[must_use]
pub fn schema(&self) -> Option<&str> {
self.schema.as_deref()
}
}
impl std::ops::Deref for TenantConn {
type Target = rustango::sql::sqlx::pool::PoolConnection<rustango::sql::sqlx::Postgres>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl std::ops::DerefMut for TenantConn {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
fn quote_ident(name: &str) -> String {
let escaped = name.replace('"', "\"\"");
format!("\"{escaped}\"")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn quote_ident_wraps_and_escapes() {
assert_eq!(quote_ident("acme"), "\"acme\"");
assert_eq!(quote_ident("a\"b"), "\"a\"\"b\"");
assert_eq!(quote_ident(""), "\"\"");
}
#[test]
fn config_defaults_are_sane() {
let c = TenantPoolsConfig::default();
assert!(c.max_cached_database_pools >= 16);
assert!(c.database_pool_max_connections >= 1);
}
}