database_replicator/postgres/
connection.rs1use crate::utils;
5use anyhow::{Context, Result};
6use native_tls::TlsConnector;
7use postgres_native_tls::MakeTlsConnector;
8use std::time::Duration;
9use tokio_postgres::Client;
10
11pub fn add_keepalive_params(connection_string: &str) -> String {
44 let has_query = connection_string.contains('?');
46 let lower = connection_string.to_lowercase();
47
48 let needs_keepalives = !lower.contains("keepalives=");
49 let needs_idle = !lower.contains("keepalives_idle=");
50 let needs_interval = !lower.contains("keepalives_interval=");
51
52 if !needs_keepalives && !needs_idle && !needs_interval {
54 return connection_string.to_string();
55 }
56
57 let mut url = connection_string.to_string();
58 let separator = if has_query { "&" } else { "?" };
59
60 let mut params = Vec::new();
62 if needs_keepalives {
63 params.push("keepalives=1");
64 }
65 if needs_idle {
66 params.push("keepalives_idle=60");
67 }
68 if needs_interval {
69 params.push("keepalives_interval=10");
70 }
71
72 if !params.is_empty() {
73 url.push_str(separator);
74 url.push_str(¶ms.join("&"));
75 }
76
77 url
78}
79
80pub async fn connect(connection_string: &str) -> Result<Client> {
122 let connection_string_with_keepalive = add_keepalive_params(connection_string);
124
125 let _config = connection_string_with_keepalive
127 .parse::<tokio_postgres::Config>()
128 .context(
129 "Invalid connection string format. Expected: postgresql://user:password@host:port/database",
130 )?;
131
132 let tls_connector = TlsConnector::builder()
136 .danger_accept_invalid_certs(true)
137 .build()
138 .context("Failed to build TLS connector")?;
139 let tls = MakeTlsConnector::new(tls_connector);
140
141 let (client, connection) = tokio_postgres::connect(&connection_string_with_keepalive, tls)
143 .await
144 .map_err(|e| {
145 let error_msg = e.to_string();
147
148 if error_msg.contains("password authentication failed") {
149 anyhow::anyhow!(
150 "Authentication failed: Invalid username or password.\n\
151 Please verify your database credentials."
152 )
153 } else if error_msg.contains("database") && error_msg.contains("does not exist") {
154 anyhow::anyhow!(
155 "Database does not exist: {}\n\
156 Please create the database first or check the connection URL.",
157 error_msg
158 )
159 } else if error_msg.contains("Connection refused")
160 || error_msg.contains("could not connect")
161 {
162 anyhow::anyhow!(
163 "Connection refused: Unable to reach database server.\n\
164 Please check:\n\
165 - The host and port are correct\n\
166 - The database server is running\n\
167 - Firewall rules allow connections\n\
168 Error: {}",
169 error_msg
170 )
171 } else if error_msg.contains("timeout") || error_msg.contains("timed out") {
172 anyhow::anyhow!(
173 "Connection timeout: Database server did not respond in time.\n\
174 This could indicate network issues or server overload.\n\
175 Error: {}",
176 error_msg
177 )
178 } else if error_msg.contains("SSL") || error_msg.contains("TLS") {
179 tracing::error!("TLS/SSL connection failed with error: {:?}", e);
181 anyhow::anyhow!(
182 "TLS/SSL error: Failed to establish secure connection.\n\
183 Please verify SSL/TLS configuration.\n\
184 Detailed error: {:?}\n\
185 Original error: {}",
186 e,
187 error_msg
188 )
189 } else if error_msg.contains("no pg_hba.conf entry") {
190 anyhow::anyhow!(
191 "Access denied: No pg_hba.conf entry for host.\n\
192 The database server is not configured to accept connections from your host.\n\
193 Contact your database administrator to update pg_hba.conf.\n\
194 Error: {}",
195 error_msg
196 )
197 } else {
198 anyhow::anyhow!("Failed to connect to database: {}", error_msg)
199 }
200 })?;
201
202 tokio::spawn(async move {
204 if let Err(e) = connection.await {
205 tracing::error!("Connection error: {}", e);
206 }
207 });
208
209 Ok(client)
210}
211
212pub async fn connect_with_retry(connection_string: &str) -> Result<Client> {
240 utils::retry_with_backoff(
241 || connect(connection_string),
242 3, Duration::from_secs(1), )
245 .await
246 .context("Failed to connect after retries")
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_add_keepalive_params_to_url_without_query() {
255 let url = "postgresql://user:pass@host:5432/db";
256 let result = add_keepalive_params(url);
257
258 assert!(result.contains("keepalives=1"));
259 assert!(result.contains("keepalives_idle=60"));
260 assert!(result.contains("keepalives_interval=10"));
261 assert!(result.starts_with("postgresql://user:pass@host:5432/db?"));
262 }
263
264 #[test]
265 fn test_add_keepalive_params_to_url_with_existing_query() {
266 let url = "postgresql://user:pass@host:5432/db?sslmode=require";
267 let result = add_keepalive_params(url);
268
269 assert!(result.contains("keepalives=1"));
270 assert!(result.contains("keepalives_idle=60"));
271 assert!(result.contains("keepalives_interval=10"));
272 assert!(result.contains("sslmode=require"));
273 assert!(result.contains("&keepalives=1"));
275 }
276
277 #[test]
278 fn test_add_keepalive_params_already_present() {
279 let url =
280 "postgresql://user:pass@host:5432/db?keepalives=1&keepalives_idle=60&keepalives_interval=10";
281 let result = add_keepalive_params(url);
282
283 assert_eq!(result, url);
285 }
286
287 #[test]
288 fn test_add_keepalive_params_partial_existing() {
289 let url = "postgresql://user:pass@host:5432/db?keepalives=1";
290 let result = add_keepalive_params(url);
291
292 assert!(result.contains("keepalives=1"));
294 assert!(result.contains("keepalives_idle=60"));
295 assert!(result.contains("keepalives_interval=10"));
296 assert_eq!(result.matches("keepalives=1").count(), 1);
298 }
299
300 #[test]
301 fn test_add_keepalive_params_case_insensitive() {
302 let url = "postgresql://user:pass@host:5432/db?KEEPALIVES=1";
303 let result = add_keepalive_params(url);
304
305 assert!(result.contains("KEEPALIVES=1"));
307 assert!(result.contains("keepalives_idle=60"));
308 assert!(result.contains("keepalives_interval=10"));
309 let lower_result = result.to_lowercase();
311 assert_eq!(lower_result.matches("keepalives=1").count(), 1);
312 }
313
314 #[tokio::test]
315 async fn test_connect_with_invalid_url_returns_error() {
316 let result = connect("invalid-url").await;
317 assert!(result.is_err());
318 }
319
320 #[tokio::test]
323 #[ignore]
324 async fn test_connect_with_valid_url_succeeds() {
325 let url = std::env::var("TEST_DATABASE_URL")
326 .expect("TEST_DATABASE_URL must be set for integration tests");
327
328 let result = connect(&url).await;
329 assert!(result.is_ok());
330 }
331}