use std::sync::{Arc, Mutex};
use axum::extract::{Request, State};
use axum::middleware::Next;
use axum::response::Response;
use sqlx::{Any, AnyPool, Transaction};
use rusty_gasket::db::config::ResolvedBackend;
pub fn sanitize_request_id(id: &str) -> String {
id.chars()
.filter(|c| c.is_ascii_alphanumeric() || *c == '-')
.take(55)
.collect()
}
pub async fn begin_tracked_transaction(
pool: &AnyPool,
request_id: &str,
backend: ResolvedBackend,
) -> Result<Transaction<'static, Any>, sqlx::Error> {
let mut tx = pool.begin().await?;
let sanitized = sanitize_request_id(request_id);
let app_name = format!("gasket|{sanitized}");
match backend {
ResolvedBackend::Postgres => {
sqlx::query("SELECT set_config('application_name', $1, true)")
.bind(&app_name)
.execute(&mut *tx)
.await?;
}
ResolvedBackend::MySql => {
sqlx::query("SET @gasket_request_id = ?")
.bind(&app_name)
.execute(&mut *tx)
.await?;
}
}
Ok(tx)
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct TransactionMiddlewareState {
pub pool: AnyPool,
pub backend: ResolvedBackend,
}
impl TransactionMiddlewareState {
#[must_use]
pub const fn new(pool: AnyPool, backend: ResolvedBackend) -> Self {
Self { pool, backend }
}
}
#[derive(Debug, Clone)]
pub struct RequestTransaction {
inner: Arc<Mutex<Option<Transaction<'static, Any>>>>,
}
impl RequestTransaction {
#[must_use]
pub fn new(tx: Transaction<'static, Any>) -> Self {
Self {
inner: Arc::new(Mutex::new(Some(tx))),
}
}
#[must_use]
pub fn take(&self) -> Option<Transaction<'static, Any>> {
match self.inner.lock() {
Ok(mut guard) => guard.take(),
Err(_) => {
tracing::error!(
"RequestTransaction mutex poisoned; transaction handle is unrecoverable"
);
None
}
}
}
}
pub async fn transaction_middleware(
State(state): State<Arc<TransactionMiddlewareState>>,
mut request: Request,
next: Next,
) -> Response {
let request_id = request
.extensions()
.get::<rusty_gasket::observability::RequestId>()
.map(|r| r.as_str().to_owned())
.unwrap_or_default();
match begin_tracked_transaction(&state.pool, &request_id, state.backend).await {
Ok(tx) => {
request.extensions_mut().insert(RequestTransaction::new(tx));
next.run(request).await
}
Err(e) => {
tracing::error!(
request_id = %request_id,
error = %e,
"Failed to begin database transaction"
);
rusty_gasket::error::quick_error_response(
http::StatusCode::SERVICE_UNAVAILABLE,
"DATABASE_ERROR",
"Service temporarily unavailable",
)
}
}
}
#[cfg(test)]
mod tests {
use super::sanitize_request_id;
#[test]
fn sanitize_normal_uuid() {
let id = "550e8400-e29b-41d4-a716-446655440000";
assert_eq!(sanitize_request_id(id), id);
}
#[test]
fn sanitize_strips_special_chars() {
let id = "req-123; DROP TABLE users;--";
assert_eq!(sanitize_request_id(id), "req-123DROPTABLEusers--");
}
#[test]
fn sanitize_truncates_long_ids() {
let id = "a".repeat(100);
let sanitized = sanitize_request_id(&id);
assert_eq!(sanitized.len(), 55);
}
#[test]
fn sanitize_empty_id() {
assert_eq!(sanitize_request_id(""), "");
}
#[test]
fn sanitize_unicode_strips_non_ascii() {
let id = "req-\u{65e5}\u{672c}\u{8a9e}-123";
assert_eq!(sanitize_request_id(id), "req--123");
}
#[test]
fn sanitize_only_special_chars() {
assert!(sanitize_request_id("!@#$%^&*()").is_empty());
}
}