database_replicator/postgres/
connection.rs1use crate::utils;
5use anyhow::{Context, Result};
6use native_tls::TlsConnector;
7use postgres_native_tls::MakeTlsConnector;
8use std::sync::OnceLock;
9use std::time::Duration;
10use tokio_postgres::Client;
11
12static ALLOW_SELF_SIGNED_CERTS: OnceLock<bool> = OnceLock::new();
14
15pub fn init_tls_policy(allow: bool) {
24 let _ = ALLOW_SELF_SIGNED_CERTS.set(allow);
25 if allow {
26 tracing::warn!("TLS policy: Allowing self-signed/invalid certificates (insecure)");
27 }
28}
29
30pub fn add_keepalive_params(connection_string: &str) -> String {
63 let has_query = connection_string.contains('?');
65 let lower = connection_string.to_lowercase();
66
67 let needs_keepalives = !lower.contains("keepalives=");
68 let needs_idle = !lower.contains("keepalives_idle=");
69 let needs_interval = !lower.contains("keepalives_interval=");
70
71 if !needs_keepalives && !needs_idle && !needs_interval {
73 return connection_string.to_string();
74 }
75
76 let mut url = connection_string.to_string();
77 let separator = if has_query { "&" } else { "?" };
78
79 let mut params = Vec::new();
81 if needs_keepalives {
82 params.push("keepalives=1");
83 }
84 if needs_idle {
85 params.push("keepalives_idle=60");
86 }
87 if needs_interval {
88 params.push("keepalives_interval=10");
89 }
90
91 if !params.is_empty() {
92 url.push_str(separator);
93 url.push_str(¶ms.join("&"));
94 }
95
96 url
97}
98
99pub async fn connect(connection_string: &str) -> Result<Client> {
141 let connection_string_with_keepalive = add_keepalive_params(connection_string);
143
144 let _config = connection_string_with_keepalive
146 .parse::<tokio_postgres::Config>()
147 .context(
148 "Invalid connection string format. Expected: postgresql://user:password@host:port/database",
149 )?;
150
151 let allow_self_signed = ALLOW_SELF_SIGNED_CERTS.get().copied().unwrap_or(false);
154
155 let mut tls_builder = TlsConnector::builder();
156 if allow_self_signed {
157 tls_builder.danger_accept_invalid_certs(true);
158 }
159
160 let tls_connector = tls_builder
161 .build()
162 .context("Failed to build TLS connector")?;
163 let tls = MakeTlsConnector::new(tls_connector);
164
165 let (client, connection) = tokio_postgres::connect(&connection_string_with_keepalive, tls)
167 .await
168 .map_err(|e| {
169 let error_msg = e.to_string();
171
172 if error_msg.contains("password authentication failed") {
173 anyhow::anyhow!(
174 "Authentication failed: Invalid username or password.\n\
175 Please verify your database credentials."
176 )
177 } else if error_msg.contains("database") && error_msg.contains("does not exist") {
178 anyhow::anyhow!(
179 "Database does not exist: {}\n\
180 Please create the database first or check the connection URL.",
181 error_msg
182 )
183 } else if error_msg.contains("Connection refused")
184 || error_msg.contains("could not connect")
185 {
186 anyhow::anyhow!(
187 "Connection refused: Unable to reach database server.\n\
188 Please check:\n\
189 - The host and port are correct\n\
190 - The database server is running\n\
191 - Firewall rules allow connections\n\
192 Error: {}",
193 error_msg
194 )
195 } else if error_msg.contains("timeout") || error_msg.contains("timed out") {
196 anyhow::anyhow!(
197 "Connection timeout: Database server did not respond in time.\n\
198 This could indicate network issues or server overload.\n\
199 Error: {}",
200 error_msg
201 )
202 } else if error_msg.contains("SSL") || error_msg.contains("TLS") {
203 tracing::error!("TLS/SSL connection failed with error: {:?}", e);
205 anyhow::anyhow!(
206 "TLS/SSL error: Failed to establish secure connection.\n\
207 Please verify SSL/TLS configuration.\n\
208 Detailed error: {:?}\n\
209 Original error: {}",
210 e,
211 error_msg
212 )
213 } else if error_msg.contains("no pg_hba.conf entry") {
214 anyhow::anyhow!(
215 "Access denied: No pg_hba.conf entry for host.\n\
216 The database server is not configured to accept connections from your host.\n\
217 Contact your database administrator to update pg_hba.conf.\n\
218 Error: {}",
219 error_msg
220 )
221 } else {
222 anyhow::anyhow!("Failed to connect to database: {}", error_msg)
223 }
224 })?;
225
226 tokio::spawn(async move {
228 if let Err(e) = connection.await {
229 tracing::error!("Connection error: {}", e);
230 }
231 });
232
233 Ok(client)
234}
235
236pub async fn connect_with_retry(connection_string: &str) -> Result<Client> {
264 utils::retry_with_backoff(
265 || connect(connection_string),
266 3, Duration::from_secs(1), )
269 .await
270 .context("Failed to connect after retries")
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_add_keepalive_params_to_url_without_query() {
279 let url = "postgresql://user:pass@host:5432/db";
280 let result = add_keepalive_params(url);
281
282 assert!(result.contains("keepalives=1"));
283 assert!(result.contains("keepalives_idle=60"));
284 assert!(result.contains("keepalives_interval=10"));
285 assert!(result.starts_with("postgresql://user:pass@host:5432/db?"));
286 }
287
288 #[test]
289 fn test_add_keepalive_params_to_url_with_existing_query() {
290 let url = "postgresql://user:pass@host:5432/db?sslmode=require";
291 let result = add_keepalive_params(url);
292
293 assert!(result.contains("keepalives=1"));
294 assert!(result.contains("keepalives_idle=60"));
295 assert!(result.contains("keepalives_interval=10"));
296 assert!(result.contains("sslmode=require"));
297 assert!(result.contains("&keepalives=1"));
299 }
300
301 #[test]
302 fn test_add_keepalive_params_already_present() {
303 let url =
304 "postgresql://user:pass@host:5432/db?keepalives=1&keepalives_idle=60&keepalives_interval=10";
305 let result = add_keepalive_params(url);
306
307 assert_eq!(result, url);
309 }
310
311 #[test]
312 fn test_add_keepalive_params_partial_existing() {
313 let url = "postgresql://user:pass@host:5432/db?keepalives=1";
314 let result = add_keepalive_params(url);
315
316 assert!(result.contains("keepalives=1"));
318 assert!(result.contains("keepalives_idle=60"));
319 assert!(result.contains("keepalives_interval=10"));
320 assert_eq!(result.matches("keepalives=1").count(), 1);
322 }
323
324 #[test]
325 fn test_add_keepalive_params_case_insensitive() {
326 let url = "postgresql://user:pass@host:5432/db?KEEPALIVES=1";
327 let result = add_keepalive_params(url);
328
329 assert!(result.contains("KEEPALIVES=1"));
331 assert!(result.contains("keepalives_idle=60"));
332 assert!(result.contains("keepalives_interval=10"));
333 let lower_result = result.to_lowercase();
335 assert_eq!(lower_result.matches("keepalives=1").count(), 1);
336 }
337
338 #[tokio::test]
339 async fn test_connect_with_invalid_url_returns_error() {
340 let result = connect("invalid-url").await;
341 assert!(result.is_err());
342 }
343
344 #[tokio::test]
347 #[ignore]
348 async fn test_connect_with_valid_url_succeeds() {
349 let url = std::env::var("TEST_DATABASE_URL")
350 .expect("TEST_DATABASE_URL must be set for integration tests");
351
352 let result = connect(&url).await;
353 assert!(result.is_ok());
354 }
355}