reinhardt_db/pool/
pool.rs1use super::config::PoolConfig;
4use super::errors::{PoolError, PoolResult};
5use super::events::{PoolEvent, PoolEventListener};
6use sqlx::{Database, MySql, Pool, Postgres, Sqlite};
7use std::mem::ManuallyDrop;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use tokio::sync::RwLock;
11
12pub(crate) fn mask_url_password(url: &str) -> String {
19 if let Some(scheme_end) = url.find("://") {
21 let after_scheme = &url[scheme_end + 3..];
22
23 if let Some(at_pos) = after_scheme.rfind('@') {
25 let user_info = &after_scheme[..at_pos];
26
27 if let Some(colon_pos) = user_info.find(':') {
29 let scheme_and_user = &url[..scheme_end + 3 + colon_pos + 1];
30 let rest = &url[scheme_end + 3 + at_pos..];
31 return format!("{}***{}", scheme_and_user, rest);
32 }
33 }
34 }
35
36 url.to_string()
38}
39
40pub struct ConnectionPool<DB: Database> {
42 pool: Pool<DB>,
43 config: PoolConfig,
44 url: String,
45 listeners: Arc<RwLock<Vec<Arc<dyn PoolEventListener>>>>,
46 first_connect_fired: Arc<AtomicBool>,
47}
48
49impl ConnectionPool<Postgres> {
50 pub async fn new_postgres(url: &str, config: PoolConfig) -> PoolResult<Self> {
67 config.validate().map_err(PoolError::Config)?;
68
69 let pool = sqlx::postgres::PgPoolOptions::new()
70 .min_connections(config.min_connections)
71 .max_connections(config.max_connections)
72 .acquire_timeout(config.acquire_timeout)
73 .idle_timeout(config.idle_timeout)
74 .max_lifetime(config.max_lifetime)
75 .test_before_acquire(config.test_before_acquire)
76 .connect(url)
77 .await?;
78
79 Ok(Self {
80 pool,
81 config,
82 url: url.to_string(),
83 listeners: Arc::new(RwLock::new(Vec::new())),
84 first_connect_fired: Arc::new(AtomicBool::new(false)),
85 })
86 }
87}
88
89impl ConnectionPool<MySql> {
90 pub async fn new_mysql(url: &str, config: PoolConfig) -> PoolResult<Self> {
107 config.validate().map_err(PoolError::Config)?;
108
109 let pool = sqlx::mysql::MySqlPoolOptions::new()
110 .min_connections(config.min_connections)
111 .max_connections(config.max_connections)
112 .acquire_timeout(config.acquire_timeout)
113 .idle_timeout(config.idle_timeout)
114 .max_lifetime(config.max_lifetime)
115 .test_before_acquire(config.test_before_acquire)
116 .connect(url)
117 .await?;
118
119 Ok(Self {
120 pool,
121 config,
122 url: url.to_string(),
123 listeners: Arc::new(RwLock::new(Vec::new())),
124 first_connect_fired: Arc::new(AtomicBool::new(false)),
125 })
126 }
127}
128
129impl ConnectionPool<Sqlite> {
130 pub async fn new_sqlite(url: &str, config: PoolConfig) -> PoolResult<Self> {
147 config.validate().map_err(PoolError::Config)?;
148
149 let pool = sqlx::sqlite::SqlitePoolOptions::new()
150 .min_connections(config.min_connections)
151 .max_connections(config.max_connections)
152 .acquire_timeout(config.acquire_timeout)
153 .idle_timeout(config.idle_timeout)
154 .max_lifetime(config.max_lifetime)
155 .test_before_acquire(config.test_before_acquire)
156 .connect(url)
157 .await?;
158
159 Ok(Self {
160 pool,
161 config,
162 url: url.to_string(),
163 listeners: Arc::new(RwLock::new(Vec::new())),
164 first_connect_fired: Arc::new(AtomicBool::new(false)),
165 })
166 }
167}
168
169impl<DB> ConnectionPool<DB>
170where
171 DB: sqlx::Database,
172{
173 pub async fn add_listener(&self, listener: Arc<dyn PoolEventListener>) {
176 let mut listeners = self.listeners.write().await;
177 listeners.push(listener);
178 }
179
180 pub(crate) async fn emit_event(&self, event: PoolEvent) {
182 let listeners = self.listeners.read().await;
183 for listener in listeners.iter() {
184 listener.on_event(event.clone()).await;
185 }
186 }
187 pub async fn acquire(&self) -> PoolResult<PooledConnection<DB>> {
206 let is_first = !self.first_connect_fired.swap(true, Ordering::SeqCst);
208
209 let conn = self.pool.acquire().await?;
210 let connection_id = uuid::Uuid::new_v4().to_string();
211
212 if is_first {
213 self.emit_event(PoolEvent::connection_created(connection_id.clone()))
215 .await;
216 }
217
218 self.emit_event(PoolEvent::connection_acquired(connection_id.clone()))
220 .await;
221
222 Ok(PooledConnection {
223 conn: ManuallyDrop::new(conn),
224 pool_ref: self.clone_arc(),
225 connection_id,
226 })
227 }
228
229 fn clone_arc(&self) -> Arc<Self> {
231 Arc::new(Self {
232 pool: self.pool.clone(),
233 config: self.config.clone(),
234 url: self.url.clone(),
235 listeners: self.listeners.clone(),
236 first_connect_fired: self.first_connect_fired.clone(),
237 })
238 }
239 pub fn inner(&self) -> &Pool<DB> {
242 &self.pool
243 }
244 pub fn config(&self) -> &PoolConfig {
247 &self.config
248 }
249 pub async fn close(&self) {
255 use tokio::time::{Duration, timeout};
256
257 let close_future = self.pool.close();
259 if timeout(Duration::from_secs(5), close_future).await.is_err() {
260 }
263 }
264 pub fn url(&self) -> String {
270 mask_url_password(&self.url)
271 }
272
273 #[allow(dead_code)]
279 pub(crate) fn url_raw(&self) -> &str {
280 &self.url
281 }
282}
283
284impl ConnectionPool<Postgres> {
286 pub async fn recreate(&mut self) -> PoolResult<()> {
307 self.pool.close().await;
309
310 let new_pool = sqlx::postgres::PgPoolOptions::new()
312 .min_connections(self.config.min_connections)
313 .max_connections(self.config.max_connections)
314 .acquire_timeout(self.config.acquire_timeout)
315 .idle_timeout(self.config.idle_timeout)
316 .max_lifetime(self.config.max_lifetime)
317 .test_before_acquire(self.config.test_before_acquire)
318 .connect(&self.url)
319 .await?;
320
321 self.pool = new_pool;
322 self.first_connect_fired.store(false, Ordering::SeqCst);
323
324 Ok(())
325 }
326}
327
328impl ConnectionPool<MySql> {
329 pub async fn recreate(&mut self) -> PoolResult<()> {
350 self.pool.close().await;
352
353 let new_pool = sqlx::mysql::MySqlPoolOptions::new()
355 .min_connections(self.config.min_connections)
356 .max_connections(self.config.max_connections)
357 .acquire_timeout(self.config.acquire_timeout)
358 .idle_timeout(self.config.idle_timeout)
359 .max_lifetime(self.config.max_lifetime)
360 .test_before_acquire(self.config.test_before_acquire)
361 .connect(&self.url)
362 .await?;
363
364 self.pool = new_pool;
365 self.first_connect_fired.store(false, Ordering::SeqCst);
366
367 Ok(())
368 }
369}
370
371impl ConnectionPool<Sqlite> {
372 pub async fn recreate(&mut self) -> PoolResult<()> {
392 self.pool.close().await;
394
395 let new_pool = sqlx::sqlite::SqlitePoolOptions::new()
397 .min_connections(self.config.min_connections)
398 .max_connections(self.config.max_connections)
399 .acquire_timeout(self.config.acquire_timeout)
400 .idle_timeout(self.config.idle_timeout)
401 .max_lifetime(self.config.max_lifetime)
402 .test_before_acquire(self.config.test_before_acquire)
403 .connect(&self.url)
404 .await?;
405
406 self.pool = new_pool;
407 self.first_connect_fired.store(false, Ordering::SeqCst);
408
409 Ok(())
410 }
411}
412
413pub struct PooledConnection<DB: sqlx::Database> {
415 conn: ManuallyDrop<sqlx::pool::PoolConnection<DB>>,
419 pool_ref: Arc<ConnectionPool<DB>>,
420 connection_id: String,
421}
422
423impl<DB: sqlx::Database> PooledConnection<DB> {
424 pub fn inner(&mut self) -> &mut sqlx::pool::PoolConnection<DB> {
427 &mut self.conn
428 }
429 pub fn connection_id(&self) -> &str {
448 &self.connection_id
449 }
450 pub async fn invalidate(self, reason: String) {
453 self.pool_ref
454 .emit_event(PoolEvent::connection_invalidated(
455 self.connection_id.clone(),
456 reason,
457 ))
458 .await;
459 }
461 pub async fn soft_invalidate(&mut self) {
464 self.pool_ref
465 .emit_event(PoolEvent::connection_soft_invalidated(
466 self.connection_id.clone(),
467 ))
468 .await;
469 }
470 pub async fn reset(&mut self) {
473 self.pool_ref
474 .emit_event(PoolEvent::connection_reset(self.connection_id.clone()))
475 .await;
476 }
477}
478
479impl<DB: sqlx::Database> Drop for PooledConnection<DB> {
480 fn drop(&mut self) {
481 let conn = unsafe { ManuallyDrop::take(&mut self.conn) };
483
484 match tokio::runtime::Handle::try_current() {
485 Ok(handle) => {
486 drop(conn);
489
490 let pool_ref = self.pool_ref.clone();
491 let connection_id = self.connection_id.clone();
492
493 handle.spawn(async move {
494 pool_ref
495 .emit_event(PoolEvent::connection_returned(connection_id))
496 .await;
497 });
498 }
499 Err(_) => {
500 std::mem::forget(conn);
505 }
506 }
507 }
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513 use rstest::rstest;
514
515 #[rstest]
516 #[case(
517 "postgresql://user:secret@localhost:5432/mydb",
518 "postgresql://user:***@localhost:5432/mydb"
519 )]
520 #[case(
521 "mysql://admin:p@ssw0rd@db.example.com/app",
522 "mysql://admin:***@db.example.com/app"
523 )]
524 #[case(
525 "postgres://user:pass@host:5432/db?sslmode=require",
526 "postgres://user:***@host:5432/db?sslmode=require"
527 )]
528 fn test_mask_url_password_with_credentials(#[case] input: &str, #[case] expected: &str) {
529 let masked = mask_url_password(input);
534
535 assert_eq!(masked, expected);
537 }
538
539 #[rstest]
540 #[case("sqlite::memory:")]
541 #[case("sqlite:///path/to/db.sqlite")]
542 #[case("postgresql://user@localhost:5432/mydb")]
543 fn test_mask_url_password_without_password(#[case] input: &str) {
544 let masked = mask_url_password(input);
549
550 assert_eq!(masked, input, "URL without password should be unchanged");
552 }
553
554 #[rstest]
555 fn test_mask_url_password_empty_password() {
556 let url = "postgresql://user:@localhost:5432/mydb";
558
559 let masked = mask_url_password(url);
561
562 assert_eq!(masked, "postgresql://user:***@localhost:5432/mydb");
564 }
565
566 #[rstest]
567 fn test_mask_url_password_special_chars_in_password() {
568 let url = "postgresql://user:p%40ss%3Aw0rd@localhost:5432/mydb";
570
571 let masked = mask_url_password(url);
573
574 assert_eq!(masked, "postgresql://user:***@localhost:5432/mydb");
576 assert!(
577 !masked.contains("p%40ss"),
578 "Password should be fully masked"
579 );
580 }
581
582 #[rstest]
583 fn test_mask_url_password_preserves_non_url() {
584 let non_url = "not-a-url-just-a-string";
586
587 let masked = mask_url_password(non_url);
589
590 assert_eq!(
592 masked, non_url,
593 "Non-URL strings should pass through unchanged"
594 );
595 }
596
597 #[rstest]
598 fn test_handle_try_current_returns_err_outside_runtime() {
599 let handle = std::thread::spawn(|| {
603 let result = tokio::runtime::Handle::try_current();
604 assert!(
605 result.is_err(),
606 "Handle::try_current() should return Err outside of a tokio runtime"
607 );
608 });
609 handle.join().expect("thread should not panic");
610 }
611
612 #[rstest]
613 fn test_drop_pooled_connection_outside_runtime_does_not_panic() {
614 let rt = tokio::runtime::Runtime::new().expect("failed to create Tokio runtime");
617
618 let (pool, conn) = rt.block_on(async {
619 let config = PoolConfig::default();
620 let pool = ConnectionPool::new_sqlite("sqlite::memory:", config)
621 .await
622 .expect("failed to create ConnectionPool");
623
624 let conn = pool.acquire().await.expect("failed to acquire connection");
625
626 (pool, conn)
627 });
628
629 drop(rt);
631
632 drop(conn);
635
636 drop(pool);
638 }
639}