use super::config::PoolConfig;
use super::errors::{PoolError, PoolResult};
use super::events::{PoolEvent, PoolEventListener};
use sqlx::{Database, MySql, Pool, Postgres, Sqlite};
use std::mem::ManuallyDrop;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::RwLock;
pub(crate) fn mask_url_password(url: &str) -> String {
if let Some(scheme_end) = url.find("://") {
let after_scheme = &url[scheme_end + 3..];
if let Some(at_pos) = after_scheme.rfind('@') {
let user_info = &after_scheme[..at_pos];
if let Some(colon_pos) = user_info.find(':') {
let scheme_and_user = &url[..scheme_end + 3 + colon_pos + 1];
let rest = &url[scheme_end + 3 + at_pos..];
return format!("{}***{}", scheme_and_user, rest);
}
}
}
url.to_string()
}
pub struct ConnectionPool<DB: Database> {
pool: Pool<DB>,
config: PoolConfig,
url: String,
listeners: Arc<RwLock<Vec<Arc<dyn PoolEventListener>>>>,
first_connect_fired: Arc<AtomicBool>,
}
impl ConnectionPool<Postgres> {
pub async fn new_postgres(url: &str, config: PoolConfig) -> PoolResult<Self> {
config.validate().map_err(PoolError::Config)?;
let pool = sqlx::postgres::PgPoolOptions::new()
.min_connections(config.min_connections)
.max_connections(config.max_connections)
.acquire_timeout(config.acquire_timeout)
.idle_timeout(config.idle_timeout)
.max_lifetime(config.max_lifetime)
.test_before_acquire(config.test_before_acquire)
.connect(url)
.await?;
Ok(Self {
pool,
config,
url: url.to_string(),
listeners: Arc::new(RwLock::new(Vec::new())),
first_connect_fired: Arc::new(AtomicBool::new(false)),
})
}
}
impl ConnectionPool<MySql> {
pub async fn new_mysql(url: &str, config: PoolConfig) -> PoolResult<Self> {
config.validate().map_err(PoolError::Config)?;
let pool = sqlx::mysql::MySqlPoolOptions::new()
.min_connections(config.min_connections)
.max_connections(config.max_connections)
.acquire_timeout(config.acquire_timeout)
.idle_timeout(config.idle_timeout)
.max_lifetime(config.max_lifetime)
.test_before_acquire(config.test_before_acquire)
.connect(url)
.await?;
Ok(Self {
pool,
config,
url: url.to_string(),
listeners: Arc::new(RwLock::new(Vec::new())),
first_connect_fired: Arc::new(AtomicBool::new(false)),
})
}
}
impl ConnectionPool<Sqlite> {
pub async fn new_sqlite(url: &str, config: PoolConfig) -> PoolResult<Self> {
config.validate().map_err(PoolError::Config)?;
let pool = sqlx::sqlite::SqlitePoolOptions::new()
.min_connections(config.min_connections)
.max_connections(config.max_connections)
.acquire_timeout(config.acquire_timeout)
.idle_timeout(config.idle_timeout)
.max_lifetime(config.max_lifetime)
.test_before_acquire(config.test_before_acquire)
.connect(url)
.await?;
Ok(Self {
pool,
config,
url: url.to_string(),
listeners: Arc::new(RwLock::new(Vec::new())),
first_connect_fired: Arc::new(AtomicBool::new(false)),
})
}
}
impl<DB> ConnectionPool<DB>
where
DB: sqlx::Database,
{
pub async fn add_listener(&self, listener: Arc<dyn PoolEventListener>) {
let mut listeners = self.listeners.write().await;
listeners.push(listener);
}
pub(crate) async fn emit_event(&self, event: PoolEvent) {
let listeners = self.listeners.read().await;
for listener in listeners.iter() {
listener.on_event(event.clone()).await;
}
}
pub async fn acquire(&self) -> PoolResult<PooledConnection<DB>> {
let is_first = !self.first_connect_fired.swap(true, Ordering::SeqCst);
let conn = self.pool.acquire().await?;
let connection_id = uuid::Uuid::now_v7().to_string();
if is_first {
self.emit_event(PoolEvent::connection_created(connection_id.clone()))
.await;
}
self.emit_event(PoolEvent::connection_acquired(connection_id.clone()))
.await;
Ok(PooledConnection {
conn: ManuallyDrop::new(conn),
pool_ref: self.clone_arc(),
connection_id,
})
}
fn clone_arc(&self) -> Arc<Self> {
Arc::new(Self {
pool: self.pool.clone(),
config: self.config.clone(),
url: self.url.clone(),
listeners: self.listeners.clone(),
first_connect_fired: self.first_connect_fired.clone(),
})
}
pub fn inner(&self) -> &Pool<DB> {
&self.pool
}
pub fn config(&self) -> &PoolConfig {
&self.config
}
pub async fn close(&self) {
use tokio::time::{Duration, timeout};
let close_future = self.pool.close();
if timeout(Duration::from_secs(5), close_future).await.is_err() {
}
}
pub fn url(&self) -> String {
mask_url_password(&self.url)
}
#[allow(dead_code)]
pub(crate) fn url_raw(&self) -> &str {
&self.url
}
}
impl ConnectionPool<Postgres> {
pub async fn recreate(&mut self) -> PoolResult<()> {
self.pool.close().await;
let new_pool = sqlx::postgres::PgPoolOptions::new()
.min_connections(self.config.min_connections)
.max_connections(self.config.max_connections)
.acquire_timeout(self.config.acquire_timeout)
.idle_timeout(self.config.idle_timeout)
.max_lifetime(self.config.max_lifetime)
.test_before_acquire(self.config.test_before_acquire)
.connect(&self.url)
.await?;
self.pool = new_pool;
self.first_connect_fired.store(false, Ordering::SeqCst);
Ok(())
}
}
impl ConnectionPool<MySql> {
pub async fn recreate(&mut self) -> PoolResult<()> {
self.pool.close().await;
let new_pool = sqlx::mysql::MySqlPoolOptions::new()
.min_connections(self.config.min_connections)
.max_connections(self.config.max_connections)
.acquire_timeout(self.config.acquire_timeout)
.idle_timeout(self.config.idle_timeout)
.max_lifetime(self.config.max_lifetime)
.test_before_acquire(self.config.test_before_acquire)
.connect(&self.url)
.await?;
self.pool = new_pool;
self.first_connect_fired.store(false, Ordering::SeqCst);
Ok(())
}
}
impl ConnectionPool<Sqlite> {
pub async fn recreate(&mut self) -> PoolResult<()> {
self.pool.close().await;
let new_pool = sqlx::sqlite::SqlitePoolOptions::new()
.min_connections(self.config.min_connections)
.max_connections(self.config.max_connections)
.acquire_timeout(self.config.acquire_timeout)
.idle_timeout(self.config.idle_timeout)
.max_lifetime(self.config.max_lifetime)
.test_before_acquire(self.config.test_before_acquire)
.connect(&self.url)
.await?;
self.pool = new_pool;
self.first_connect_fired.store(false, Ordering::SeqCst);
Ok(())
}
}
pub struct PooledConnection<DB: sqlx::Database> {
conn: ManuallyDrop<sqlx::pool::PoolConnection<DB>>,
pool_ref: Arc<ConnectionPool<DB>>,
connection_id: String,
}
impl<DB: sqlx::Database> PooledConnection<DB> {
pub fn inner(&mut self) -> &mut sqlx::pool::PoolConnection<DB> {
&mut self.conn
}
pub fn connection_id(&self) -> &str {
&self.connection_id
}
pub async fn invalidate(self, reason: String) {
self.pool_ref
.emit_event(PoolEvent::connection_invalidated(
self.connection_id.clone(),
reason,
))
.await;
}
pub async fn soft_invalidate(&mut self) {
self.pool_ref
.emit_event(PoolEvent::connection_soft_invalidated(
self.connection_id.clone(),
))
.await;
}
pub async fn reset(&mut self) {
self.pool_ref
.emit_event(PoolEvent::connection_reset(self.connection_id.clone()))
.await;
}
}
impl<DB: sqlx::Database> Drop for PooledConnection<DB> {
fn drop(&mut self) {
let conn = unsafe { ManuallyDrop::take(&mut self.conn) };
match tokio::runtime::Handle::try_current() {
Ok(handle) => {
drop(conn);
let pool_ref = self.pool_ref.clone();
let connection_id = self.connection_id.clone();
handle.spawn(async move {
pool_ref
.emit_event(PoolEvent::connection_returned(connection_id))
.await;
});
}
Err(_) => {
std::mem::forget(conn);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
#[case(
"postgresql://user:secret@localhost:5432/mydb",
"postgresql://user:***@localhost:5432/mydb"
)]
#[case(
"mysql://admin:p@ssw0rd@db.example.com/app",
"mysql://admin:***@db.example.com/app"
)]
#[case(
"postgres://user:pass@host:5432/db?sslmode=require",
"postgres://user:***@host:5432/db?sslmode=require"
)]
fn test_mask_url_password_with_credentials(#[case] input: &str, #[case] expected: &str) {
let masked = mask_url_password(input);
assert_eq!(masked, expected);
}
#[rstest]
#[case("sqlite::memory:")]
#[case("sqlite:///path/to/db.sqlite")]
#[case("postgresql://user@localhost:5432/mydb")]
fn test_mask_url_password_without_password(#[case] input: &str) {
let masked = mask_url_password(input);
assert_eq!(masked, input, "URL without password should be unchanged");
}
#[rstest]
fn test_mask_url_password_empty_password() {
let url = "postgresql://user:@localhost:5432/mydb";
let masked = mask_url_password(url);
assert_eq!(masked, "postgresql://user:***@localhost:5432/mydb");
}
#[rstest]
fn test_mask_url_password_special_chars_in_password() {
let url = "postgresql://user:p%40ss%3Aw0rd@localhost:5432/mydb";
let masked = mask_url_password(url);
assert_eq!(masked, "postgresql://user:***@localhost:5432/mydb");
assert!(
!masked.contains("p%40ss"),
"Password should be fully masked"
);
}
#[rstest]
fn test_mask_url_password_preserves_non_url() {
let non_url = "not-a-url-just-a-string";
let masked = mask_url_password(non_url);
assert_eq!(
masked, non_url,
"Non-URL strings should pass through unchanged"
);
}
#[rstest]
fn test_handle_try_current_returns_err_outside_runtime() {
let handle = std::thread::spawn(|| {
let result = tokio::runtime::Handle::try_current();
assert!(
result.is_err(),
"Handle::try_current() should return Err outside of a tokio runtime"
);
});
handle.join().expect("thread should not panic");
}
#[rstest]
fn test_drop_pooled_connection_outside_runtime_does_not_panic() {
let rt = tokio::runtime::Runtime::new().expect("failed to create Tokio runtime");
let (pool, conn) = rt.block_on(async {
let config = PoolConfig::default();
let pool = ConnectionPool::new_sqlite("sqlite::memory:", config)
.await
.expect("failed to create ConnectionPool");
let conn = pool.acquire().await.expect("failed to acquire connection");
(pool, conn)
});
drop(rt);
drop(conn);
drop(pool);
}
}