database_replicator/postgres/
connection.rs1use crate::utils;
5use anyhow::{Context, Result};
6use native_tls::TlsConnector;
7use postgres_native_tls::MakeTlsConnector;
8use std::sync::OnceLock;
9use std::time::Duration;
10use tokio_postgres::Client;
11
12static ALLOW_SELF_SIGNED_CERTS: OnceLock<bool> = OnceLock::new();
14
15pub fn init_tls_policy(allow: bool) {
24 let _ = ALLOW_SELF_SIGNED_CERTS.set(allow);
25 if allow {
26 tracing::warn!("TLS policy: Allowing self-signed/invalid certificates (insecure)");
27 }
28}
29
30pub fn add_keepalive_params(connection_string: &str) -> String {
63 let has_query = connection_string.contains('?');
65 let lower = connection_string.to_lowercase();
66
67 let needs_keepalives = !lower.contains("keepalives=");
68 let needs_idle = !lower.contains("keepalives_idle=");
69 let needs_interval = !lower.contains("keepalives_interval=");
70
71 if !needs_keepalives && !needs_idle && !needs_interval {
73 return connection_string.to_string();
74 }
75
76 let mut url = connection_string.to_string();
77 let separator = if has_query { "&" } else { "?" };
78
79 let mut params = Vec::new();
81 if needs_keepalives {
82 params.push("keepalives=1");
83 }
84 if needs_idle {
85 params.push("keepalives_idle=60");
86 }
87 if needs_interval {
88 params.push("keepalives_interval=10");
89 }
90
91 if !params.is_empty() {
92 url.push_str(separator);
93 url.push_str(¶ms.join("&"));
94 }
95
96 url
97}
98
99pub async fn connect(connection_string: &str) -> Result<Client> {
141 let connection_string_with_keepalive = add_keepalive_params(connection_string);
143
144 let _config = connection_string_with_keepalive
146 .parse::<tokio_postgres::Config>()
147 .context(
148 "Invalid connection string format. Expected: postgresql://user:password@host:port/database",
149 )?;
150
151 let allow_self_signed = ALLOW_SELF_SIGNED_CERTS.get().copied().unwrap_or(false);
154
155 let mut tls_builder = TlsConnector::builder();
156 if allow_self_signed {
157 tls_builder.danger_accept_invalid_certs(true);
158 }
159
160 let tls_connector = tls_builder
161 .build()
162 .context("Failed to build TLS connector")?;
163 let tls = MakeTlsConnector::new(tls_connector);
164
165 let (client, connection) = tokio_postgres::connect(&connection_string_with_keepalive, tls)
167 .await
168 .map_err(|e| {
169 let error_msg = e.to_string();
171
172 let detailed_msg = if let Some(db_error) = e.as_db_error() {
174 let mut details = format!(
176 "PostgreSQL error {}: {}",
177 db_error.code().code(),
178 db_error.message()
179 );
180 if let Some(detail) = db_error.detail() {
181 details.push_str(&format!("\nDetail: {}", detail));
182 }
183 if let Some(hint) = db_error.hint() {
184 details.push_str(&format!("\nHint: {}", hint));
185 }
186 details
187 } else if let Some(source) = std::error::Error::source(&e) {
188 format!("{} (caused by: {})", error_msg, source)
190 } else {
191 if error_msg == "db error" || error_msg.len() < 20 {
193 format!("{:?}", e)
194 } else {
195 error_msg.clone()
196 }
197 };
198
199 if error_msg.contains("password authentication failed") {
200 anyhow::anyhow!(
201 "Authentication failed: Invalid username or password.\n\
202 Please verify your database credentials."
203 )
204 } else if error_msg.contains("database") && error_msg.contains("does not exist") {
205 anyhow::anyhow!(
206 "Database does not exist: {}\n\
207 Please create the database first or check the connection URL.",
208 error_msg
209 )
210 } else if error_msg.contains("Connection refused")
211 || error_msg.contains("could not connect")
212 {
213 anyhow::anyhow!(
214 "Connection refused: Unable to reach database server.\n\
215 Please check:\n\
216 - The host and port are correct\n\
217 - The database server is running\n\
218 - Firewall rules allow connections\n\
219 Error: {}",
220 detailed_msg
221 )
222 } else if error_msg.contains("timeout") || error_msg.contains("timed out") {
223 anyhow::anyhow!(
224 "Connection timeout: Database server did not respond in time.\n\
225 This could indicate network issues or server overload.\n\
226 Error: {}",
227 detailed_msg
228 )
229 } else if error_msg.contains("SSL") || error_msg.contains("TLS") {
230 tracing::error!("TLS/SSL connection failed with error: {:?}", e);
232 anyhow::anyhow!(
233 "TLS/SSL error: Failed to establish secure connection.\n\
234 Please verify SSL/TLS configuration.\n\
235 Detailed error: {:?}\n\
236 Original error: {}",
237 e,
238 error_msg
239 )
240 } else if error_msg.contains("no pg_hba.conf entry") {
241 anyhow::anyhow!(
242 "Access denied: No pg_hba.conf entry for host.\n\
243 The database server is not configured to accept connections from your host.\n\
244 Contact your database administrator to update pg_hba.conf.\n\
245 Error: {}",
246 error_msg
247 )
248 } else {
249 anyhow::anyhow!("Failed to connect to database: {}", detailed_msg)
251 }
252 })?;
253
254 tokio::spawn(async move {
256 if let Err(e) = connection.await {
257 tracing::error!("Connection error: {}", e);
258 }
259 });
260
261 Ok(client)
262}
263
264pub async fn connect_with_retry(connection_string: &str) -> Result<Client> {
292 utils::retry_with_backoff(
293 || connect(connection_string),
294 3, Duration::from_secs(1), )
297 .await
298 .context("Failed to connect after retries")
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_add_keepalive_params_to_url_without_query() {
307 let url = "postgresql://user:pass@host:5432/db";
308 let result = add_keepalive_params(url);
309
310 assert!(result.contains("keepalives=1"));
311 assert!(result.contains("keepalives_idle=60"));
312 assert!(result.contains("keepalives_interval=10"));
313 assert!(result.starts_with("postgresql://user:pass@host:5432/db?"));
314 }
315
316 #[test]
317 fn test_add_keepalive_params_to_url_with_existing_query() {
318 let url = "postgresql://user:pass@host:5432/db?sslmode=require";
319 let result = add_keepalive_params(url);
320
321 assert!(result.contains("keepalives=1"));
322 assert!(result.contains("keepalives_idle=60"));
323 assert!(result.contains("keepalives_interval=10"));
324 assert!(result.contains("sslmode=require"));
325 assert!(result.contains("&keepalives=1"));
327 }
328
329 #[test]
330 fn test_add_keepalive_params_already_present() {
331 let url =
332 "postgresql://user:pass@host:5432/db?keepalives=1&keepalives_idle=60&keepalives_interval=10";
333 let result = add_keepalive_params(url);
334
335 assert_eq!(result, url);
337 }
338
339 #[test]
340 fn test_add_keepalive_params_partial_existing() {
341 let url = "postgresql://user:pass@host:5432/db?keepalives=1";
342 let result = add_keepalive_params(url);
343
344 assert!(result.contains("keepalives=1"));
346 assert!(result.contains("keepalives_idle=60"));
347 assert!(result.contains("keepalives_interval=10"));
348 assert_eq!(result.matches("keepalives=1").count(), 1);
350 }
351
352 #[test]
353 fn test_add_keepalive_params_case_insensitive() {
354 let url = "postgresql://user:pass@host:5432/db?KEEPALIVES=1";
355 let result = add_keepalive_params(url);
356
357 assert!(result.contains("KEEPALIVES=1"));
359 assert!(result.contains("keepalives_idle=60"));
360 assert!(result.contains("keepalives_interval=10"));
361 let lower_result = result.to_lowercase();
363 assert_eq!(lower_result.matches("keepalives=1").count(), 1);
364 }
365
366 #[tokio::test]
367 async fn test_connect_with_invalid_url_returns_error() {
368 let result = connect("invalid-url").await;
369 assert!(result.is_err());
370 }
371
372 #[tokio::test]
375 #[ignore]
376 async fn test_connect_with_valid_url_succeeds() {
377 let url = std::env::var("TEST_DATABASE_URL")
378 .expect("TEST_DATABASE_URL must be set for integration tests");
379
380 let result = connect(&url).await;
381 assert!(result.is_ok());
382 }
383}