bestool_postgres/
pool.rs

1use std::time::Duration;
2
3use miette::{IntoDiagnostic, Report, Result, WrapErr, miette};
4use mobc::{Connection, Pool};
5use tokio_postgres::config::SslMode;
6use tracing::debug;
7
8pub use manager::{PgConnectionManager, PgError};
9
10mod manager;
11mod tls;
12mod url;
13
14/// Check if an error is a TLS/SSL error
15fn is_tls_error(error: &Report) -> bool {
16	if error.downcast_ref::<rustls::Error>().is_some() {
17		return true;
18	}
19
20	// Check the error chain for PgError::Tls
21	let mut source = error.source();
22	while let Some(err) = source {
23		if err.downcast_ref::<rustls::Error>().is_some() {
24			return true;
25		}
26		source = err.source();
27	}
28
29	let message = error.to_string();
30	message.contains("tls:")
31		|| message.contains("rustls")
32		|| message.contains("certificate")
33		|| message.contains("TLS handshake")
34		|| message.contains("invalid configuration")
35}
36
37/// Check if an error is an authentication error
38fn is_auth_error(error: &Report) -> bool {
39	if let Some(db_error) = error.downcast_ref::<tokio_postgres::Error>()
40		&& let Some(db_error) = db_error.as_db_error()
41	{
42		// PostgreSQL error codes for authentication failures:
43		// 28000 - invalid_authorization_specification
44		// 28P01 - invalid_password
45		let code = db_error.code().code();
46		return code == "28000" || code == "28P01";
47	}
48
49	// Check for other connection errors that might indicate auth issues
50	let message = error.to_string();
51	message.contains("password authentication failed")
52		|| message.contains("no password supplied")
53		|| message.contains("authentication failed")
54}
55
56pub type PgConnection = Connection<manager::PgConnectionManager>;
57
58#[derive(Debug, Clone)]
59pub struct PgPool {
60	pub manager: manager::PgConnectionManager,
61	pub inner: Pool<manager::PgConnectionManager>,
62}
63
64impl PgPool {
65	/// Returns a single connection by either opening a new connection
66	/// or returning an existing connection from the connection pool. Conn will
67	/// block until either a connection is returned or timeout.
68	pub async fn get(&self) -> Result<PgConnection, mobc::Error<PgError>> {
69		self.inner.get().await
70	}
71
72	/// Retrieves a connection from the pool, waiting for at most `timeout`
73	///
74	/// The given timeout will be used instead of the configured connection
75	/// timeout.
76	pub async fn get_timeout(
77		&self,
78		duration: Duration,
79	) -> Result<PgConnection, mobc::Error<PgError>> {
80		self.inner.get_timeout(duration).await
81	}
82}
83
84/// Create a connection pool from a connection URL
85///
86/// Supports Unix socket connections via:
87/// - Query parameter: `postgresql:///dbname?host=/var/run/postgresql`
88/// - Percent-encoded host: `postgresql://%2Fvar%2Frun%2Fpostgresql/dbname`
89/// - Empty host (auto-detects Unix socket or falls back to localhost): `postgresql:///dbname`
90///
91/// Unix socket connections automatically disable SSL/TLS.
92///
93/// # Password Prompting
94///
95/// If the connection fails with an authentication error and no password was provided
96/// in the connection URL, the function will prompt the user to enter a password
97/// interactively. The password will be read securely without echoing to the terminal.
98pub async fn create_pool(url: &str, application_name: &str) -> Result<PgPool> {
99	let mut config = url::parse_connection_url(url)?;
100
101	config.application_name(application_name);
102
103	let mut tried_ssl_fallback = false;
104
105	// Try to connect, and if it fails with auth error, prompt for password
106	let pool = loop {
107		debug!("Creating manager");
108		let tls = config.get_ssl_mode() != SslMode::Disable;
109		let manager = crate::pool::PgConnectionManager::new(config.clone(), tls);
110
111		debug!("Creating pool");
112		let pool = Pool::builder()
113			.max_lifetime(Some(Duration::from_secs(3600)))
114			.build(manager.clone());
115
116		let pool = PgPool {
117			manager,
118			inner: pool,
119		};
120
121		debug!("Checking pool");
122		match check_pool(&pool).await {
123			Ok(_) => {
124				if tried_ssl_fallback {
125					tracing::info!("Connected successfully with SSL disabled after TLS error");
126				}
127				break pool;
128			}
129			Err(e) => {
130				debug!("Connection error: {:#}", e);
131				debug!(
132					"is_tls_error: {}, is_auth_error: {}",
133					is_tls_error(&e),
134					is_auth_error(&e)
135				);
136
137				if is_tls_error(&e) {
138					// If SSL mode is prefer and we haven't tried fallback yet, retry with SSL disabled
139					if config.get_ssl_mode() == SslMode::Prefer && !tried_ssl_fallback {
140						debug!("TLS failed with prefer mode, retrying with SSL disabled");
141						config.ssl_mode(SslMode::Disable);
142						tried_ssl_fallback = true;
143						continue;
144					}
145
146					// TLS error - suggest disabling SSL
147					return Err(e).wrap_err(
148						"TLS/SSL connection failed. Try using --ssl disable, \
149						or use a connection URL with sslmode=disable: \
150						postgresql://user@host/db?sslmode=disable",
151					);
152				} else if is_auth_error(&e) && config.get_password().is_none() {
153					let password = rpassword::prompt_password("Password: ").into_diagnostic()?;
154					config.password(password);
155					// Loop will retry with the new password
156				} else {
157					// Not an auth error or we already have a password, re-throw
158					return Err(e);
159				}
160			}
161		}
162	};
163
164	Ok(pool)
165}
166
167/// Check if we can actually establish a connection
168async fn check_pool(pool: &PgPool) -> Result<()> {
169	let conn = match pool.get().await {
170		Err(mobc::Error::Inner(db_err)) => {
171			return Err(match db_err.as_db_error() {
172				Some(db_err) => miette!(
173					"E{code} at {func} in {file}:{line}",
174					code = db_err.code().code(),
175					func = db_err.routine().unwrap_or("{unknown}"),
176					file = db_err.file().unwrap_or("unknown.c"),
177					line = db_err.line().unwrap_or(0)
178				),
179				_ => miette!("{db_err}"),
180			})
181			.wrap_err(
182				db_err
183					.as_db_error()
184					.map(|e| e.to_string())
185					.unwrap_or_default(),
186			)?;
187		}
188		res @ Err(_) => {
189			let res = res.map(drop).into_diagnostic();
190			return if let Err(ref err) = res
191				&& is_auth_error(err)
192			{
193				res.wrap_err("hint: check the password")
194			} else {
195				res
196			};
197		}
198		Ok(conn) => conn,
199	};
200	conn.simple_query("SELECT 1")
201		.await
202		.into_diagnostic()
203		.wrap_err("checking connection")?;
204	Ok(())
205}
206
207#[cfg(test)]
208mod tests {
209	use super::*;
210
211	#[tokio::test]
212	async fn test_create_pool_valid_connection_string() {
213		let connection_string = "postgresql://localhost/test";
214		let result = create_pool(connection_string, "test").await;
215		// May fail if database doesn't exist, but should not be a parsing error
216		if let Err(e) = result {
217			let error_msg = format!("{:?}", e);
218			assert!(
219				!error_msg.contains("parsing connection string"),
220				"Should not be a parsing error: {}",
221				error_msg
222			);
223		}
224	}
225
226	#[tokio::test]
227	async fn test_create_pool_with_full_url() {
228		let connection_string = "postgresql://user:pass@localhost:5432/testdb";
229		let result = create_pool(connection_string, "test").await;
230		// May fail if database doesn't exist or auth fails, but should not be a parsing error
231		if let Err(e) = result {
232			let error_msg = format!("{:?}", e);
233			assert!(
234				!error_msg.contains("parsing connection string"),
235				"Should not be a parsing error: {}",
236				error_msg
237			);
238		}
239	}
240
241	#[tokio::test]
242	async fn test_create_pool_with_unix_socket_path() {
243		// Test connecting via Unix socket path
244		let url = "postgresql:///postgres?host=/var/run/postgresql";
245		let result = create_pool(url, "test").await;
246		// This may fail if PostgreSQL isn't running or isn't accessible via Unix socket
247		// but we can at least verify the parsing works
248		match result {
249			Ok(_) => {
250				// Connection succeeded
251			}
252			Err(e) => {
253				let error_msg = format!("{:?}", e);
254				// Verify it's not a parsing error but a connection error
255				assert!(
256					!error_msg.contains("parsing connection string"),
257					"Should not be a parsing error: {}",
258					error_msg
259				);
260			}
261		}
262	}
263
264	#[tokio::test]
265	async fn test_create_pool_with_encoded_unix_socket() {
266		// Test connecting via percent-encoded Unix socket path in host
267		let url = "postgresql://%2Fvar%2Frun%2Fpostgresql/postgres";
268		let result = create_pool(url, "test").await;
269		// This may fail if PostgreSQL isn't running, but parsing should work
270		match result {
271			Ok(_) => {
272				// Connection succeeded
273			}
274			Err(e) => {
275				let error_msg = format!("{:?}", e);
276				// Verify it's not a parsing error
277				assert!(
278					!error_msg.contains("parsing connection string"),
279					"Should not be a parsing error: {}",
280					error_msg
281				);
282			}
283		}
284	}
285
286	#[tokio::test]
287	async fn test_create_pool_with_no_host() {
288		// Test connection with no host specified (should try Unix socket or fallback to localhost)
289		let url = "postgresql:///postgres";
290		let result = create_pool(url, "test").await;
291		// This should either succeed or fail with a connection error, not a parsing error
292		match result {
293			Ok(_) => {
294				// Connection succeeded
295			}
296			Err(e) => {
297				let error_msg = format!("{:?}", e);
298				// Verify it's not a parsing error
299				assert!(
300					!error_msg.contains("parsing connection string"),
301					"Should not be a parsing error: {}",
302					error_msg
303				);
304			}
305		}
306	}
307
308	#[tokio::test]
309	async fn test_unix_socket_connection_end_to_end() {
310		// Test that we can actually connect and query via Unix socket
311		let url = "postgresql:///postgres?host=/var/run/postgresql";
312		let result = create_pool(url, "test").await;
313
314		match result {
315			Ok(pool) => {
316				// If connection succeeded, try a simple query
317				let conn = pool.get().await;
318				if let Ok(conn) = conn {
319					let result = conn.simple_query("SELECT 1 as test").await;
320					assert!(result.is_ok(), "Query should succeed");
321				}
322			}
323			Err(e) => {
324				let error_msg = format!("{:?}", e);
325				// If it failed, make sure it's not a parsing or TLS error
326				assert!(
327					!error_msg.contains("parsing connection string"),
328					"Should not be a parsing error: {}",
329					error_msg
330				);
331				assert!(
332					!error_msg.contains("TLS handshake"),
333					"Should not be a TLS error for Unix socket: {}",
334					error_msg
335				);
336			}
337		}
338	}
339}