bestool_postgres/
pool.rs

1use std::time::Duration;
2
3use miette::{IntoDiagnostic, Report, Result, WrapErr, miette};
4use mobc::{Connection, Pool};
5use tokio_postgres::config::SslMode;
6use tracing::debug;
7
8pub use manager::{PgConnectionManager, PgError};
9
10mod manager;
11mod tls;
12mod url;
13
14/// Check if an error is an authentication error
15fn is_auth_error(error: &Report) -> bool {
16	if let Some(db_error) = error.downcast_ref::<tokio_postgres::Error>()
17		&& let Some(db_error) = db_error.as_db_error()
18	{
19		// PostgreSQL error codes for authentication failures:
20		// 28000 - invalid_authorization_specification
21		// 28P01 - invalid_password
22		let code = db_error.code().code();
23		return code == "28000" || code == "28P01";
24	}
25
26	// Check for other connection errors that might indicate auth issues
27	let message = error.to_string();
28	message.contains("password authentication failed")
29		|| message.contains("no password supplied")
30		|| message.contains("authentication failed")
31		|| message.contains("invalid configuration")
32}
33
34pub type PgConnection = Connection<manager::PgConnectionManager>;
35
36#[derive(Debug, Clone)]
37pub struct PgPool {
38	pub manager: manager::PgConnectionManager,
39	pub inner: Pool<manager::PgConnectionManager>,
40}
41
42impl PgPool {
43	/// Returns a single connection by either opening a new connection
44	/// or returning an existing connection from the connection pool. Conn will
45	/// block until either a connection is returned or timeout.
46	pub async fn get(&self) -> Result<PgConnection, mobc::Error<PgError>> {
47		self.inner.get().await
48	}
49
50	/// Retrieves a connection from the pool, waiting for at most `timeout`
51	///
52	/// The given timeout will be used instead of the configured connection
53	/// timeout.
54	pub async fn get_timeout(
55		&self,
56		duration: Duration,
57	) -> Result<PgConnection, mobc::Error<PgError>> {
58		self.inner.get_timeout(duration).await
59	}
60}
61
62/// Create a connection pool from a connection URL
63///
64/// Supports Unix socket connections via:
65/// - Query parameter: `postgresql:///dbname?host=/var/run/postgresql`
66/// - Percent-encoded host: `postgresql://%2Fvar%2Frun%2Fpostgresql/dbname`
67/// - Empty host (auto-detects Unix socket or falls back to localhost): `postgresql:///dbname`
68///
69/// Unix socket connections automatically disable SSL/TLS.
70///
71/// # Password Prompting
72///
73/// If the connection fails with an authentication error and no password was provided
74/// in the connection URL, the function will prompt the user to enter a password
75/// interactively. The password will be read securely without echoing to the terminal.
76///
77/// # Examples
78///
79/// ```no_run
80/// # use bestool_postgres::create_pool;
81/// # async fn example() -> miette::Result<()> {
82/// // Connect via Unix socket using query parameter
83/// let pool = create_pool("postgresql:///postgres?host=/var/run/postgresql", "myapp", "test").await?;
84///
85/// // Connect via percent-encoded Unix socket path
86/// let pool = create_pool("postgresql://%2Fvar%2Frun%2Fpostgresql/postgres", "myapp", "test").await?;
87///
88/// // Connect with auto-detection (tries Unix socket first, then localhost)
89/// let pool = create_pool("postgresql:///postgres", "myapp", "test").await?;
90///
91/// // Traditional TCP connection
92/// let pool = create_pool("postgresql://user:pass@localhost:5432/dbname", "myapp", "test").await?;
93///
94/// // Connection without password (will prompt if authentication required)
95/// let pool = create_pool("postgresql://user@localhost/dbname", "myapp", "test").await?;
96/// # Ok(())
97/// # }
98/// ```
99pub async fn create_pool(url: &str, application_name: &str) -> Result<PgPool> {
100	let mut config = url::parse_connection_url(url)?;
101
102	config.application_name(application_name);
103
104	// Try to connect, and if it fails with auth error, prompt for password
105	let pool = loop {
106		debug!("Creating manager");
107		let tls = config.get_ssl_mode() != SslMode::Disable;
108		let manager = crate::pool::PgConnectionManager::new(config.clone(), tls);
109
110		debug!("Creating pool");
111		let pool = Pool::builder()
112			.max_lifetime(Some(Duration::from_secs(3600)))
113			.build(manager.clone());
114
115		let pool = PgPool {
116			manager,
117			inner: pool,
118		};
119
120		debug!("Checking pool");
121		match check_pool(&pool).await {
122			Ok(_) => break pool,
123			Err(e) => {
124				if is_auth_error(&e) && config.get_password().is_none() {
125					let password = rpassword::prompt_password("Password: ").into_diagnostic()?;
126					config.password(password);
127					// Loop will retry with the new password
128				} else {
129					// Not an auth error or we already have a password, re-throw
130					return Err(e);
131				}
132			}
133		}
134	};
135
136	Ok(pool)
137}
138
139/// Check if we can actually establish a connection
140async fn check_pool(pool: &PgPool) -> Result<()> {
141	let conn = match pool.get().await {
142		Err(mobc::Error::Inner(db_err)) => {
143			return Err(match db_err.as_db_error() {
144				Some(db_err) => miette!(
145					"E{code} at {func} in {file}:{line}",
146					code = db_err.code().code(),
147					func = db_err.routine().unwrap_or("{unknown}"),
148					file = db_err.file().unwrap_or("unknown.c"),
149					line = db_err.line().unwrap_or(0)
150				),
151				_ => miette!("{db_err}"),
152			})
153			.wrap_err(
154				db_err
155					.as_db_error()
156					.map(|e| e.to_string())
157					.unwrap_or_default(),
158			)?;
159		}
160		res @ Err(_) => {
161			let res = res.map(drop).into_diagnostic();
162			return if let Err(ref err) = res
163				&& is_auth_error(err)
164			{
165				res.wrap_err("hint: check the password")
166			} else {
167				res
168			};
169		}
170		Ok(conn) => conn,
171	};
172	conn.simple_query("SELECT 1")
173		.await
174		.into_diagnostic()
175		.wrap_err("checking connection")?;
176	Ok(())
177}
178
179#[cfg(test)]
180mod tests {
181	use super::*;
182
183	#[tokio::test]
184	async fn test_create_pool_valid_connection_string() {
185		let connection_string = "postgresql://localhost/test";
186		let result = create_pool(connection_string, "test").await;
187		// May fail if database doesn't exist, but should not be a parsing error
188		if let Err(e) = result {
189			let error_msg = format!("{:?}", e);
190			assert!(
191				!error_msg.contains("parsing connection string"),
192				"Should not be a parsing error: {}",
193				error_msg
194			);
195		}
196	}
197
198	#[tokio::test]
199	async fn test_create_pool_with_full_url() {
200		let connection_string = "postgresql://user:pass@localhost:5432/testdb";
201		let result = create_pool(connection_string, "test").await;
202		// May fail if database doesn't exist or auth fails, but should not be a parsing error
203		if let Err(e) = result {
204			let error_msg = format!("{:?}", e);
205			assert!(
206				!error_msg.contains("parsing connection string"),
207				"Should not be a parsing error: {}",
208				error_msg
209			);
210		}
211	}
212
213	#[tokio::test]
214	async fn test_create_pool_with_unix_socket_path() {
215		// Test connecting via Unix socket path
216		let url = "postgresql:///postgres?host=/var/run/postgresql";
217		let result = create_pool(url, "test").await;
218		// This may fail if PostgreSQL isn't running or isn't accessible via Unix socket
219		// but we can at least verify the parsing works
220		match result {
221			Ok(_) => {
222				// Connection succeeded
223			}
224			Err(e) => {
225				let error_msg = format!("{:?}", e);
226				// Verify it's not a parsing error but a connection error
227				assert!(
228					!error_msg.contains("parsing connection string"),
229					"Should not be a parsing error: {}",
230					error_msg
231				);
232			}
233		}
234	}
235
236	#[tokio::test]
237	async fn test_create_pool_with_encoded_unix_socket() {
238		// Test connecting via percent-encoded Unix socket path in host
239		let url = "postgresql://%2Fvar%2Frun%2Fpostgresql/postgres";
240		let result = create_pool(url, "test").await;
241		// This may fail if PostgreSQL isn't running, but parsing should work
242		match result {
243			Ok(_) => {
244				// Connection succeeded
245			}
246			Err(e) => {
247				let error_msg = format!("{:?}", e);
248				// Verify it's not a parsing error
249				assert!(
250					!error_msg.contains("parsing connection string"),
251					"Should not be a parsing error: {}",
252					error_msg
253				);
254			}
255		}
256	}
257
258	#[tokio::test]
259	async fn test_create_pool_with_no_host() {
260		// Test connection with no host specified (should try Unix socket or fallback to localhost)
261		let url = "postgresql:///postgres";
262		let result = create_pool(url, "test").await;
263		// This should either succeed or fail with a connection error, not a parsing error
264		match result {
265			Ok(_) => {
266				// Connection succeeded
267			}
268			Err(e) => {
269				let error_msg = format!("{:?}", e);
270				// Verify it's not a parsing error
271				assert!(
272					!error_msg.contains("parsing connection string"),
273					"Should not be a parsing error: {}",
274					error_msg
275				);
276			}
277		}
278	}
279
280	#[tokio::test]
281	async fn test_unix_socket_connection_end_to_end() {
282		// Test that we can actually connect and query via Unix socket
283		let url = "postgresql:///postgres?host=/var/run/postgresql";
284		let result = create_pool(url, "test").await;
285
286		match result {
287			Ok(pool) => {
288				// If connection succeeded, try a simple query
289				let conn = pool.get().await;
290				if let Ok(conn) = conn {
291					let result = conn.simple_query("SELECT 1 as test").await;
292					assert!(result.is_ok(), "Query should succeed");
293				}
294			}
295			Err(e) => {
296				let error_msg = format!("{:?}", e);
297				// If it failed, make sure it's not a parsing or TLS error
298				assert!(
299					!error_msg.contains("parsing connection string"),
300					"Should not be a parsing error: {}",
301					error_msg
302				);
303				assert!(
304					!error_msg.contains("TLS handshake"),
305					"Should not be a TLS error for Unix socket: {}",
306					error_msg
307				);
308			}
309		}
310	}
311}