use axum::extract::FromRequestParts;
use diesel;
use diesel_async::AsyncPgConnection;
use diesel_async::pooled_connection::AsyncDieselConnectionManager;
use diesel_async::pooled_connection::deadpool::Pool;
use futures::FutureExt as _;
use std::any::Any;
use std::future::Future;
use std::panic::AssertUnwindSafe;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tracing::Instrument as _;
use crate::config::DatabaseConfig;
use crate::error::AutumnError;
pub type CommitCallback = Box<
dyn FnOnce() -> Pin<Box<dyn Future<Output = crate::AutumnResult<()>> + Send + 'static>>
+ Send
+ 'static,
>;
tokio::task_local! {
pub static AFTER_COMMIT_REGISTRY: Arc<Mutex<Vec<CommitCallback>>>;
}
pub static AFTER_COMMIT_FAILURES_TOTAL: AtomicU64 = AtomicU64::new(0);
pub(crate) fn record_after_commit_failure() -> u64 {
AFTER_COMMIT_FAILURES_TOTAL.fetch_add(1, Ordering::Relaxed) + 1
}
pub(crate) fn reject_ambient_after_commit_registry_for_tx() -> Result<(), AutumnError> {
if AFTER_COMMIT_REGISTRY.try_with(|_| ()).is_ok() {
return Err(AutumnError::bad_request_msg(
"Nested Db::tx calls are not supported",
));
}
Ok(())
}
pub(crate) fn spawn_committed_after_commit_callbacks(
callbacks: Vec<CommitCallback>,
) -> Option<tokio::task::JoinHandle<()>> {
if callbacks.is_empty() {
return None;
}
Some(tokio::task::spawn(async move {
for cb in callbacks {
let result = match std::panic::catch_unwind(AssertUnwindSafe(cb)) {
Ok(callback) => AssertUnwindSafe(callback).catch_unwind().await,
Err(panic) => Err(panic),
};
match result {
Ok(Ok(())) => {}
Ok(Err(e)) => {
let failures_total = record_after_commit_failure();
tracing::error!(
autumn.after_commit.failures_total = failures_total,
"after_commit callback failed (tx already committed): {e}"
);
}
Err(panic) => {
let failures_total = record_after_commit_failure();
let panic = after_commit_panic_message(&*panic);
tracing::error!(
autumn.after_commit.failures_total = failures_total,
"after_commit callback panicked (tx already committed): {panic}"
);
}
}
}
}))
}
fn after_commit_panic_message(payload: &(dyn Any + Send)) -> String {
match (
payload.downcast_ref::<&'static str>(),
payload.downcast_ref::<String>(),
) {
(Some(message), _) => (*message).to_owned(),
(_, Some(message)) => message.clone(),
(None, None) => "non-string panic payload".to_owned(),
}
}
pub async fn register_after_commit<F, Fut>(f: F)
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = crate::AutumnResult<()>> + Send + 'static,
{
let mut f_opt = Some(f);
AFTER_COMMIT_REGISTRY
.try_with(|registry| {
let f = f_opt.take().expect("closure only entered once");
let boxed: CommitCallback = Box::new(move || Box::pin(f()));
registry.lock().expect("registry lock").push(boxed);
})
.ok();
if let Some(f) = f_opt {
tracing::debug!("register_after_commit: no active transaction; running callback eagerly");
if let Err(e) = f().await {
let failures_total = record_after_commit_failure();
tracing::error!(
autumn.after_commit.failures_total = failures_total,
"register_after_commit eager callback failed: {e}"
);
}
}
}
pub trait DbState {
fn pool(&self) -> Option<&Pool<AsyncPgConnection>>;
fn metrics(&self) -> Option<&crate::middleware::MetricsCollector> {
None
}
fn replica_pool(&self) -> Option<&Pool<AsyncPgConnection>> {
None
}
fn read_pool(&self) -> Option<&Pool<AsyncPgConnection>> {
self.replica_pool().or_else(|| self.pool())
}
fn db_interceptors(
&self,
) -> Vec<std::sync::Arc<dyn crate::interceptor::DbConnectionInterceptor>> {
Vec::new()
}
fn statement_timeout(&self) -> Option<std::time::Duration> {
None
}
fn slow_query_threshold(&self) -> std::time::Duration {
std::time::Duration::from_millis(500)
}
}
fn consume_estring_body(chars: &mut std::iter::Peekable<std::str::Chars<'_>>) {
loop {
match chars.next() {
None => break,
Some('\'') => {
if chars.peek() == Some(&'\'') {
chars.next(); } else {
break;
}
}
Some('\\') => {
chars.next(); }
Some(_) => {}
}
}
}
fn consume_dollar_quoted_body(chars: &mut std::iter::Peekable<std::str::Chars<'_>>, tag: &str) {
let closing: Vec<char> = format!("${tag}$").chars().collect();
let clen = closing.len();
let mut match_count = 0usize;
for sc in chars.by_ref() {
if sc == closing[match_count] {
match_count += 1;
if match_count == clen {
break; }
} else {
match_count = 0;
if sc == closing[0] {
match_count = 1;
}
}
}
}
#[inline]
const fn is_separator(c: char) -> bool {
matches!(
c,
' ' | '\t' | '\n' | '=' | '<' | '>' | '!' | '+' | '-' | '*' | '/' | '%' | '(' | ',' )
}
#[must_use]
pub fn scrub_sql(sql: &str) -> String {
let mut out = String::with_capacity(sql.len());
let mut prev_is_sep = true;
let mut chars = sql.chars().peekable();
while let Some(c) = chars.next() {
if (c == 'E' || c == 'e') && chars.peek() == Some(&'\'') {
chars.next(); out.push_str("'?'");
prev_is_sep = false;
consume_estring_body(&mut chars);
continue;
}
if c == '\'' {
out.push_str("'?'");
prev_is_sep = false;
loop {
match chars.next() {
None => break,
Some('\'') => {
if chars.peek() == Some(&'\'') {
chars.next();
} else {
break;
}
}
Some(_) => {}
}
}
continue;
}
if c == '$' {
let next_ch = chars.peek().copied();
if next_ch.is_some_and(|nc| nc.is_ascii_digit()) {
out.push('$');
prev_is_sep = false;
while chars.peek().is_some_and(char::is_ascii_digit) {
if let Some(d) = chars.next() {
out.push(d);
}
}
continue;
}
let mut tag = String::new();
let mut found_closing_dollar = false;
if next_ch == Some('$') {
chars.next();
found_closing_dollar = true;
} else if next_ch.is_some_and(|nc| nc.is_alphabetic() || nc == '_') {
while let Some(&tc) = chars.peek() {
if tc == '$' {
chars.next(); found_closing_dollar = true;
break;
} else if tc.is_alphanumeric() || tc == '_' {
tag.push(tc);
chars.next();
} else {
break;
}
}
}
if found_closing_dollar {
out.push_str("'?'");
prev_is_sep = false;
consume_dollar_quoted_body(&mut chars, &tag);
} else {
out.push('$');
out.push_str(&tag);
prev_is_sep = false;
}
continue;
}
let is_leading_dot =
c == '.' && prev_is_sep && chars.peek().is_some_and(char::is_ascii_digit);
if (c.is_ascii_digit() && prev_is_sep) || is_leading_dot {
out.push('?');
if is_leading_dot {
chars.next(); }
while chars
.peek()
.is_some_and(|d| d.is_ascii_digit() || *d == '.' || *d == '_')
{
chars.next();
}
if chars.peek().is_some_and(|e| *e == 'e' || *e == 'E') {
chars.next(); if chars.peek().is_some_and(|s| *s == '+' || *s == '-') {
chars.next(); }
while chars.peek().is_some_and(char::is_ascii_digit) {
chars.next();
}
}
prev_is_sep = false;
continue;
}
out.push(c);
prev_is_sep = is_separator(c);
}
out
}
pub async fn run_instrumented<F, Fut, T>(
sql: &str,
route_key: &str,
slow_threshold: std::time::Duration,
metrics: &crate::middleware::metrics::MetricsCollector,
query: F,
) -> Result<T, AutumnError>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<T, diesel::result::Error>>,
{
let start = std::time::Instant::now();
let result = query().await;
let elapsed = start.elapsed();
let elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX);
let verb = sql.split_whitespace().next().unwrap_or("?");
let metric_key = format!("{route_key} {verb}");
metrics.record_db_query(&metric_key, elapsed_ms);
if elapsed >= slow_threshold {
let fingerprint = scrub_sql(sql);
tracing::warn!(
route = %route_key,
sql = %fingerprint,
duration_ms = elapsed_ms,
"slow database query"
);
}
result.map_err(|db_err| {
if is_query_canceled(&db_err) {
tracing::warn!(
route = %route_key,
duration_ms = elapsed_ms,
"database query cancelled: statement_timeout exceeded"
);
AutumnError::query_timeout(format!(
"Database query timed out after {elapsed_ms}ms (statement_timeout exceeded)"
))
} else {
AutumnError::from(db_err)
}
})
}
fn is_query_canceled(err: &diesel::result::Error) -> bool {
let err_str = err.to_string().to_lowercase();
if err_str.contains("57014")
|| err_str.contains("query_canceled")
|| err_str.contains("canceling statement due to statement timeout")
|| err_str.contains("statement timeout")
|| err_str.contains("query canceled")
{
return true;
}
let mut source: Option<&(dyn std::error::Error + 'static)> = Some(err);
while let Some(e) = source {
if e.downcast_ref::<tokio_postgres::Error>()
.and_then(|pg_err| pg_err.code())
== Some(&tokio_postgres::error::SqlState::QUERY_CANCELED)
{
return true;
}
if e.downcast_ref::<tokio_postgres::error::DbError>()
.is_some_and(|db_err| db_err.code() == &tokio_postgres::error::SqlState::QUERY_CANCELED)
{
return true;
}
source = e.source();
}
false
}
pub type PoolError = diesel_async::pooled_connection::deadpool::BuildError;
#[derive(Clone)]
pub struct DatabaseTopology {
primary: Pool<AsyncPgConnection>,
replica: Option<Pool<AsyncPgConnection>>,
}
impl DatabaseTopology {
#[must_use]
pub const fn from_pools(
primary: Pool<AsyncPgConnection>,
replica: Option<Pool<AsyncPgConnection>>,
) -> Self {
Self { primary, replica }
}
#[must_use]
pub const fn primary_only(primary: Pool<AsyncPgConnection>) -> Self {
Self {
primary,
replica: None,
}
}
#[must_use]
pub const fn primary(&self) -> &Pool<AsyncPgConnection> {
&self.primary
}
#[must_use]
pub const fn replica(&self) -> Option<&Pool<AsyncPgConnection>> {
self.replica.as_ref()
}
#[must_use]
pub fn read(&self) -> &Pool<AsyncPgConnection> {
self.replica.as_ref().unwrap_or(&self.primary)
}
}
fn build_pool(
url: &str,
pool_size: usize,
connect_timeout_secs: u64,
) -> Result<Pool<AsyncPgConnection>, PoolError> {
let timeout = Duration::from_secs(connect_timeout_secs);
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(url);
Pool::builder(manager)
.max_size(pool_size.max(1))
.wait_timeout(Some(timeout))
.create_timeout(Some(timeout))
.runtime(deadpool::Runtime::Tokio1)
.build()
}
pub fn create_pool(config: &DatabaseConfig) -> Result<Option<Pool<AsyncPgConnection>>, PoolError> {
let Some(url) = config.effective_primary_url() else {
return Ok(None);
};
let pool = build_pool(
url,
config.effective_primary_pool_size(),
config.connect_timeout_secs,
)?;
Ok(Some(pool))
}
pub fn create_topology(config: &DatabaseConfig) -> Result<Option<DatabaseTopology>, PoolError> {
let Some(primary_url) = config.effective_primary_url() else {
return Ok(None);
};
let primary = build_pool(
primary_url,
config.effective_primary_pool_size(),
config.connect_timeout_secs,
)?;
let replica = config
.replica_url
.as_deref()
.map(|url| {
build_pool(
url,
config.effective_replica_pool_size(),
config.connect_timeout_secs,
)
})
.transpose()?;
Ok(Some(DatabaseTopology { primary, replica }))
}
pub type PooledConnection = diesel_async::pooled_connection::deadpool::Object<AsyncPgConnection>;
struct TxDepthGuard<'a> {
depth: &'a mut usize,
poisoned: &'a mut bool,
disarmed: bool,
}
impl Drop for TxDepthGuard<'_> {
fn drop(&mut self) {
*self.depth -= 1;
if !self.disarmed {
*self.poisoned = true;
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StatementTimeout(pub std::time::Duration);
pub struct Db {
conn: PooledConnection,
span: tracing::Span,
tx_depth: usize,
tx_poisoned: bool,
route_key: Option<String>,
metrics: Option<crate::middleware::MetricsCollector>,
slow_query_threshold: std::time::Duration,
start_time: std::time::Instant,
is_test_tx: bool,
}
impl Db {
#[must_use]
pub const fn span(&self) -> &tracing::Span {
&self.span
}
pub async fn tx<'a, T, E, F>(&'a mut self, f: F) -> Result<T, crate::error::AutumnError>
where
T: Send + 'a,
E: From<diesel::result::Error> + Send + Sync + 'a,
crate::error::AutumnError: From<E>,
F: for<'r> FnOnce(
&'r mut PooledConnection,
) -> scoped_futures::ScopedBoxFuture<'a, 'r, Result<T, E>>
+ Send
+ 'a,
{
use diesel_async::AsyncConnection as _;
if self.tx_poisoned {
return Err(crate::error::AutumnError::service_unavailable_msg(
"Database connection is in an invalid transaction state",
));
}
if self.tx_depth > 0 {
return Err(crate::error::AutumnError::bad_request_msg(
"Nested Db::tx calls are not supported",
));
}
reject_ambient_after_commit_registry_for_tx()?;
self.tx_depth += 1;
let mut guard = TxDepthGuard {
depth: &mut self.tx_depth,
poisoned: &mut self.tx_poisoned,
disarmed: false,
};
let registry: Arc<Mutex<Vec<CommitCallback>>> = Arc::new(Mutex::new(Vec::new()));
let result = AFTER_COMMIT_REGISTRY
.scope(registry.clone(), self.conn.transaction::<T, E, _>(f))
.await
.map_err(Into::into);
guard.disarmed = true;
if result.is_ok() {
let callbacks: Vec<CommitCallback> = {
let mut reg = registry.lock().expect("registry lock");
std::mem::take(&mut *reg)
};
if !callbacks.is_empty() && !self.is_test_tx {
let _ = spawn_committed_after_commit_callbacks(callbacks);
}
}
result
}
}
impl std::ops::Deref for Db {
type Target = AsyncPgConnection;
fn deref(&self) -> &Self::Target {
assert!(
!self.tx_poisoned,
"Db connection is poisoned due to a cancelled/dropped transaction"
);
&self.conn
}
}
impl std::ops::DerefMut for Db {
fn deref_mut(&mut self) -> &mut Self::Target {
assert!(
!self.tx_poisoned,
"Db connection is poisoned due to a cancelled/dropped transaction"
);
&mut self.conn
}
}
impl<S> FromRequestParts<S> for Db
where
S: DbState + Send + Sync,
{
type Rejection = AutumnError;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
state: &S,
) -> Result<Self, Self::Rejection> {
const PG_TIMEOUT_MAX_MS: u64 = i32::MAX as u64;
use diesel_async::RunQueryDsl as _;
let pool = state
.pool()
.ok_or_else(|| AutumnError::service_unavailable_msg("Database not configured"))?;
let span = tracing::info_span!(
"db.connection",
otel.kind = "client",
db.system = "postgresql",
);
let interceptors = state.db_interceptors();
let mut checkout_future: std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<PooledConnection, AutumnError>> + Send + '_,
>,
> = Box::pin(async {
pool.get().await.map_err(|e| {
tracing::error!("Failed to acquire database connection: {e}");
AutumnError::service_unavailable_msg(e.to_string())
})
});
for interceptor in &interceptors {
let ctx = crate::interceptor::DbCheckoutContext {
pool_name: "primary".to_string(),
};
checkout_future = interceptor.intercept_checkout(ctx, checkout_future);
}
let mut conn = checkout_future.instrument(span.clone()).await?;
let timeout_override = parts.extensions.get::<StatementTimeout>().copied();
let timeout_ms = timeout_override
.map(|t| t.0)
.or_else(|| state.statement_timeout())
.map_or(0u64, |d| {
u64::try_from(d.as_millis())
.unwrap_or(PG_TIMEOUT_MAX_MS)
.min(PG_TIMEOUT_MAX_MS)
});
diesel::sql_query(format!("SET statement_timeout = {timeout_ms}"))
.execute(&mut conn)
.await
.map_err(|e| {
tracing::error!("Failed to set database statement_timeout to {timeout_ms}ms: {e}");
AutumnError::service_unavailable_msg(format!("Database initialization error: {e}"))
})?;
let matched_path = parts
.extensions
.get::<axum::extract::MatchedPath>()
.map_or_else(|| parts.uri.path(), axum::extract::MatchedPath::as_str);
let route_key = format!("{} {}", parts.method, matched_path);
let metrics = state.metrics();
let slow_query_threshold = state.slow_query_threshold();
let start_time = std::time::Instant::now();
let is_test_tx = interceptors.iter().any(|i| i.is_transactional_test());
Ok(Self {
conn,
span,
tx_depth: 0,
tx_poisoned: false,
route_key: Some(route_key),
metrics: metrics.cloned(),
slow_query_threshold,
start_time,
is_test_tx,
})
}
}
impl Drop for Db {
fn drop(&mut self) {
if let (Some(route_key), Some(metrics)) = (&self.route_key, &self.metrics) {
let elapsed = self.start_time.elapsed();
let elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX);
let metric_key = format!("{route_key} SELECT");
metrics.record_db_query(&metric_key, elapsed_ms);
if elapsed >= self.slow_query_threshold {
tracing::warn!(
route = %route_key,
sql = "SELECT ?",
duration_ms = elapsed_ms,
"slow database query"
);
}
}
}
}
pub trait DatabasePoolProvider: Send + Sync + 'static {
fn create_pool(
&self,
config: &DatabaseConfig,
) -> impl std::future::Future<Output = Result<Option<Pool<AsyncPgConnection>>, PoolError>> + Send;
fn create_topology(
&self,
config: &DatabaseConfig,
) -> impl std::future::Future<Output = Result<Option<DatabaseTopology>, PoolError>> + Send {
async move {
let Some(primary) = self.create_pool(config).await? else {
return Ok(None);
};
let replica = config
.replica_url
.as_deref()
.map(|url| {
build_pool(
url,
config.effective_replica_pool_size(),
config.connect_timeout_secs,
)
})
.transpose()?;
Ok(Some(DatabaseTopology::from_pools(primary, replica)))
}
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct DieselDeadpoolPoolProvider;
impl DieselDeadpoolPoolProvider {
#[must_use]
pub const fn new() -> Self {
Self
}
}
impl DatabasePoolProvider for DieselDeadpoolPoolProvider {
async fn create_pool(
&self,
config: &DatabaseConfig,
) -> Result<Option<Pool<AsyncPgConnection>>, PoolError> {
create_pool(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::DatabaseConfig;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::Duration;
#[tokio::test]
async fn register_after_commit_outside_tx_runs_eagerly() {
let counter = Arc::new(AtomicUsize::new(0));
let c = counter.clone();
register_after_commit(move || async move {
c.fetch_add(1, Ordering::SeqCst);
Ok(())
})
.await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn register_after_commit_eager_failure_increments_failure_counter() {
let before = AFTER_COMMIT_FAILURES_TOTAL.load(Ordering::Relaxed);
register_after_commit(|| async {
Err(crate::AutumnError::internal_server_error_msg(
"deliberate eager after-commit failure",
))
})
.await;
let after = AFTER_COMMIT_FAILURES_TOTAL.load(Ordering::Relaxed);
assert!(
after > before,
"eager after_commit failures should be counted for recovery signals"
);
}
#[tokio::test]
async fn register_after_commit_inside_scope_defers_until_drained() {
let counter = Arc::new(AtomicUsize::new(0));
let c = counter.clone();
let registry = Arc::new(std::sync::Mutex::new(Vec::<CommitCallback>::new()));
AFTER_COMMIT_REGISTRY
.scope(registry.clone(), async {
register_after_commit(move || async move {
c.fetch_add(1, Ordering::SeqCst);
Ok(())
})
.await;
})
.await;
assert_eq!(counter.load(Ordering::SeqCst), 0);
let callbacks: Vec<CommitCallback> = {
let mut reg = registry.lock().unwrap();
std::mem::take(&mut *reg)
};
for cb in callbacks {
cb().await.unwrap();
}
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn register_after_commit_on_rollback_callbacks_dropped() {
let counter = Arc::new(AtomicUsize::new(0));
let c = counter.clone();
let registry = Arc::new(std::sync::Mutex::new(Vec::<CommitCallback>::new()));
AFTER_COMMIT_REGISTRY
.scope(registry.clone(), async {
register_after_commit(move || async move {
c.fetch_add(1, Ordering::SeqCst);
Ok(())
})
.await;
})
.await;
drop(registry);
assert_eq!(counter.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn register_after_commit_callbacks_run_in_registration_order() {
let order = Arc::new(std::sync::Mutex::new(Vec::<u32>::new()));
let registry = Arc::new(std::sync::Mutex::new(Vec::<CommitCallback>::new()));
let o1 = order.clone();
let o2 = order.clone();
let o3 = order.clone();
AFTER_COMMIT_REGISTRY
.scope(registry.clone(), async {
register_after_commit(move || async move {
o1.lock().unwrap().push(1);
Ok(())
})
.await;
register_after_commit(move || async move {
o2.lock().unwrap().push(2);
Ok(())
})
.await;
register_after_commit(move || async move {
o3.lock().unwrap().push(3);
Ok(())
})
.await;
})
.await;
let callbacks: Vec<CommitCallback> = {
let mut reg = registry.lock().unwrap();
std::mem::take(&mut *reg)
};
for cb in callbacks {
cb().await.unwrap();
}
assert_eq!(*order.lock().unwrap(), vec![1, 2, 3]);
}
#[tokio::test]
async fn production_after_commit_drain_preserves_registration_order() {
let order = Arc::new(std::sync::Mutex::new(Vec::<u32>::new()));
let (release_first, wait_first) = tokio::sync::oneshot::channel::<()>();
let first_order = order.clone();
let second_order = order.clone();
let callbacks: Vec<CommitCallback> = vec![
Box::new(move || {
Box::pin(async move {
wait_first
.await
.expect("test should release first callback");
first_order.lock().unwrap().push(1);
Ok(())
})
}),
Box::new(move || {
Box::pin(async move {
second_order.lock().unwrap().push(2);
Ok(())
})
}),
];
let drain = spawn_committed_after_commit_callbacks(callbacks)
.expect("non-empty callback list should spawn a drain task");
tokio::task::yield_now().await;
assert_eq!(
*order.lock().unwrap(),
Vec::<u32>::new(),
"later callbacks must wait for earlier callbacks to finish"
);
release_first
.send(())
.expect("first callback receiver alive");
drain.await.expect("drain task should not panic");
assert_eq!(*order.lock().unwrap(), vec![1, 2]);
}
#[tokio::test]
async fn production_after_commit_drain_isolates_panicking_callbacks() {
let before = AFTER_COMMIT_FAILURES_TOTAL.load(Ordering::Relaxed);
let ran_later = Arc::new(AtomicU64::new(0));
let later = ran_later.clone();
let callbacks: Vec<CommitCallback> = vec![
Box::new(|| Box::pin(async { panic!("deliberate after_commit panic") })),
Box::new(move || {
Box::pin(async move {
later.fetch_add(1, Ordering::SeqCst);
Ok(())
})
}),
];
let drain = spawn_committed_after_commit_callbacks(callbacks)
.expect("non-empty callback list should spawn a drain task");
drain.await.expect("panicking callback should be isolated");
assert_eq!(
ran_later.load(Ordering::SeqCst),
1,
"later callbacks must still run after an earlier callback panics"
);
let after = AFTER_COMMIT_FAILURES_TOTAL.load(Ordering::Relaxed);
assert!(
after > before,
"panicking after_commit callbacks must increment the failure counter"
);
}
#[tokio::test]
async fn db_tx_rejects_ambient_after_commit_registry() {
let registry = Arc::new(std::sync::Mutex::new(Vec::<CommitCallback>::new()));
let err = AFTER_COMMIT_REGISTRY
.scope(registry, async {
reject_ambient_after_commit_registry_for_tx().expect_err(
"starting Db::tx inside an ambient transaction registry should fail",
)
})
.await;
assert!(
err.to_string().contains("Nested Db::tx calls"),
"unexpected nested transaction error: {err}"
);
}
#[tokio::test]
async fn register_after_commit_callback_error_is_swallowed() {
let registry = Arc::new(std::sync::Mutex::new(Vec::<CommitCallback>::new()));
AFTER_COMMIT_REGISTRY
.scope(registry.clone(), async {
register_after_commit(|| async {
Err(crate::AutumnError::internal_server_error_msg(
"deliberate error",
))
})
.await;
})
.await;
let callbacks: Vec<CommitCallback> = {
let mut reg = registry.lock().unwrap();
std::mem::take(&mut *reg)
};
for cb in callbacks {
let _ = cb().await;
}
}
struct NoOpPoolProvider;
impl DatabasePoolProvider for NoOpPoolProvider {
async fn create_pool(
&self,
_config: &DatabaseConfig,
) -> Result<Option<Pool<AsyncPgConnection>>, PoolError> {
Ok(None)
}
}
#[tokio::test]
async fn pool_provider_trait_returns_supplied_pool() {
let config = DatabaseConfig {
url: Some("postgres://localhost/ignored".to_owned()),
..Default::default()
};
let provider = NoOpPoolProvider;
let pool = provider
.create_pool(&config)
.await
.expect("no-op provider should succeed");
assert!(
pool.is_none(),
"no-op provider must override default behaviour"
);
}
#[tokio::test]
async fn default_pool_provider_matches_free_function() {
let config = DatabaseConfig::default();
let via_provider = DieselDeadpoolPoolProvider::new()
.create_pool(&config)
.await
.expect("default provider should succeed");
let via_function = create_pool(&config).expect("free fn should succeed");
assert_eq!(via_provider.is_none(), via_function.is_none());
}
#[tokio::test]
async fn default_pool_provider_respects_url_config() {
let config = DatabaseConfig {
url: Some("postgres://localhost/test".into()),
..Default::default()
};
let provider = DieselDeadpoolPoolProvider::new();
let pool = provider
.create_pool(&config)
.await
.expect("default provider should succeed");
assert!(
pool.is_some(),
"default provider should return Some when url is provided"
);
}
#[test]
fn create_pool_with_no_url_returns_none() {
let config = DatabaseConfig::default();
let pool = create_pool(&config).expect("should not fail with no URL");
assert!(pool.is_none());
}
#[test]
fn create_pool_with_url_returns_some() {
let config = DatabaseConfig {
url: Some("postgres://localhost/test".into()),
..Default::default()
};
let pool = create_pool(&config).expect("should build pool from valid config");
assert!(pool.is_some());
}
#[test]
fn pool_respects_max_size() {
let config = DatabaseConfig {
url: Some("postgres://localhost/test".into()),
pool_size: 5,
..Default::default()
};
let pool = create_pool(&config)
.expect("should build pool")
.expect("should be Some");
assert_eq!(pool.status().max_size, 5);
}
#[test]
fn pool_clamps_size_to_one_if_zero() {
let config = DatabaseConfig {
url: Some("postgres://localhost/test".into()),
pool_size: 0,
..Default::default()
};
let pool = create_pool(&config)
.expect("should build pool")
.expect("should be Some");
assert_eq!(
pool.status().max_size,
1,
"Pool size should be clamped to 1"
);
}
#[test]
fn database_topology_builds_primary_and_replica_pools() {
let config = DatabaseConfig {
primary_url: Some("postgres://localhost/primary".into()),
replica_url: Some("postgres://localhost/replica".into()),
primary_pool_size: Some(6),
replica_pool_size: Some(2),
..Default::default()
};
let topology = create_topology(&config)
.expect("topology should build")
.expect("topology should be configured");
assert_eq!(topology.primary().status().max_size, 6);
assert_eq!(
topology.replica().expect("replica pool").status().max_size,
2
);
assert_eq!(topology.read().status().max_size, 2);
}
#[test]
fn database_topology_single_url_builds_only_primary_pool() {
let config = DatabaseConfig {
url: Some("postgres://localhost/single".into()),
pool_size: 5,
..Default::default()
};
let topology = create_topology(&config)
.expect("topology should build")
.expect("topology should be configured");
assert_eq!(topology.primary().status().max_size, 5);
assert!(topology.replica().is_none());
assert_eq!(topology.read().status().max_size, 5);
}
#[test]
fn config_runtime_drift_pool_applies_connect_timeout_to_wait_and_create() {
let config = DatabaseConfig {
url: Some("postgres://localhost/test".into()),
connect_timeout_secs: 7,
..Default::default()
};
let pool = create_pool(&config)
.expect("should build pool")
.expect("should be Some");
let timeouts = pool.timeouts();
assert_eq!(timeouts.wait, Some(Duration::from_secs(7)));
assert_eq!(timeouts.create, Some(Duration::from_secs(7)));
}
#[derive(Clone)]
struct TestDbState;
impl DbState for TestDbState {
fn pool(&self) -> Option<&Pool<AsyncPgConnection>> {
None
}
}
#[derive(Clone)]
struct TestReadState {
primary: Pool<AsyncPgConnection>,
}
impl DbState for TestReadState {
fn pool(&self) -> Option<&Pool<AsyncPgConnection>> {
Some(&self.primary)
}
}
#[test]
fn database_topology_read_pool_falls_back_to_primary() {
let config = DatabaseConfig {
url: Some("postgres://localhost/read-fallback".into()),
pool_size: 3,
..Default::default()
};
let primary = create_pool(&config).unwrap().unwrap();
let state = TestReadState { primary };
assert_eq!(state.read_pool().expect("read pool").status().max_size, 3);
}
#[tokio::test]
async fn db_extractor_rejects_when_no_pool() {
use axum::Router;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use axum::routing::get;
use tower::ServiceExt;
async fn handler(_db: Db) -> &'static str {
"ok"
}
let app = Router::new()
.route("/", get(handler))
.with_state(TestDbState);
let response = app
.oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
}
#[tokio::test]
async fn database_topology_primary_only_has_no_replica() {
let config = DatabaseConfig {
primary_url: Some("postgres://user:pass@localhost/db".to_string()),
..DatabaseConfig::default()
};
let topology = create_topology(&config).unwrap().unwrap();
let primary = topology.primary().clone();
let new_topology = DatabaseTopology::primary_only(primary);
assert!(
new_topology.replica().is_none(),
"primary_only must set replica to None"
);
}
#[tokio::test]
async fn database_topology_from_pools_retains_replica() {
let config = DatabaseConfig {
primary_url: Some("postgres://user:pass@localhost/db".to_string()),
replica_url: Some("postgres://user:pass@localhost/db_replica".to_string()),
..DatabaseConfig::default()
};
let topology = create_topology(&config).unwrap().unwrap();
let primary = topology.primary().clone();
let replica = topology.replica().cloned();
let new_topology = DatabaseTopology::from_pools(primary, replica);
assert!(
new_topology.replica().is_some(),
"from_pools must preserve the replica pool"
);
}
#[test]
fn scrub_sql_strips_string_literals() {
assert_eq!(
super::scrub_sql("SELECT * FROM users WHERE name = 'Alice'"),
"SELECT * FROM users WHERE name = '?'"
);
}
#[test]
fn scrub_sql_strips_numeric_literals() {
assert_eq!(
super::scrub_sql("SELECT * FROM orders WHERE id = 42"),
"SELECT * FROM orders WHERE id = ?"
);
}
#[test]
fn scrub_sql_preserves_pg_positional_params() {
assert_eq!(
super::scrub_sql("SELECT * FROM t WHERE x = $1 AND y = $2"),
"SELECT * FROM t WHERE x = $1 AND y = $2"
);
}
#[test]
fn scrub_sql_does_not_stomp_identifiers() {
assert_eq!(
super::scrub_sql("SELECT * FROM table1 WHERE active = true"),
"SELECT * FROM table1 WHERE active = true"
);
}
#[test]
fn scrub_sql_multiple_literals_in_one_query() {
assert_eq!(
super::scrub_sql("INSERT INTO users (name, age) VALUES ('Bob', 30)"),
"INSERT INTO users (name, age) VALUES ('?', ?)"
);
}
#[test]
fn scrub_sql_handles_escaped_single_quotes() {
assert_eq!(
super::scrub_sql("SELECT * FROM t WHERE s = 'it''s a test'"),
"SELECT * FROM t WHERE s = '?'"
);
}
#[test]
fn scrub_sql_empty_string() {
assert_eq!(super::scrub_sql(""), "");
}
#[test]
fn scrub_sql_scientific_notation_integer_exponent() {
assert_eq!(
super::scrub_sql("SELECT * FROM t WHERE n = 1e6"),
"SELECT * FROM t WHERE n = ?"
);
}
#[test]
fn scrub_sql_scientific_notation_float_exponent() {
assert_eq!(
super::scrub_sql("SELECT * FROM t WHERE n = 2.5E-4"),
"SELECT * FROM t WHERE n = ?"
);
}
#[test]
fn scrub_sql_scientific_notation_uppercase_positive_exponent() {
assert_eq!(
super::scrub_sql("SELECT * FROM t WHERE n = 3E+10"),
"SELECT * FROM t WHERE n = ?"
);
}
#[test]
fn scrub_sql_dollar_quoted_anonymous() {
assert_eq!(super::scrub_sql("SELECT $$secret value$$"), "SELECT '?'");
}
#[test]
fn scrub_sql_dollar_quoted_with_tag() {
assert_eq!(
super::scrub_sql("SELECT $body$hello world$body$"),
"SELECT '?'"
);
}
#[test]
fn scrub_sql_dollar_quoted_does_not_affect_positional_params() {
assert_eq!(
super::scrub_sql("SELECT $1, $2 FROM $$secret$$ WHERE id = $3"),
"SELECT $1, $2 FROM '?' WHERE id = $3"
);
}
#[test]
fn scrub_sql_estring_backslash_escaped_quote() {
assert_eq!(
super::scrub_sql(r"SELECT E'it\'s secret' FROM t"),
"SELECT '?' FROM t"
);
}
#[test]
fn scrub_sql_estring_uppercase() {
assert_eq!(
super::scrub_sql("SELECT E'hello world' FROM t"),
"SELECT '?' FROM t"
);
}
#[test]
fn scrub_sql_estring_multiple_backslash_escapes() {
assert_eq!(
super::scrub_sql(r"SELECT E'line1\nline2' FROM t"),
"SELECT '?' FROM t"
);
}
#[test]
fn scrub_sql_leading_dot_numeric_literals() {
assert_eq!(super::scrub_sql("SELECT .5"), "SELECT ?");
assert_eq!(super::scrub_sql("SELECT .25 + .75"), "SELECT ? + ?");
assert_eq!(super::scrub_sql("SELECT t.col"), "SELECT t.col");
assert_eq!(
super::scrub_sql("SELECT schema.table.col"),
"SELECT schema.table.col"
);
}
#[test]
fn scrub_sql_estring_doubled_quote_escape() {
assert_eq!(
super::scrub_sql("SELECT E'it''s secret' FROM t"),
"SELECT '?' FROM t"
);
}
#[test]
fn scrub_sql_numeric_literal_underscore_grouping() {
assert_eq!(super::scrub_sql("SELECT 5_432_000"), "SELECT ?");
assert_eq!(super::scrub_sql("SELECT 1_000.5_0"), "SELECT ?");
assert_eq!(super::scrub_sql("SELECT col_5_val"), "SELECT col_5_val");
assert_eq!(super::scrub_sql("SELECT col_5"), "SELECT col_5");
}
}