use std::collections::HashMap;
use std::pin::Pin;
use std::sync::OnceLock;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteSynchronous};
use sqlx::{ConnectOptions, PgPool, SqlitePool};
use std::str::FromStr;
use std::time::Duration;
pub mod route_context;
pub mod router;
pub use route_context::{RouteContext, TenantKey, current as route_context};
pub use router::{Alias, DatabaseRouter, DefaultRouter, RouteOp, Schema, router};
#[derive(Debug, Clone)]
pub enum DbPool {
Sqlite(SqlitePool),
Postgres(PgPool),
}
impl DbPool {
pub fn as_sqlite(&self) -> Option<&SqlitePool> {
match self {
DbPool::Sqlite(p) => Some(p),
DbPool::Postgres(_) => None,
}
}
pub fn as_postgres(&self) -> Option<&PgPool> {
match self {
DbPool::Sqlite(_) => None,
DbPool::Postgres(p) => Some(p),
}
}
pub fn sqlite_or_panic(&self) -> &SqlitePool {
self.as_sqlite().expect(
"umbral: a Postgres pool is registered but this code path \
still reads SqlitePool. Full Postgres support lands in \
Phase 2 of the rollout — see FEATURES.md and the \
`DbPool` rustdoc.",
)
}
pub fn backend_name(&self) -> &'static str {
match self {
DbPool::Sqlite(_) => "sqlite",
DbPool::Postgres(_) => "postgres",
}
}
}
impl From<SqlitePool> for DbPool {
fn from(pool: SqlitePool) -> Self {
DbPool::Sqlite(pool)
}
}
impl From<PgPool> for DbPool {
fn from(pool: PgPool) -> Self {
DbPool::Postgres(pool)
}
}
static POOLS: OnceLock<HashMap<String, DbPool>> = OnceLock::new();
static DYNAMIC_POOLS: OnceLock<std::sync::RwLock<HashMap<String, &'static DbPool>>> =
OnceLock::new();
static ATOMIC_DEFAULT: OnceLock<bool> = OnceLock::new();
pub(crate) fn init_atomic_default(enabled: bool) {
let _ = ATOMIC_DEFAULT.set(enabled);
}
pub fn atomic_default() -> bool {
*ATOMIC_DEFAULT.get().unwrap_or(&false)
}
pub(crate) fn init(pools: HashMap<String, DbPool>) {
POOLS
.set(pools)
.expect("umbral::db::init called more than once");
}
pub fn pool() -> SqlitePool {
pool_dispatched().sqlite_or_panic().clone()
}
pub fn pool_dispatched() -> &'static DbPool {
POOLS
.get()
.expect("umbral: db pool not initialised — did you call App::build()?")
.get("default")
.expect("umbral: no default database registered")
}
pub fn try_pool_dispatched() -> Option<&'static DbPool> {
POOLS.get().and_then(|pools| pools.get("default"))
}
pub fn pool_for(alias: &str) -> SqlitePool {
pool_for_dispatched(alias).sqlite_or_panic().clone()
}
pub fn pool_for_dispatched(alias: &str) -> &'static DbPool {
if let Some(p) = POOLS.get().and_then(|pools| pools.get(alias)) {
return p;
}
if let Some(p) = DYNAMIC_POOLS
.get()
.and_then(|reg| reg.read().ok().and_then(|m| m.get(alias).copied()))
{
return p;
}
if POOLS.get().is_none() {
panic!("umbral: db pool not initialised — did you call App::build()?");
}
panic!("umbral: no database registered under alias '{alias}'");
}
pub fn register_tenant_pool(alias: impl Into<String>, pool: DbPool) {
let alias = alias.into();
let mut guard = DYNAMIC_POOLS
.get_or_init(|| std::sync::RwLock::new(HashMap::new()))
.write()
.expect("umbral: dynamic pool registry poisoned");
if guard.contains_key(&alias) {
return; }
let leaked: &'static DbPool = Box::leak(Box::new(pool));
guard.insert(alias, leaked);
}
pub fn pool_alias_registered(alias: &str) -> bool {
POOLS.get().is_some_and(|p| p.contains_key(alias))
|| DYNAMIC_POOLS
.get()
.and_then(|reg| reg.read().ok().map(|m| m.contains_key(alias)))
.unwrap_or(false)
}
pub async fn ping() -> Result<(), sqlx::Error> {
match pool_dispatched() {
DbPool::Sqlite(p) => {
sqlx::query("SELECT 1").execute(p).await.map(|_| ())
}
DbPool::Postgres(p) => {
sqlx::query("SELECT 1").execute(p).await.map(|_| ())
}
}
}
pub fn registered_aliases() -> Vec<String> {
let mut aliases: Vec<String> = POOLS
.get()
.expect("umbral: db pool not initialised — did you call App::build()?")
.keys()
.cloned()
.collect();
aliases.sort();
aliases
}
pub async fn connect(url: &str) -> Result<DbPool, sqlx::Error> {
let scheme = url
.split("://")
.next()
.and_then(|s| s.split(':').next())
.unwrap_or(url);
match scheme {
"sqlite" => Ok(DbPool::Sqlite(connect_sqlite(url).await?)),
"postgres" | "postgresql" => Ok(DbPool::Postgres(connect_postgres(url).await?)),
other => Err(sqlx::Error::Configuration(
format!(
"umbral::db::connect: unsupported URL scheme `{other}://`. \
Phase 1 supports `sqlite://` and `postgres://`."
)
.into(),
)),
}
}
struct PoolConfig {
max_connections: u32,
min_connections: u32,
acquire_timeout_secs: u64,
idle_timeout_secs: Option<u64>,
max_lifetime_secs: Option<u64>,
test_before_acquire: bool,
}
impl PoolConfig {
fn resolve() -> Self {
match crate::settings::get_opt() {
Some(s) => PoolConfig {
max_connections: s.db_max_connections,
min_connections: s.db_min_connections,
acquire_timeout_secs: s.db_acquire_timeout_secs,
idle_timeout_secs: s.db_idle_timeout_secs,
max_lifetime_secs: s.db_max_lifetime_secs,
test_before_acquire: s.db_test_before_acquire,
},
None => PoolConfig {
max_connections: 10,
min_connections: 0,
acquire_timeout_secs: 30,
idle_timeout_secs: Some(600),
max_lifetime_secs: Some(1800),
test_before_acquire: true,
},
}
}
fn log(&self, backend: &str) {
tracing::info!(
backend,
max_connections = self.max_connections.max(1),
min_connections = self.min_connections,
acquire_timeout_secs = self.acquire_timeout_secs,
idle_timeout_secs = ?self.idle_timeout_secs,
max_lifetime_secs = ?self.max_lifetime_secs,
test_before_acquire = self.test_before_acquire,
"umbral: opening database pool"
);
}
}
pub async fn connect_postgres(url: &str) -> Result<PgPool, sqlx::Error> {
use std::time::Duration;
let cfg = PoolConfig::resolve();
cfg.log("postgres");
let mut opts = sqlx::postgres::PgPoolOptions::new()
.max_connections(cfg.max_connections.max(1))
.min_connections(cfg.min_connections)
.acquire_timeout(Duration::from_secs(cfg.acquire_timeout_secs))
.test_before_acquire(cfg.test_before_acquire);
if let Some(secs) = cfg.idle_timeout_secs {
opts = opts.idle_timeout(Duration::from_secs(secs));
}
if let Some(secs) = cfg.max_lifetime_secs {
opts = opts.max_lifetime(Duration::from_secs(secs));
}
opts.connect(url).await
}
pub async fn connect_sqlite(url: &str) -> Result<SqlitePool, sqlx::Error> {
use std::sync::atomic::{AtomicU64, Ordering};
static MEM_SEQ: AtomicU64 = AtomicU64::new(0);
let lower = url.to_ascii_lowercase();
let in_memory = lower.contains(":memory:") || lower.contains("mode=memory");
let opts = if in_memory {
let n = MEM_SEQ.fetch_add(1, Ordering::Relaxed);
let path =
std::env::temp_dir().join(format!("umbral_mem_{}_{n}.sqlite", std::process::id()));
let _ = std::fs::remove_file(&path);
SqliteConnectOptions::new()
.filename(&path)
.create_if_missing(true)
} else {
SqliteConnectOptions::from_str(url)?
};
let opts = opts
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
.synchronous(SqliteSynchronous::Normal)
.busy_timeout(Duration::from_secs(5))
.foreign_keys(true)
.log_statements(tracing::log::LevelFilter::Off);
let cfg = PoolConfig::resolve();
cfg.log("sqlite");
let mut pool_opts = SqlitePoolOptions::new()
.max_connections(cfg.max_connections.max(1))
.min_connections(cfg.min_connections)
.acquire_timeout(Duration::from_secs(cfg.acquire_timeout_secs))
.test_before_acquire(cfg.test_before_acquire);
if let Some(secs) = cfg.idle_timeout_secs {
pool_opts = pool_opts.idle_timeout(Duration::from_secs(secs));
}
if let Some(secs) = cfg.max_lifetime_secs {
pool_opts = pool_opts.max_lifetime(Duration::from_secs(secs));
}
pool_opts.connect_with(opts).await
}
pub async fn close() {
if let Some(pools) = POOLS.get() {
for db in pools.values() {
match db {
DbPool::Sqlite(p) => p.close().await,
DbPool::Postgres(p) => p.close().await,
}
}
}
}
pub struct Transaction {
inner: TransactionInner,
}
enum TransactionInner {
Sqlite(sqlx::Transaction<'static, sqlx::Sqlite>),
Postgres(sqlx::Transaction<'static, sqlx::Postgres>),
}
impl Transaction {
pub fn as_sqlite_mut(&mut self) -> Option<&mut sqlx::Transaction<'static, sqlx::Sqlite>> {
match &mut self.inner {
TransactionInner::Sqlite(tx) => Some(tx),
TransactionInner::Postgres(_) => None,
}
}
pub fn as_pg_mut(&mut self) -> Option<&mut sqlx::Transaction<'static, sqlx::Postgres>> {
match &mut self.inner {
TransactionInner::Sqlite(_) => None,
TransactionInner::Postgres(tx) => Some(tx),
}
}
pub fn backend_name(&self) -> &'static str {
match &self.inner {
TransactionInner::Sqlite(_) => "sqlite",
TransactionInner::Postgres(_) => "postgres",
}
}
pub async fn commit(self) -> Result<(), sqlx::Error> {
match self.inner {
TransactionInner::Sqlite(tx) => tx.commit().await,
TransactionInner::Postgres(tx) => tx.commit().await,
}
}
pub async fn rollback(self) -> Result<(), sqlx::Error> {
match self.inner {
TransactionInner::Sqlite(tx) => tx.rollback().await,
TransactionInner::Postgres(tx) => tx.rollback().await,
}
}
}
pub async fn begin() -> Result<Transaction, sqlx::Error> {
match pool_dispatched() {
DbPool::Sqlite(pool) => {
let tx = pool.begin().await?;
Ok(Transaction {
inner: TransactionInner::Sqlite(tx),
})
}
DbPool::Postgres(pool) => {
let tx = pool.begin().await?;
Ok(Transaction {
inner: TransactionInner::Postgres(tx),
})
}
}
}
pub async fn begin_sqlite(pool: &sqlx::SqlitePool) -> Result<Transaction, sqlx::Error> {
let tx = pool.begin().await?;
Ok(Transaction {
inner: TransactionInner::Sqlite(tx),
})
}
pub async fn begin_pg(pool: &sqlx::PgPool) -> Result<Transaction, sqlx::Error> {
let tx = pool.begin().await?;
Ok(Transaction {
inner: TransactionInner::Postgres(tx),
})
}
pub type TxFuture<'a, T, E> = Pin<Box<dyn std::future::Future<Output = Result<T, E>> + Send + 'a>>;
pub async fn transaction<F, T, E>(f: F) -> Result<T, E>
where
for<'a> F: FnOnce(&'a mut Transaction) -> TxFuture<'a, T, E>,
E: From<sqlx::Error>,
{
let mut tx = begin().await.map_err(E::from)?;
match f(&mut tx).await {
Ok(val) => {
tx.commit().await.map_err(E::from)?;
Ok(val)
}
Err(e) => {
let _ = tx.rollback().await;
Err(e)
}
}
}
pub async fn transaction_sqlite<F, T, E>(pool: &sqlx::SqlitePool, f: F) -> Result<T, E>
where
for<'a> F: FnOnce(&'a mut Transaction) -> TxFuture<'a, T, E>,
E: From<sqlx::Error>,
{
let mut tx = begin_sqlite(pool).await.map_err(E::from)?;
match f(&mut tx).await {
Ok(val) => {
tx.commit().await.map_err(E::from)?;
Ok(val)
}
Err(e) => {
let _ = tx.rollback().await;
Err(e)
}
}
}
pub async fn transaction_pg<F, T, E>(pool: &sqlx::PgPool, f: F) -> Result<T, E>
where
for<'a> F: FnOnce(&'a mut Transaction) -> TxFuture<'a, T, E>,
E: From<sqlx::Error>,
{
let mut tx = begin_pg(pool).await.map_err(E::from)?;
match f(&mut tx).await {
Ok(val) => {
tx.commit().await.map_err(E::from)?;
Ok(val)
}
Err(e) => {
let _ = tx.rollback().await;
Err(e)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn connect_returns_a_working_pool_against_in_memory_sqlite() {
let pool = connect("sqlite::memory:")
.await
.expect("in-memory sqlite should always connect");
let sqlite = pool.as_sqlite().expect("should be Sqlite variant");
let (one,): (i64,) = sqlx::query_as("SELECT 1")
.fetch_one(sqlite)
.await
.expect("SELECT 1 should succeed on a fresh pool");
assert_eq!(one, 1);
}
#[tokio::test]
async fn connect_errors_on_malformed_url() {
let result = connect("not-a-real-url").await;
assert!(
result.is_err(),
"expected sqlx to reject a malformed url, got Ok"
);
}
#[tokio::test]
async fn connect_rejects_unsupported_scheme() {
let result = connect("mysql://user:pass@host/db").await;
match result {
Err(sqlx::Error::Configuration(msg)) => {
assert!(msg.to_string().contains("mysql"));
}
other => panic!("expected Configuration error, got {other:?}"),
}
}
#[tokio::test]
async fn sqlite_pool_round_trips_through_dbpool() {
let sp = SqlitePool::connect("sqlite::memory:").await.unwrap();
let dp: DbPool = sp.clone().into();
assert_eq!(dp.backend_name(), "sqlite");
assert!(dp.as_sqlite().is_some());
assert!(dp.as_postgres().is_none());
}
}