1use std::time::Duration;
2
3use miette::{IntoDiagnostic, Report, Result, WrapErr, miette};
4use mobc::{Connection, Pool};
5use tokio_postgres::config::SslMode;
6use tracing::debug;
7
8pub use manager::{PgConnectionManager, PgError};
9
10mod manager;
11mod tls;
12mod url;
13
14fn is_tls_error(error: &Report) -> bool {
16 if error.downcast_ref::<rustls::Error>().is_some() {
17 return true;
18 }
19
20 let mut source = error.source();
22 while let Some(err) = source {
23 if err.downcast_ref::<rustls::Error>().is_some() {
24 return true;
25 }
26 source = err.source();
27 }
28
29 let message = error.to_string();
30 message.contains("tls:")
31 || message.contains("rustls")
32 || message.contains("certificate")
33 || message.contains("TLS handshake")
34 || message.contains("invalid configuration")
35}
36
37fn is_auth_error(error: &Report) -> bool {
39 if let Some(db_error) = error.downcast_ref::<tokio_postgres::Error>()
40 && let Some(db_error) = db_error.as_db_error()
41 {
42 let code = db_error.code().code();
46 return code == "28000" || code == "28P01";
47 }
48
49 let message = error.to_string();
51 message.contains("password authentication failed")
52 || message.contains("no password supplied")
53 || message.contains("authentication failed")
54}
55
56pub type PgConnection = Connection<manager::PgConnectionManager>;
57
58#[derive(Debug, Clone)]
59pub struct PgPool {
60 pub manager: manager::PgConnectionManager,
61 pub inner: Pool<manager::PgConnectionManager>,
62}
63
64impl PgPool {
65 pub async fn get(&self) -> Result<PgConnection, mobc::Error<PgError>> {
69 self.inner.get().await
70 }
71
72 pub async fn get_timeout(
77 &self,
78 duration: Duration,
79 ) -> Result<PgConnection, mobc::Error<PgError>> {
80 self.inner.get_timeout(duration).await
81 }
82}
83
84pub async fn create_pool(url: &str, application_name: &str) -> Result<PgPool> {
99 let mut config = url::parse_connection_url(url)?;
100
101 config.application_name(application_name);
102
103 let mut tried_ssl_fallback = false;
104
105 let pool = loop {
107 debug!("Creating manager");
108 let tls = config.get_ssl_mode() != SslMode::Disable;
109 let manager = crate::pool::PgConnectionManager::new(config.clone(), tls);
110
111 debug!("Creating pool");
112 let pool = Pool::builder()
113 .max_lifetime(Some(Duration::from_secs(3600)))
114 .build(manager.clone());
115
116 let pool = PgPool {
117 manager,
118 inner: pool,
119 };
120
121 debug!("Checking pool");
122 match check_pool(&pool).await {
123 Ok(_) => {
124 if tried_ssl_fallback {
125 tracing::info!("Connected successfully with SSL disabled after TLS error");
126 }
127 break pool;
128 }
129 Err(e) => {
130 debug!("Connection error: {:#}", e);
131 debug!(
132 "is_tls_error: {}, is_auth_error: {}",
133 is_tls_error(&e),
134 is_auth_error(&e)
135 );
136
137 if is_tls_error(&e) {
138 if config.get_ssl_mode() == SslMode::Prefer && !tried_ssl_fallback {
140 debug!("TLS failed with prefer mode, retrying with SSL disabled");
141 config.ssl_mode(SslMode::Disable);
142 tried_ssl_fallback = true;
143 continue;
144 }
145
146 return Err(e).wrap_err(
148 "TLS/SSL connection failed. Try using --ssl disable, \
149 or use a connection URL with sslmode=disable: \
150 postgresql://user@host/db?sslmode=disable",
151 );
152 } else if is_auth_error(&e) && config.get_password().is_none() {
153 let password = rpassword::prompt_password("Password: ").into_diagnostic()?;
154 config.password(password);
155 } else {
157 return Err(e);
159 }
160 }
161 }
162 };
163
164 Ok(pool)
165}
166
167async fn check_pool(pool: &PgPool) -> Result<()> {
169 let conn = match pool.get().await {
170 Err(mobc::Error::Inner(db_err)) => {
171 return Err(match db_err.as_db_error() {
172 Some(db_err) => miette!(
173 "E{code} at {func} in {file}:{line}",
174 code = db_err.code().code(),
175 func = db_err.routine().unwrap_or("{unknown}"),
176 file = db_err.file().unwrap_or("unknown.c"),
177 line = db_err.line().unwrap_or(0)
178 ),
179 _ => miette!("{db_err}"),
180 })
181 .wrap_err(
182 db_err
183 .as_db_error()
184 .map(|e| e.to_string())
185 .unwrap_or_default(),
186 )?;
187 }
188 res @ Err(_) => {
189 let res = res.map(drop).into_diagnostic();
190 return if let Err(ref err) = res
191 && is_auth_error(err)
192 {
193 res.wrap_err("hint: check the password")
194 } else {
195 res
196 };
197 }
198 Ok(conn) => conn,
199 };
200 conn.simple_query("SELECT 1")
201 .await
202 .into_diagnostic()
203 .wrap_err("checking connection")?;
204 Ok(())
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[tokio::test]
212 async fn test_create_pool_valid_connection_string() {
213 let connection_string = "postgresql://localhost/test";
214 let result = create_pool(connection_string, "test").await;
215 if let Err(e) = result {
217 let error_msg = format!("{:?}", e);
218 assert!(
219 !error_msg.contains("parsing connection string"),
220 "Should not be a parsing error: {}",
221 error_msg
222 );
223 }
224 }
225
226 #[tokio::test]
227 async fn test_create_pool_with_full_url() {
228 let connection_string = "postgresql://user:pass@localhost:5432/testdb";
229 let result = create_pool(connection_string, "test").await;
230 if let Err(e) = result {
232 let error_msg = format!("{:?}", e);
233 assert!(
234 !error_msg.contains("parsing connection string"),
235 "Should not be a parsing error: {}",
236 error_msg
237 );
238 }
239 }
240
241 #[tokio::test]
242 async fn test_create_pool_with_unix_socket_path() {
243 let url = "postgresql:///postgres?host=/var/run/postgresql";
245 let result = create_pool(url, "test").await;
246 match result {
249 Ok(_) => {
250 }
252 Err(e) => {
253 let error_msg = format!("{:?}", e);
254 assert!(
256 !error_msg.contains("parsing connection string"),
257 "Should not be a parsing error: {}",
258 error_msg
259 );
260 }
261 }
262 }
263
264 #[tokio::test]
265 async fn test_create_pool_with_encoded_unix_socket() {
266 let url = "postgresql://%2Fvar%2Frun%2Fpostgresql/postgres";
268 let result = create_pool(url, "test").await;
269 match result {
271 Ok(_) => {
272 }
274 Err(e) => {
275 let error_msg = format!("{:?}", e);
276 assert!(
278 !error_msg.contains("parsing connection string"),
279 "Should not be a parsing error: {}",
280 error_msg
281 );
282 }
283 }
284 }
285
286 #[tokio::test]
287 async fn test_create_pool_with_no_host() {
288 let url = "postgresql:///postgres";
290 let result = create_pool(url, "test").await;
291 match result {
293 Ok(_) => {
294 }
296 Err(e) => {
297 let error_msg = format!("{:?}", e);
298 assert!(
300 !error_msg.contains("parsing connection string"),
301 "Should not be a parsing error: {}",
302 error_msg
303 );
304 }
305 }
306 }
307
308 #[tokio::test]
309 async fn test_unix_socket_connection_end_to_end() {
310 let url = "postgresql:///postgres?host=/var/run/postgresql";
312 let result = create_pool(url, "test").await;
313
314 match result {
315 Ok(pool) => {
316 let conn = pool.get().await;
318 if let Ok(conn) = conn {
319 let result = conn.simple_query("SELECT 1 as test").await;
320 assert!(result.is_ok(), "Query should succeed");
321 }
322 }
323 Err(e) => {
324 let error_msg = format!("{:?}", e);
325 assert!(
327 !error_msg.contains("parsing connection string"),
328 "Should not be a parsing error: {}",
329 error_msg
330 );
331 assert!(
332 !error_msg.contains("TLS handshake"),
333 "Should not be a TLS error for Unix socket: {}",
334 error_msg
335 );
336 }
337 }
338 }
339}