Skip to main content

reinhardt_db/pool/
pool.rs

1//! Connection pool implementation
2
3use super::config::PoolConfig;
4use super::errors::{PoolError, PoolResult};
5use super::events::{PoolEvent, PoolEventListener};
6use sqlx::{Database, MySql, Pool, Postgres, Sqlite};
7use std::mem::ManuallyDrop;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use tokio::sync::RwLock;
11
12/// Mask the password in a database URL for safe display.
13///
14/// Handles standard URL formats like `scheme://user:password@host/db`
15/// and replaces the password portion with `***`.
16/// Correctly handles passwords containing `@` by using the last `@` as
17/// the user-info delimiter.
18pub(crate) fn mask_url_password(url: &str) -> String {
19	// Try to parse as a standard URL with scheme://user:pass@host format
20	if let Some(scheme_end) = url.find("://") {
21		let after_scheme = &url[scheme_end + 3..];
22
23		// Use the last @ as the user-info delimiter, since passwords may contain @
24		if let Some(at_pos) = after_scheme.rfind('@') {
25			let user_info = &after_scheme[..at_pos];
26
27			// Find the first colon separating user from password
28			if let Some(colon_pos) = user_info.find(':') {
29				let scheme_and_user = &url[..scheme_end + 3 + colon_pos + 1];
30				let rest = &url[scheme_end + 3 + at_pos..];
31				return format!("{}***{}", scheme_and_user, rest);
32			}
33		}
34	}
35
36	// No password found, return as-is
37	url.to_string()
38}
39
40/// A database connection pool
41pub struct ConnectionPool<DB: Database> {
42	pool: Pool<DB>,
43	config: PoolConfig,
44	url: String,
45	listeners: Arc<RwLock<Vec<Arc<dyn PoolEventListener>>>>,
46	first_connect_fired: Arc<AtomicBool>,
47}
48
49impl ConnectionPool<Postgres> {
50	/// Create a new PostgreSQL connection pool
51	///
52	/// # Examples
53	///
54	/// ```
55	/// use reinhardt_db::pool::{ConnectionPool, PoolConfig};
56	///
57	/// # async fn example() {
58	/// let config = PoolConfig::default();
59	/// // For doctest purposes, using SQLite in-memory instead of PostgreSQL
60	/// let pool = ConnectionPool::new_sqlite("sqlite::memory:", config).await.unwrap();
61	/// assert!(pool.url().contains("memory"));
62	/// assert_eq!(pool.config().max_connections, 10);
63	/// # }
64	/// # tokio::runtime::Runtime::new().unwrap().block_on(example());
65	/// ```
66	pub async fn new_postgres(url: &str, config: PoolConfig) -> PoolResult<Self> {
67		config.validate().map_err(PoolError::Config)?;
68
69		let pool = sqlx::postgres::PgPoolOptions::new()
70			.min_connections(config.min_connections)
71			.max_connections(config.max_connections)
72			.acquire_timeout(config.acquire_timeout)
73			.idle_timeout(config.idle_timeout)
74			.max_lifetime(config.max_lifetime)
75			.test_before_acquire(config.test_before_acquire)
76			.connect(url)
77			.await?;
78
79		Ok(Self {
80			pool,
81			config,
82			url: url.to_string(),
83			listeners: Arc::new(RwLock::new(Vec::new())),
84			first_connect_fired: Arc::new(AtomicBool::new(false)),
85		})
86	}
87}
88
89impl ConnectionPool<MySql> {
90	/// Create a new MySQL connection pool
91	///
92	/// # Examples
93	///
94	/// ```
95	/// use reinhardt_db::pool::{ConnectionPool, PoolConfig};
96	///
97	/// # async fn example() {
98	/// let config = PoolConfig::default();
99	/// // For doctest purposes, using SQLite in-memory instead of MySQL
100	/// let pool = ConnectionPool::new_sqlite("sqlite::memory:", config).await.unwrap();
101	/// assert!(pool.url().contains("memory"));
102	/// assert_eq!(pool.config().max_connections, 10);
103	/// # }
104	/// # tokio::runtime::Runtime::new().unwrap().block_on(example());
105	/// ```
106	pub async fn new_mysql(url: &str, config: PoolConfig) -> PoolResult<Self> {
107		config.validate().map_err(PoolError::Config)?;
108
109		let pool = sqlx::mysql::MySqlPoolOptions::new()
110			.min_connections(config.min_connections)
111			.max_connections(config.max_connections)
112			.acquire_timeout(config.acquire_timeout)
113			.idle_timeout(config.idle_timeout)
114			.max_lifetime(config.max_lifetime)
115			.test_before_acquire(config.test_before_acquire)
116			.connect(url)
117			.await?;
118
119		Ok(Self {
120			pool,
121			config,
122			url: url.to_string(),
123			listeners: Arc::new(RwLock::new(Vec::new())),
124			first_connect_fired: Arc::new(AtomicBool::new(false)),
125		})
126	}
127}
128
129impl ConnectionPool<Sqlite> {
130	/// Create a new SQLite connection pool
131	///
132	/// # Examples
133	///
134	/// ```
135	/// use reinhardt_db::pool::{ConnectionPool, PoolConfig};
136	///
137	/// # async fn example() {
138	/// let config = PoolConfig::default();
139	/// // Using in-memory SQLite for doctest
140	/// let pool = ConnectionPool::new_sqlite("sqlite::memory:", config).await.unwrap();
141	/// assert!(pool.url().contains("memory"));
142	/// assert!(pool.config().max_connections > 0);
143	/// # }
144	/// # tokio::runtime::Runtime::new().unwrap().block_on(example());
145	/// ```
146	pub async fn new_sqlite(url: &str, config: PoolConfig) -> PoolResult<Self> {
147		config.validate().map_err(PoolError::Config)?;
148
149		let pool = sqlx::sqlite::SqlitePoolOptions::new()
150			.min_connections(config.min_connections)
151			.max_connections(config.max_connections)
152			.acquire_timeout(config.acquire_timeout)
153			.idle_timeout(config.idle_timeout)
154			.max_lifetime(config.max_lifetime)
155			.test_before_acquire(config.test_before_acquire)
156			.connect(url)
157			.await?;
158
159		Ok(Self {
160			pool,
161			config,
162			url: url.to_string(),
163			listeners: Arc::new(RwLock::new(Vec::new())),
164			first_connect_fired: Arc::new(AtomicBool::new(false)),
165		})
166	}
167}
168
169impl<DB> ConnectionPool<DB>
170where
171	DB: sqlx::Database,
172{
173	/// Add an event listener
174	///
175	pub async fn add_listener(&self, listener: Arc<dyn PoolEventListener>) {
176		let mut listeners = self.listeners.write().await;
177		listeners.push(listener);
178	}
179
180	/// Emit an event to all listeners
181	pub(crate) async fn emit_event(&self, event: PoolEvent) {
182		let listeners = self.listeners.read().await;
183		for listener in listeners.iter() {
184			listener.on_event(event.clone()).await;
185		}
186	}
187	/// Acquire a connection from the pool with event emission
188	///
189	/// # Examples
190	///
191	/// ```no_run
192	/// use reinhardt_db::pool::{ConnectionPool, PoolConfig};
193	///
194	/// # async fn example() {
195	/// let config = PoolConfig::default();
196	/// let pool = ConnectionPool::new_postgres("postgresql://user:pass@localhost/test", config)
197	///     .await
198	///     .unwrap();
199	///
200	/// // Acquire a connection
201	/// let conn = pool.acquire().await;
202	/// assert!(conn.is_ok());
203	/// # }
204	/// ```
205	pub async fn acquire(&self) -> PoolResult<PooledConnection<DB>> {
206		// Check if this is the first connection
207		let is_first = !self.first_connect_fired.swap(true, Ordering::SeqCst);
208
209		let conn = self.pool.acquire().await?;
210		let connection_id = uuid::Uuid::new_v4().to_string();
211
212		if is_first {
213			// Emit first_connect event (using ConnectionCreated as proxy)
214			self.emit_event(PoolEvent::connection_created(connection_id.clone()))
215				.await;
216		}
217
218		// Emit checkout event
219		self.emit_event(PoolEvent::connection_acquired(connection_id.clone()))
220			.await;
221
222		Ok(PooledConnection {
223			conn: ManuallyDrop::new(conn),
224			pool_ref: self.clone_arc(),
225			connection_id,
226		})
227	}
228
229	/// Clone as Arc for sharing with PooledConnection
230	fn clone_arc(&self) -> Arc<Self> {
231		Arc::new(Self {
232			pool: self.pool.clone(),
233			config: self.config.clone(),
234			url: self.url.clone(),
235			listeners: self.listeners.clone(),
236			first_connect_fired: self.first_connect_fired.clone(),
237		})
238	}
239	/// Get the underlying pool
240	///
241	pub fn inner(&self) -> &Pool<DB> {
242		&self.pool
243	}
244	/// Get pool configuration
245	///
246	pub fn config(&self) -> &PoolConfig {
247		&self.config
248	}
249	/// Close the pool
250	///
251	/// Attempts to gracefully close the pool with a 5-second timeout.
252	/// If active connections are not returned within this time, the pool
253	/// will be forcefully closed.
254	pub async fn close(&self) {
255		use tokio::time::{Duration, timeout};
256
257		// Try to close gracefully with a timeout
258		let close_future = self.pool.close();
259		if timeout(Duration::from_secs(5), close_future).await.is_err() {
260			// Timeout occurred - pool had active connections
261			// The pool will be forcefully closed when dropped
262		}
263	}
264	/// Get the database URL with password masked for safe display
265	///
266	/// Returns the database URL with any password replaced by `***`
267	/// to prevent credential exposure in logs and debug output.
268	/// Use `url_raw()` when the actual password is needed for reconnection.
269	pub fn url(&self) -> String {
270		mask_url_password(&self.url)
271	}
272
273	/// Get the raw database URL including credentials
274	///
275	/// This method returns the unmasked URL containing the actual password.
276	/// Use with caution - prefer `url()` for logging and display purposes.
277	// Allow dead_code: preserved for internal use by reconnection logic (e.g., `recreate()`)
278	#[allow(dead_code)]
279	pub(crate) fn url_raw(&self) -> &str {
280		&self.url
281	}
282}
283
284// Database-specific recreate implementations
285impl ConnectionPool<Postgres> {
286	/// Recreate the pool with the same configuration
287	///
288	/// # Examples
289	///
290	/// ```
291	/// use reinhardt_db::pool::{ConnectionPool, PoolConfig};
292	///
293	/// # async fn example() {
294	/// let config = PoolConfig::default();
295	/// // For doctest purposes, using SQLite in-memory instead of PostgreSQL
296	/// let mut pool = ConnectionPool::new_sqlite("sqlite::memory:", config)
297	///     .await
298	///     .unwrap();
299	///
300	/// // Recreate the pool
301	/// pool.recreate().await.unwrap();
302	/// assert_eq!(pool.config().max_connections, 10);
303	/// # }
304	/// # tokio::runtime::Runtime::new().unwrap().block_on(example());
305	/// ```
306	pub async fn recreate(&mut self) -> PoolResult<()> {
307		// Close existing pool
308		self.pool.close().await;
309
310		// Create new pool with same configuration
311		let new_pool = sqlx::postgres::PgPoolOptions::new()
312			.min_connections(self.config.min_connections)
313			.max_connections(self.config.max_connections)
314			.acquire_timeout(self.config.acquire_timeout)
315			.idle_timeout(self.config.idle_timeout)
316			.max_lifetime(self.config.max_lifetime)
317			.test_before_acquire(self.config.test_before_acquire)
318			.connect(&self.url)
319			.await?;
320
321		self.pool = new_pool;
322		self.first_connect_fired.store(false, Ordering::SeqCst);
323
324		Ok(())
325	}
326}
327
328impl ConnectionPool<MySql> {
329	/// Recreate the pool with the same configuration
330	///
331	/// # Examples
332	///
333	/// ```
334	/// use reinhardt_db::pool::{ConnectionPool, PoolConfig};
335	///
336	/// # async fn example() {
337	/// let config = PoolConfig::default();
338	/// // For doctest purposes, using SQLite in-memory instead of MySQL
339	/// let mut pool = ConnectionPool::new_sqlite("sqlite::memory:", config)
340	///     .await
341	///     .unwrap();
342	///
343	/// // Recreate the pool
344	/// pool.recreate().await.unwrap();
345	/// assert_eq!(pool.config().max_connections, 10);
346	/// # }
347	/// # tokio::runtime::Runtime::new().unwrap().block_on(example());
348	/// ```
349	pub async fn recreate(&mut self) -> PoolResult<()> {
350		// Close existing pool
351		self.pool.close().await;
352
353		// Create new pool with same configuration
354		let new_pool = sqlx::mysql::MySqlPoolOptions::new()
355			.min_connections(self.config.min_connections)
356			.max_connections(self.config.max_connections)
357			.acquire_timeout(self.config.acquire_timeout)
358			.idle_timeout(self.config.idle_timeout)
359			.max_lifetime(self.config.max_lifetime)
360			.test_before_acquire(self.config.test_before_acquire)
361			.connect(&self.url)
362			.await?;
363
364		self.pool = new_pool;
365		self.first_connect_fired.store(false, Ordering::SeqCst);
366
367		Ok(())
368	}
369}
370
371impl ConnectionPool<Sqlite> {
372	/// Recreate the pool with the same configuration
373	///
374	/// # Examples
375	///
376	/// ```
377	/// use reinhardt_db::pool::{ConnectionPool, PoolConfig};
378	///
379	/// # async fn example() {
380	/// let config = PoolConfig::default();
381	/// let mut pool = ConnectionPool::new_sqlite("sqlite::memory:", config)
382	///     .await
383	///     .unwrap();
384	///
385	/// // Recreate the pool
386	/// pool.recreate().await.unwrap();
387	/// assert!(pool.url().contains("memory"));
388	/// # }
389	/// # tokio::runtime::Runtime::new().unwrap().block_on(example());
390	/// ```
391	pub async fn recreate(&mut self) -> PoolResult<()> {
392		// Close existing pool
393		self.pool.close().await;
394
395		// Create new pool with same configuration
396		let new_pool = sqlx::sqlite::SqlitePoolOptions::new()
397			.min_connections(self.config.min_connections)
398			.max_connections(self.config.max_connections)
399			.acquire_timeout(self.config.acquire_timeout)
400			.idle_timeout(self.config.idle_timeout)
401			.max_lifetime(self.config.max_lifetime)
402			.test_before_acquire(self.config.test_before_acquire)
403			.connect(&self.url)
404			.await?;
405
406		self.pool = new_pool;
407		self.first_connect_fired.store(false, Ordering::SeqCst);
408
409		Ok(())
410	}
411}
412
413/// A pooled connection wrapper with event emission
414pub struct PooledConnection<DB: sqlx::Database> {
415	// Wrapped in ManuallyDrop so we can take ownership in Drop.
416	// When no tokio runtime is available, we detach the connection
417	// to avoid sqlx's PoolConnection::Drop calling rt::spawn().
418	conn: ManuallyDrop<sqlx::pool::PoolConnection<DB>>,
419	pool_ref: Arc<ConnectionPool<DB>>,
420	connection_id: String,
421}
422
423impl<DB: sqlx::Database> PooledConnection<DB> {
424	/// Documentation for `inner`
425	///
426	pub fn inner(&mut self) -> &mut sqlx::pool::PoolConnection<DB> {
427		&mut self.conn
428	}
429	/// Get the unique identifier for this connection
430	///
431	/// # Examples
432	///
433	/// ```no_run
434	/// use reinhardt_db::pool::{ConnectionPool, PoolConfig};
435	///
436	/// # async fn example() {
437	/// let config = PoolConfig::default();
438	/// let pool = ConnectionPool::new_postgres("postgresql://user:pass@localhost/test", config)
439	///     .await
440	///     .unwrap();
441	///
442	/// let mut conn = pool.acquire().await.unwrap();
443	/// let id = conn.connection_id();
444	/// assert!(!id.is_empty());
445	/// # }
446	/// ```
447	pub fn connection_id(&self) -> &str {
448		&self.connection_id
449	}
450	/// Invalidate this connection (hard invalidation - connection is unusable)
451	///
452	pub async fn invalidate(self, reason: String) {
453		self.pool_ref
454			.emit_event(PoolEvent::connection_invalidated(
455				self.connection_id.clone(),
456				reason,
457			))
458			.await;
459		// Connection will be dropped and not returned to pool
460	}
461	/// Soft invalidate this connection (can complete current operation)
462	///
463	pub async fn soft_invalidate(&mut self) {
464		self.pool_ref
465			.emit_event(PoolEvent::connection_soft_invalidated(
466				self.connection_id.clone(),
467			))
468			.await;
469	}
470	/// Reset this connection
471	///
472	pub async fn reset(&mut self) {
473		self.pool_ref
474			.emit_event(PoolEvent::connection_reset(self.connection_id.clone()))
475			.await;
476	}
477}
478
479impl<DB: sqlx::Database> Drop for PooledConnection<DB> {
480	fn drop(&mut self) {
481		// SAFETY: ManuallyDrop::take is called exactly once (in drop).
482		let conn = unsafe { ManuallyDrop::take(&mut self.conn) };
483
484		match tokio::runtime::Handle::try_current() {
485			Ok(handle) => {
486				// Runtime available: drop the connection normally (returns to pool)
487				// and emit the connection-returned event.
488				drop(conn);
489
490				let pool_ref = self.pool_ref.clone();
491				let connection_id = self.connection_id.clone();
492
493				handle.spawn(async move {
494					pool_ref
495						.emit_event(PoolEvent::connection_returned(connection_id))
496						.await;
497				});
498			}
499			Err(_) => {
500				// No runtime available: prevent sqlx's PoolConnection::Drop
501				// from running, as it calls crate::rt::spawn() which panics
502				// without a tokio runtime. The connection is intentionally
503				// leaked to avoid the panic.
504				std::mem::forget(conn);
505			}
506		}
507	}
508}
509
510#[cfg(test)]
511mod tests {
512	use super::*;
513	use rstest::rstest;
514
515	#[rstest]
516	#[case(
517		"postgresql://user:secret@localhost:5432/mydb",
518		"postgresql://user:***@localhost:5432/mydb"
519	)]
520	#[case(
521		"mysql://admin:p@ssw0rd@db.example.com/app",
522		"mysql://admin:***@db.example.com/app"
523	)]
524	#[case(
525		"postgres://user:pass@host:5432/db?sslmode=require",
526		"postgres://user:***@host:5432/db?sslmode=require"
527	)]
528	fn test_mask_url_password_with_credentials(#[case] input: &str, #[case] expected: &str) {
529		// Arrange
530		// (input provided by case parameters)
531
532		// Act
533		let masked = mask_url_password(input);
534
535		// Assert
536		assert_eq!(masked, expected);
537	}
538
539	#[rstest]
540	#[case("sqlite::memory:")]
541	#[case("sqlite:///path/to/db.sqlite")]
542	#[case("postgresql://user@localhost:5432/mydb")]
543	fn test_mask_url_password_without_password(#[case] input: &str) {
544		// Arrange
545		// (input provided by case parameter)
546
547		// Act
548		let masked = mask_url_password(input);
549
550		// Assert
551		assert_eq!(masked, input, "URL without password should be unchanged");
552	}
553
554	#[rstest]
555	fn test_mask_url_password_empty_password() {
556		// Arrange
557		let url = "postgresql://user:@localhost:5432/mydb";
558
559		// Act
560		let masked = mask_url_password(url);
561
562		// Assert
563		assert_eq!(masked, "postgresql://user:***@localhost:5432/mydb");
564	}
565
566	#[rstest]
567	fn test_mask_url_password_special_chars_in_password() {
568		// Arrange
569		let url = "postgresql://user:p%40ss%3Aw0rd@localhost:5432/mydb";
570
571		// Act
572		let masked = mask_url_password(url);
573
574		// Assert
575		assert_eq!(masked, "postgresql://user:***@localhost:5432/mydb");
576		assert!(
577			!masked.contains("p%40ss"),
578			"Password should be fully masked"
579		);
580	}
581
582	#[rstest]
583	fn test_mask_url_password_preserves_non_url() {
584		// Arrange
585		let non_url = "not-a-url-just-a-string";
586
587		// Act
588		let masked = mask_url_password(non_url);
589
590		// Assert
591		assert_eq!(
592			masked, non_url,
593			"Non-URL strings should pass through unchanged"
594		);
595	}
596
597	#[rstest]
598	fn test_handle_try_current_returns_err_outside_runtime() {
599		// Arrange & Act & Assert
600		// Run on a fresh thread to avoid inheriting runtime context
601		// from the test runner's worker thread.
602		let handle = std::thread::spawn(|| {
603			let result = tokio::runtime::Handle::try_current();
604			assert!(
605				result.is_err(),
606				"Handle::try_current() should return Err outside of a tokio runtime"
607			);
608		});
609		handle.join().expect("thread should not panic");
610	}
611
612	#[rstest]
613	fn test_drop_pooled_connection_outside_runtime_does_not_panic() {
614		// Arrange
615		// Create a Tokio runtime and acquire a pooled connection inside it.
616		let rt = tokio::runtime::Runtime::new().expect("failed to create Tokio runtime");
617
618		let (pool, conn) = rt.block_on(async {
619			let config = PoolConfig::default();
620			let pool = ConnectionPool::new_sqlite("sqlite::memory:", config)
621				.await
622				.expect("failed to create ConnectionPool");
623
624			let conn = pool.acquire().await.expect("failed to acquire connection");
625
626			(pool, conn)
627		});
628
629		// Drop the runtime so there is no active Tokio runtime.
630		drop(rt);
631
632		// Act & Assert
633		// Dropping the connection outside any runtime should not panic.
634		drop(conn);
635
636		// Also drop the pool to ensure cleanup does not panic outside a runtime.
637		drop(pool);
638	}
639}