bestool_postgres/
pool.rs

1use std::time::Duration;
2
3use miette::{IntoDiagnostic, Report, Result, WrapErr, miette};
4use mobc::{Connection, Pool};
5use tokio_postgres::config::SslMode;
6use tracing::debug;
7
8pub use manager::{PgConnectionManager, PgError};
9
10mod manager;
11mod tls;
12mod url;
13
14/// Check if an error is an authentication error
15fn is_auth_error(error: &Report) -> bool {
16	if let Some(db_error) = error.downcast_ref::<tokio_postgres::Error>()
17		&& let Some(db_error) = db_error.as_db_error()
18	{
19		// PostgreSQL error codes for authentication failures:
20		// 28000 - invalid_authorization_specification
21		// 28P01 - invalid_password
22		let code = db_error.code().code();
23		return code == "28000" || code == "28P01";
24	}
25
26	// Check for other connection errors that might indicate auth issues
27	let message = error.to_string();
28	message.contains("password authentication failed")
29		|| message.contains("no password supplied")
30		|| message.contains("authentication failed")
31		|| message.contains("invalid configuration")
32}
33
34pub type PgConnection = Connection<manager::PgConnectionManager>;
35
36#[derive(Debug, Clone)]
37pub struct PgPool {
38	pub manager: manager::PgConnectionManager,
39	pub inner: Pool<manager::PgConnectionManager>,
40}
41
42impl PgPool {
43	/// Returns a single connection by either opening a new connection
44	/// or returning an existing connection from the connection pool. Conn will
45	/// block until either a connection is returned or timeout.
46	pub async fn get(&self) -> Result<PgConnection, mobc::Error<PgError>> {
47		self.inner.get().await
48	}
49
50	/// Retrieves a connection from the pool, waiting for at most `timeout`
51	///
52	/// The given timeout will be used instead of the configured connection
53	/// timeout.
54	pub async fn get_timeout(
55		&self,
56		duration: Duration,
57	) -> Result<PgConnection, mobc::Error<PgError>> {
58		self.inner.get_timeout(duration).await
59	}
60}
61
62/// Create a connection pool from a connection URL
63///
64/// Supports Unix socket connections via:
65/// - Query parameter: `postgresql:///dbname?host=/var/run/postgresql`
66/// - Percent-encoded host: `postgresql://%2Fvar%2Frun%2Fpostgresql/dbname`
67/// - Empty host (auto-detects Unix socket or falls back to localhost): `postgresql:///dbname`
68///
69/// Unix socket connections automatically disable SSL/TLS.
70///
71/// # Password Prompting
72///
73/// If the connection fails with an authentication error and no password was provided
74/// in the connection URL, the function will prompt the user to enter a password
75/// interactively. The password will be read securely without echoing to the terminal.
76pub async fn create_pool(url: &str, application_name: &str) -> Result<PgPool> {
77	let mut config = url::parse_connection_url(url)?;
78
79	config.application_name(application_name);
80
81	// Try to connect, and if it fails with auth error, prompt for password
82	let pool = loop {
83		debug!("Creating manager");
84		let tls = config.get_ssl_mode() != SslMode::Disable;
85		let manager = crate::pool::PgConnectionManager::new(config.clone(), tls);
86
87		debug!("Creating pool");
88		let pool = Pool::builder()
89			.max_lifetime(Some(Duration::from_secs(3600)))
90			.build(manager.clone());
91
92		let pool = PgPool {
93			manager,
94			inner: pool,
95		};
96
97		debug!("Checking pool");
98		match check_pool(&pool).await {
99			Ok(_) => break pool,
100			Err(e) => {
101				if is_auth_error(&e) && config.get_password().is_none() {
102					let password = rpassword::prompt_password("Password: ").into_diagnostic()?;
103					config.password(password);
104					// Loop will retry with the new password
105				} else {
106					// Not an auth error or we already have a password, re-throw
107					return Err(e);
108				}
109			}
110		}
111	};
112
113	Ok(pool)
114}
115
116/// Check if we can actually establish a connection
117async fn check_pool(pool: &PgPool) -> Result<()> {
118	let conn = match pool.get().await {
119		Err(mobc::Error::Inner(db_err)) => {
120			return Err(match db_err.as_db_error() {
121				Some(db_err) => miette!(
122					"E{code} at {func} in {file}:{line}",
123					code = db_err.code().code(),
124					func = db_err.routine().unwrap_or("{unknown}"),
125					file = db_err.file().unwrap_or("unknown.c"),
126					line = db_err.line().unwrap_or(0)
127				),
128				_ => miette!("{db_err}"),
129			})
130			.wrap_err(
131				db_err
132					.as_db_error()
133					.map(|e| e.to_string())
134					.unwrap_or_default(),
135			)?;
136		}
137		res @ Err(_) => {
138			let res = res.map(drop).into_diagnostic();
139			return if let Err(ref err) = res
140				&& is_auth_error(err)
141			{
142				res.wrap_err("hint: check the password")
143			} else {
144				res
145			};
146		}
147		Ok(conn) => conn,
148	};
149	conn.simple_query("SELECT 1")
150		.await
151		.into_diagnostic()
152		.wrap_err("checking connection")?;
153	Ok(())
154}
155
156#[cfg(test)]
157mod tests {
158	use super::*;
159
160	#[tokio::test]
161	async fn test_create_pool_valid_connection_string() {
162		let connection_string = "postgresql://localhost/test";
163		let result = create_pool(connection_string, "test").await;
164		// May fail if database doesn't exist, but should not be a parsing error
165		if let Err(e) = result {
166			let error_msg = format!("{:?}", e);
167			assert!(
168				!error_msg.contains("parsing connection string"),
169				"Should not be a parsing error: {}",
170				error_msg
171			);
172		}
173	}
174
175	#[tokio::test]
176	async fn test_create_pool_with_full_url() {
177		let connection_string = "postgresql://user:pass@localhost:5432/testdb";
178		let result = create_pool(connection_string, "test").await;
179		// May fail if database doesn't exist or auth fails, but should not be a parsing error
180		if let Err(e) = result {
181			let error_msg = format!("{:?}", e);
182			assert!(
183				!error_msg.contains("parsing connection string"),
184				"Should not be a parsing error: {}",
185				error_msg
186			);
187		}
188	}
189
190	#[tokio::test]
191	async fn test_create_pool_with_unix_socket_path() {
192		// Test connecting via Unix socket path
193		let url = "postgresql:///postgres?host=/var/run/postgresql";
194		let result = create_pool(url, "test").await;
195		// This may fail if PostgreSQL isn't running or isn't accessible via Unix socket
196		// but we can at least verify the parsing works
197		match result {
198			Ok(_) => {
199				// Connection succeeded
200			}
201			Err(e) => {
202				let error_msg = format!("{:?}", e);
203				// Verify it's not a parsing error but a connection error
204				assert!(
205					!error_msg.contains("parsing connection string"),
206					"Should not be a parsing error: {}",
207					error_msg
208				);
209			}
210		}
211	}
212
213	#[tokio::test]
214	async fn test_create_pool_with_encoded_unix_socket() {
215		// Test connecting via percent-encoded Unix socket path in host
216		let url = "postgresql://%2Fvar%2Frun%2Fpostgresql/postgres";
217		let result = create_pool(url, "test").await;
218		// This may fail if PostgreSQL isn't running, but parsing should work
219		match result {
220			Ok(_) => {
221				// Connection succeeded
222			}
223			Err(e) => {
224				let error_msg = format!("{:?}", e);
225				// Verify it's not a parsing error
226				assert!(
227					!error_msg.contains("parsing connection string"),
228					"Should not be a parsing error: {}",
229					error_msg
230				);
231			}
232		}
233	}
234
235	#[tokio::test]
236	async fn test_create_pool_with_no_host() {
237		// Test connection with no host specified (should try Unix socket or fallback to localhost)
238		let url = "postgresql:///postgres";
239		let result = create_pool(url, "test").await;
240		// This should either succeed or fail with a connection error, not a parsing error
241		match result {
242			Ok(_) => {
243				// Connection succeeded
244			}
245			Err(e) => {
246				let error_msg = format!("{:?}", e);
247				// Verify it's not a parsing error
248				assert!(
249					!error_msg.contains("parsing connection string"),
250					"Should not be a parsing error: {}",
251					error_msg
252				);
253			}
254		}
255	}
256
257	#[tokio::test]
258	async fn test_unix_socket_connection_end_to_end() {
259		// Test that we can actually connect and query via Unix socket
260		let url = "postgresql:///postgres?host=/var/run/postgresql";
261		let result = create_pool(url, "test").await;
262
263		match result {
264			Ok(pool) => {
265				// If connection succeeded, try a simple query
266				let conn = pool.get().await;
267				if let Ok(conn) = conn {
268					let result = conn.simple_query("SELECT 1 as test").await;
269					assert!(result.is_ok(), "Query should succeed");
270				}
271			}
272			Err(e) => {
273				let error_msg = format!("{:?}", e);
274				// If it failed, make sure it's not a parsing or TLS error
275				assert!(
276					!error_msg.contains("parsing connection string"),
277					"Should not be a parsing error: {}",
278					error_msg
279				);
280				assert!(
281					!error_msg.contains("TLS handshake"),
282					"Should not be a TLS error for Unix socket: {}",
283					error_msg
284				);
285			}
286		}
287	}
288}