1use std::time::Duration;
2
3use miette::{IntoDiagnostic, Report, Result, WrapErr, miette};
4use mobc::{Connection, Pool};
5use tokio_postgres::config::SslMode;
6use tracing::debug;
7
8pub use manager::{PgConnectionManager, PgError};
9
10mod manager;
11mod tls;
12mod url;
13
14fn is_auth_error(error: &Report) -> bool {
16 if let Some(db_error) = error.downcast_ref::<tokio_postgres::Error>()
17 && let Some(db_error) = db_error.as_db_error()
18 {
19 let code = db_error.code().code();
23 return code == "28000" || code == "28P01";
24 }
25
26 let message = error.to_string();
28 message.contains("password authentication failed")
29 || message.contains("no password supplied")
30 || message.contains("authentication failed")
31 || message.contains("invalid configuration")
32}
33
34pub type PgConnection = Connection<manager::PgConnectionManager>;
35
36#[derive(Debug, Clone)]
37pub struct PgPool {
38 pub manager: manager::PgConnectionManager,
39 pub inner: Pool<manager::PgConnectionManager>,
40}
41
42impl PgPool {
43 pub async fn get(&self) -> Result<PgConnection, mobc::Error<PgError>> {
47 self.inner.get().await
48 }
49
50 pub async fn get_timeout(
55 &self,
56 duration: Duration,
57 ) -> Result<PgConnection, mobc::Error<PgError>> {
58 self.inner.get_timeout(duration).await
59 }
60}
61
62pub async fn create_pool(url: &str, application_name: &str) -> Result<PgPool> {
100 let mut config = url::parse_connection_url(url)?;
101
102 config.application_name(application_name);
103
104 let pool = loop {
106 debug!("Creating manager");
107 let tls = config.get_ssl_mode() != SslMode::Disable;
108 let manager = crate::pool::PgConnectionManager::new(config.clone(), tls);
109
110 debug!("Creating pool");
111 let pool = Pool::builder()
112 .max_lifetime(Some(Duration::from_secs(3600)))
113 .build(manager.clone());
114
115 let pool = PgPool {
116 manager,
117 inner: pool,
118 };
119
120 debug!("Checking pool");
121 match check_pool(&pool).await {
122 Ok(_) => break pool,
123 Err(e) => {
124 if is_auth_error(&e) && config.get_password().is_none() {
125 let password = rpassword::prompt_password("Password: ").into_diagnostic()?;
126 config.password(password);
127 } else {
129 return Err(e);
131 }
132 }
133 }
134 };
135
136 Ok(pool)
137}
138
139async fn check_pool(pool: &PgPool) -> Result<()> {
141 let conn = match pool.get().await {
142 Err(mobc::Error::Inner(db_err)) => {
143 return Err(match db_err.as_db_error() {
144 Some(db_err) => miette!(
145 "E{code} at {func} in {file}:{line}",
146 code = db_err.code().code(),
147 func = db_err.routine().unwrap_or("{unknown}"),
148 file = db_err.file().unwrap_or("unknown.c"),
149 line = db_err.line().unwrap_or(0)
150 ),
151 _ => miette!("{db_err}"),
152 })
153 .wrap_err(
154 db_err
155 .as_db_error()
156 .map(|e| e.to_string())
157 .unwrap_or_default(),
158 )?;
159 }
160 res @ Err(_) => {
161 let res = res.map(drop).into_diagnostic();
162 return if let Err(ref err) = res
163 && is_auth_error(err)
164 {
165 res.wrap_err("hint: check the password")
166 } else {
167 res
168 };
169 }
170 Ok(conn) => conn,
171 };
172 conn.simple_query("SELECT 1")
173 .await
174 .into_diagnostic()
175 .wrap_err("checking connection")?;
176 Ok(())
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[tokio::test]
184 async fn test_create_pool_valid_connection_string() {
185 let connection_string = "postgresql://localhost/test";
186 let result = create_pool(connection_string, "test").await;
187 if let Err(e) = result {
189 let error_msg = format!("{:?}", e);
190 assert!(
191 !error_msg.contains("parsing connection string"),
192 "Should not be a parsing error: {}",
193 error_msg
194 );
195 }
196 }
197
198 #[tokio::test]
199 async fn test_create_pool_with_full_url() {
200 let connection_string = "postgresql://user:pass@localhost:5432/testdb";
201 let result = create_pool(connection_string, "test").await;
202 if let Err(e) = result {
204 let error_msg = format!("{:?}", e);
205 assert!(
206 !error_msg.contains("parsing connection string"),
207 "Should not be a parsing error: {}",
208 error_msg
209 );
210 }
211 }
212
213 #[tokio::test]
214 async fn test_create_pool_with_unix_socket_path() {
215 let url = "postgresql:///postgres?host=/var/run/postgresql";
217 let result = create_pool(url, "test").await;
218 match result {
221 Ok(_) => {
222 }
224 Err(e) => {
225 let error_msg = format!("{:?}", e);
226 assert!(
228 !error_msg.contains("parsing connection string"),
229 "Should not be a parsing error: {}",
230 error_msg
231 );
232 }
233 }
234 }
235
236 #[tokio::test]
237 async fn test_create_pool_with_encoded_unix_socket() {
238 let url = "postgresql://%2Fvar%2Frun%2Fpostgresql/postgres";
240 let result = create_pool(url, "test").await;
241 match result {
243 Ok(_) => {
244 }
246 Err(e) => {
247 let error_msg = format!("{:?}", e);
248 assert!(
250 !error_msg.contains("parsing connection string"),
251 "Should not be a parsing error: {}",
252 error_msg
253 );
254 }
255 }
256 }
257
258 #[tokio::test]
259 async fn test_create_pool_with_no_host() {
260 let url = "postgresql:///postgres";
262 let result = create_pool(url, "test").await;
263 match result {
265 Ok(_) => {
266 }
268 Err(e) => {
269 let error_msg = format!("{:?}", e);
270 assert!(
272 !error_msg.contains("parsing connection string"),
273 "Should not be a parsing error: {}",
274 error_msg
275 );
276 }
277 }
278 }
279
280 #[tokio::test]
281 async fn test_unix_socket_connection_end_to_end() {
282 let url = "postgresql:///postgres?host=/var/run/postgresql";
284 let result = create_pool(url, "test").await;
285
286 match result {
287 Ok(pool) => {
288 let conn = pool.get().await;
290 if let Ok(conn) = conn {
291 let result = conn.simple_query("SELECT 1 as test").await;
292 assert!(result.is_ok(), "Query should succeed");
293 }
294 }
295 Err(e) => {
296 let error_msg = format!("{:?}", e);
297 assert!(
299 !error_msg.contains("parsing connection string"),
300 "Should not be a parsing error: {}",
301 error_msg
302 );
303 assert!(
304 !error_msg.contains("TLS handshake"),
305 "Should not be a TLS error for Unix socket: {}",
306 error_msg
307 );
308 }
309 }
310 }
311}