use std::collections::HashMap;
use std::sync::Arc;
use rustango::sql::sqlx;
use sqlx::pool::{Pool, PoolConnection};
use sqlx::Database;
use tokio::sync::RwLock;
use super::error::TenancyError;
use super::org::{BackendKind, Org, StorageMode};
use super::pools::TenantPoolsConfig;
use super::secrets::{LiteralSecretsResolver, SecretsResolver};
#[derive(Debug)]
pub struct DatabasePool<DB: Database> {
pool: Arc<Pool<DB>>,
}
impl<DB: Database> Clone for DatabasePool<DB> {
fn clone(&self) -> Self {
Self {
pool: self.pool.clone(),
}
}
}
impl<DB: Database> DatabasePool<DB> {
#[must_use]
pub fn pool(&self) -> &Pool<DB> {
&self.pool
}
#[must_use]
pub fn pool_arc(&self) -> Arc<Pool<DB>> {
self.pool.clone()
}
}
pub struct DatabaseConn<DB: Database> {
inner: PoolConnection<DB>,
}
impl<DB: Database> std::ops::Deref for DatabaseConn<DB> {
type Target = PoolConnection<DB>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<DB: Database> std::ops::DerefMut for DatabaseConn<DB> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
pub struct DatabasePools<DB: Database> {
config: TenantPoolsConfig,
secrets: Arc<dyn SecretsResolver>,
cache: RwLock<HashMap<String, Arc<Pool<DB>>>>,
backend: BackendKind,
url_template: Option<String>,
}
impl<DB: Database> DatabasePools<DB> {
#[must_use]
pub fn new(backend: BackendKind) -> Self {
Self::with_secrets(backend, LiteralSecretsResolver)
}
#[must_use]
pub fn with_secrets<R: SecretsResolver>(backend: BackendKind, secrets: R) -> Self {
Self {
config: TenantPoolsConfig::default(),
secrets: Arc::new(secrets),
cache: RwLock::new(HashMap::new()),
backend,
url_template: None,
}
}
#[must_use]
pub fn with_url_template(mut self, template: impl Into<String>) -> Self {
self.url_template = Some(template.into());
self
}
#[must_use]
pub fn config(mut self, config: TenantPoolsConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub fn pool_config(&self) -> &TenantPoolsConfig {
&self.config
}
#[must_use]
pub fn backend_kind(&self) -> BackendKind {
self.backend
}
pub async fn pool_for_org(&self, org: &Org) -> Result<DatabasePool<DB>, TenancyError> {
let want = BackendKind::parse(&org.backend_kind).map_err(|got| {
TenancyError::Validation(format!(
"org `{}` has unknown backend_kind `{got}` \
(expected `postgres`, `mysql`, or `sqlite`)",
org.slug
))
})?;
if want != self.backend {
return Err(TenancyError::Validation(format!(
"org `{}` is configured for backend `{want}` \
but this DatabasePools is serving `{}`. \
The process boots one backend; route through the \
matching server instance.",
org.slug, self.backend
)));
}
let mode = StorageMode::parse(&org.storage_mode).map_err(|got| {
TenancyError::Validation(format!(
"org `{}` has unknown storage_mode `{got}` \
(expected `schema` or `database`)",
org.slug
))
})?;
if mode != StorageMode::Database {
return Err(TenancyError::Validation(format!(
"org `{}` storage_mode is `{mode}` but DatabasePools \
only supports database-mode. \
(Schema-mode is Postgres-only — use TenantPools.)",
org.slug
)));
}
{
let cache = self.cache.read().await;
if let Some(pool) = cache.get(&org.slug) {
return Ok(DatabasePool { pool: pool.clone() });
}
}
let url_ref = match org.database_url.as_deref() {
Some(u) => u.to_owned(),
None => match &self.url_template {
Some(tpl) => tpl.replace("{slug}", &org.slug),
None => {
return Err(TenancyError::Validation(format!(
"org `{}` is database-mode but has no database_url \
and no url_template is configured on DatabasePools \
(call `.with_url_template(...)` at boot)",
org.slug
)));
}
},
};
let resolved = self
.secrets
.resolve(&url_ref)
.await
.map_err(TenancyError::Secrets)?;
let pool = self.build_pool(&resolved).await?;
let pool_arc = Arc::new(pool);
let mut cache = self.cache.write().await;
if let Some(existing) = cache.get(&org.slug) {
return Ok(DatabasePool {
pool: existing.clone(),
});
}
if cache.len() >= self.config.max_cached_database_pools {
return Err(TenancyError::Validation(format!(
"database-pool cache is at cap ({}); \
bump TenantPoolsConfig::max_cached_database_pools",
self.config.max_cached_database_pools
)));
}
cache.insert(org.slug.clone(), pool_arc.clone());
Ok(DatabasePool { pool: pool_arc })
}
pub async fn acquire(&self, org: &Org) -> Result<DatabaseConn<DB>, TenancyError> {
let pool = self.pool_for_org(org).await?;
let inner = pool.pool().acquire().await?;
Ok(DatabaseConn { inner })
}
pub async fn invalidate(&self, slug: &str) {
let mut cache = self.cache.write().await;
cache.remove(slug);
}
async fn build_pool(&self, url: &str) -> Result<Pool<DB>, TenancyError> {
let mut opts = sqlx::pool::PoolOptions::<DB>::new()
.max_connections(self.config.database_pool_max_connections)
.min_connections(self.config.database_pool_min_connections)
.acquire_timeout(self.config.database_pool_acquire_timeout);
if let Some(idle) = self.config.database_pool_idle_timeout {
opts = opts.idle_timeout(idle);
}
if let Some(lifetime) = self.config.database_pool_max_lifetime {
opts = opts.max_lifetime(lifetime);
}
opts.connect(url).await.map_err(TenancyError::from)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "sqlite")]
#[test]
fn sqlite_pool_registry_constructible() {
let _: DatabasePools<sqlx::Sqlite> = DatabasePools::new(BackendKind::Sqlite);
}
#[cfg(feature = "mysql")]
#[test]
fn mysql_pool_registry_constructible() {
let _: DatabasePools<sqlx::MySql> = DatabasePools::new(BackendKind::MySql);
}
#[cfg(feature = "sqlite")]
#[test]
fn backend_kind_round_trips() {
let p: DatabasePools<sqlx::Sqlite> = DatabasePools::new(BackendKind::Sqlite);
assert_eq!(p.backend_kind(), BackendKind::Sqlite);
}
}