1use std::time::Duration;
2
3use miette::{IntoDiagnostic, Report, Result, WrapErr, miette};
4use mobc::{Connection, Pool};
5use tokio_postgres::config::SslMode;
6use tracing::debug;
7
8pub use manager::{PgConnectionManager, PgError};
9
10mod manager;
11mod tls;
12mod url;
13
14fn is_auth_error(error: &Report) -> bool {
16 if let Some(db_error) = error.downcast_ref::<tokio_postgres::Error>()
17 && let Some(db_error) = db_error.as_db_error()
18 {
19 let code = db_error.code().code();
23 return code == "28000" || code == "28P01";
24 }
25
26 let message = error.to_string();
28 message.contains("password authentication failed")
29 || message.contains("no password supplied")
30 || message.contains("authentication failed")
31 || message.contains("invalid configuration")
32}
33
34pub type PgConnection = Connection<manager::PgConnectionManager>;
35
36#[derive(Debug, Clone)]
37pub struct PgPool {
38 pub manager: manager::PgConnectionManager,
39 pub inner: Pool<manager::PgConnectionManager>,
40}
41
42impl PgPool {
43 pub async fn get(&self) -> Result<PgConnection, mobc::Error<PgError>> {
47 self.inner.get().await
48 }
49
50 pub async fn get_timeout(
55 &self,
56 duration: Duration,
57 ) -> Result<PgConnection, mobc::Error<PgError>> {
58 self.inner.get_timeout(duration).await
59 }
60}
61
62pub async fn create_pool(url: &str, application_name: &str) -> Result<PgPool> {
77 let mut config = url::parse_connection_url(url)?;
78
79 config.application_name(application_name);
80
81 let pool = loop {
83 debug!("Creating manager");
84 let tls = config.get_ssl_mode() != SslMode::Disable;
85 let manager = crate::pool::PgConnectionManager::new(config.clone(), tls);
86
87 debug!("Creating pool");
88 let pool = Pool::builder()
89 .max_lifetime(Some(Duration::from_secs(3600)))
90 .build(manager.clone());
91
92 let pool = PgPool {
93 manager,
94 inner: pool,
95 };
96
97 debug!("Checking pool");
98 match check_pool(&pool).await {
99 Ok(_) => break pool,
100 Err(e) => {
101 if is_auth_error(&e) && config.get_password().is_none() {
102 let password = rpassword::prompt_password("Password: ").into_diagnostic()?;
103 config.password(password);
104 } else {
106 return Err(e);
108 }
109 }
110 }
111 };
112
113 Ok(pool)
114}
115
116async fn check_pool(pool: &PgPool) -> Result<()> {
118 let conn = match pool.get().await {
119 Err(mobc::Error::Inner(db_err)) => {
120 return Err(match db_err.as_db_error() {
121 Some(db_err) => miette!(
122 "E{code} at {func} in {file}:{line}",
123 code = db_err.code().code(),
124 func = db_err.routine().unwrap_or("{unknown}"),
125 file = db_err.file().unwrap_or("unknown.c"),
126 line = db_err.line().unwrap_or(0)
127 ),
128 _ => miette!("{db_err}"),
129 })
130 .wrap_err(
131 db_err
132 .as_db_error()
133 .map(|e| e.to_string())
134 .unwrap_or_default(),
135 )?;
136 }
137 res @ Err(_) => {
138 let res = res.map(drop).into_diagnostic();
139 return if let Err(ref err) = res
140 && is_auth_error(err)
141 {
142 res.wrap_err("hint: check the password")
143 } else {
144 res
145 };
146 }
147 Ok(conn) => conn,
148 };
149 conn.simple_query("SELECT 1")
150 .await
151 .into_diagnostic()
152 .wrap_err("checking connection")?;
153 Ok(())
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[tokio::test]
161 async fn test_create_pool_valid_connection_string() {
162 let connection_string = "postgresql://localhost/test";
163 let result = create_pool(connection_string, "test").await;
164 if let Err(e) = result {
166 let error_msg = format!("{:?}", e);
167 assert!(
168 !error_msg.contains("parsing connection string"),
169 "Should not be a parsing error: {}",
170 error_msg
171 );
172 }
173 }
174
175 #[tokio::test]
176 async fn test_create_pool_with_full_url() {
177 let connection_string = "postgresql://user:pass@localhost:5432/testdb";
178 let result = create_pool(connection_string, "test").await;
179 if let Err(e) = result {
181 let error_msg = format!("{:?}", e);
182 assert!(
183 !error_msg.contains("parsing connection string"),
184 "Should not be a parsing error: {}",
185 error_msg
186 );
187 }
188 }
189
190 #[tokio::test]
191 async fn test_create_pool_with_unix_socket_path() {
192 let url = "postgresql:///postgres?host=/var/run/postgresql";
194 let result = create_pool(url, "test").await;
195 match result {
198 Ok(_) => {
199 }
201 Err(e) => {
202 let error_msg = format!("{:?}", e);
203 assert!(
205 !error_msg.contains("parsing connection string"),
206 "Should not be a parsing error: {}",
207 error_msg
208 );
209 }
210 }
211 }
212
213 #[tokio::test]
214 async fn test_create_pool_with_encoded_unix_socket() {
215 let url = "postgresql://%2Fvar%2Frun%2Fpostgresql/postgres";
217 let result = create_pool(url, "test").await;
218 match result {
220 Ok(_) => {
221 }
223 Err(e) => {
224 let error_msg = format!("{:?}", e);
225 assert!(
227 !error_msg.contains("parsing connection string"),
228 "Should not be a parsing error: {}",
229 error_msg
230 );
231 }
232 }
233 }
234
235 #[tokio::test]
236 async fn test_create_pool_with_no_host() {
237 let url = "postgresql:///postgres";
239 let result = create_pool(url, "test").await;
240 match result {
242 Ok(_) => {
243 }
245 Err(e) => {
246 let error_msg = format!("{:?}", e);
247 assert!(
249 !error_msg.contains("parsing connection string"),
250 "Should not be a parsing error: {}",
251 error_msg
252 );
253 }
254 }
255 }
256
257 #[tokio::test]
258 async fn test_unix_socket_connection_end_to_end() {
259 let url = "postgresql:///postgres?host=/var/run/postgresql";
261 let result = create_pool(url, "test").await;
262
263 match result {
264 Ok(pool) => {
265 let conn = pool.get().await;
267 if let Ok(conn) = conn {
268 let result = conn.simple_query("SELECT 1 as test").await;
269 assert!(result.is_ok(), "Query should succeed");
270 }
271 }
272 Err(e) => {
273 let error_msg = format!("{:?}", e);
274 assert!(
276 !error_msg.contains("parsing connection string"),
277 "Should not be a parsing error: {}",
278 error_msg
279 );
280 assert!(
281 !error_msg.contains("TLS handshake"),
282 "Should not be a TLS error for Unix socket: {}",
283 error_msg
284 );
285 }
286 }
287 }
288}