database_replicator/postgres/
connection.rs

1// ABOUTME: PostgreSQL connection utilities for Neon and Seren
2// ABOUTME: Handles connection string parsing, TLS setup, and connection lifecycle
3
4use crate::utils;
5use anyhow::{Context, Result};
6use native_tls::TlsConnector;
7use postgres_native_tls::MakeTlsConnector;
8use std::time::Duration;
9use tokio_postgres::Client;
10
11/// Add TCP keepalive parameters to a PostgreSQL connection string
12///
13/// Automatically adds keepalive parameters to prevent idle connection timeouts
14/// when connecting through load balancers (like AWS ELB). These parameters ensure
15/// that TCP keepalive packets are sent regularly to keep the connection alive.
16///
17/// Parameters added:
18/// - `keepalives=1`: Enable TCP keepalives
19/// - `keepalives_idle=60`: Send first keepalive after 60 seconds of idle time
20/// - `keepalives_interval=10`: Send subsequent keepalives every 10 seconds
21///
22/// If any of these parameters already exist in the connection string, they are
23/// not overwritten.
24///
25/// # Arguments
26///
27/// * `connection_string` - Original PostgreSQL connection URL
28///
29/// # Returns
30///
31/// Connection string with keepalive parameters added
32///
33/// # Examples
34///
35/// ```
36/// # use database_replicator::postgres::connection::add_keepalive_params;
37/// let url = "postgresql://user:pass@host:5432/db";
38/// let url_with_keepalives = add_keepalive_params(url);
39/// assert!(url_with_keepalives.contains("keepalives=1"));
40/// assert!(url_with_keepalives.contains("keepalives_idle=60"));
41/// assert!(url_with_keepalives.contains("keepalives_interval=10"));
42/// ```
43pub fn add_keepalive_params(connection_string: &str) -> String {
44    // Parse to check if params already exist
45    let has_query = connection_string.contains('?');
46    let lower = connection_string.to_lowercase();
47
48    let needs_keepalives = !lower.contains("keepalives=");
49    let needs_idle = !lower.contains("keepalives_idle=");
50    let needs_interval = !lower.contains("keepalives_interval=");
51
52    // If all params already exist, return as-is
53    if !needs_keepalives && !needs_idle && !needs_interval {
54        return connection_string.to_string();
55    }
56
57    let mut url = connection_string.to_string();
58    let separator = if has_query { "&" } else { "?" };
59
60    // Add missing keepalive parameters
61    let mut params = Vec::new();
62    if needs_keepalives {
63        params.push("keepalives=1");
64    }
65    if needs_idle {
66        params.push("keepalives_idle=60");
67    }
68    if needs_interval {
69        params.push("keepalives_interval=10");
70    }
71
72    if !params.is_empty() {
73        url.push_str(separator);
74        url.push_str(&params.join("&"));
75    }
76
77    url
78}
79
80/// Connect to PostgreSQL database with TLS support
81///
82/// Establishes a connection using the provided connection string with TLS enabled.
83/// The connection lifecycle is managed automatically via tokio spawn.
84///
85/// **Automatic Keepalive:** This function automatically adds TCP keepalive parameters
86/// to prevent idle connection timeouts when connecting through load balancers.
87/// The following parameters are added if not already present:
88/// - `keepalives=1`
89/// - `keepalives_idle=60`
90/// - `keepalives_interval=10`
91///
92/// # Arguments
93///
94/// * `connection_string` - PostgreSQL URL (e.g., "postgresql://user:pass@host:5432/db")
95///
96/// # Returns
97///
98/// Returns a `Client` on success, or an error with context if connection fails.
99///
100/// # Errors
101///
102/// This function will return an error if:
103/// - The connection string format is invalid
104/// - Authentication fails (invalid username or password)
105/// - The database does not exist
106/// - The database server is unreachable
107/// - TLS negotiation fails
108/// - Connection times out
109/// - pg_hba.conf does not allow the connection
110///
111/// # Examples
112///
113/// ```no_run
114/// # use anyhow::Result;
115/// # use database_replicator::postgres::connect;
116/// # async fn example() -> Result<()> {
117/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
118/// # Ok(())
119/// # }
120/// ```
121pub async fn connect(connection_string: &str) -> Result<Client> {
122    // Add keepalive parameters to prevent idle connection timeouts
123    let connection_string_with_keepalive = add_keepalive_params(connection_string);
124
125    // Parse connection string
126    let _config = connection_string_with_keepalive
127        .parse::<tokio_postgres::Config>()
128        .context(
129        "Invalid connection string format. Expected: postgresql://user:password@host:port/database",
130    )?;
131
132    // Set up TLS connector for cloud connections
133    // TEMPORARY: Accept invalid certs to debug TLS issues
134    // TODO: Remove this once we identify the certificate validation issue
135    let tls_connector = TlsConnector::builder()
136        .danger_accept_invalid_certs(true)
137        .build()
138        .context("Failed to build TLS connector")?;
139    let tls = MakeTlsConnector::new(tls_connector);
140
141    // Connect with keepalive parameters
142    let (client, connection) = tokio_postgres::connect(&connection_string_with_keepalive, tls)
143        .await
144        .map_err(|e| {
145            // Parse error and provide helpful context
146            let error_msg = e.to_string();
147
148            if error_msg.contains("password authentication failed") {
149                anyhow::anyhow!(
150                    "Authentication failed: Invalid username or password.\n\
151                     Please verify your database credentials."
152                )
153            } else if error_msg.contains("database") && error_msg.contains("does not exist") {
154                anyhow::anyhow!(
155                    "Database does not exist: {}\n\
156                     Please create the database first or check the connection URL.",
157                    error_msg
158                )
159            } else if error_msg.contains("Connection refused")
160                || error_msg.contains("could not connect")
161            {
162                anyhow::anyhow!(
163                    "Connection refused: Unable to reach database server.\n\
164                     Please check:\n\
165                     - The host and port are correct\n\
166                     - The database server is running\n\
167                     - Firewall rules allow connections\n\
168                     Error: {}",
169                    error_msg
170                )
171            } else if error_msg.contains("timeout") || error_msg.contains("timed out") {
172                anyhow::anyhow!(
173                    "Connection timeout: Database server did not respond in time.\n\
174                     This could indicate network issues or server overload.\n\
175                     Error: {}",
176                    error_msg
177                )
178            } else if error_msg.contains("SSL") || error_msg.contains("TLS") {
179                // Log full error for debugging TLS issues
180                tracing::error!("TLS/SSL connection failed with error: {:?}", e);
181                anyhow::anyhow!(
182                    "TLS/SSL error: Failed to establish secure connection.\n\
183                     Please verify SSL/TLS configuration.\n\
184                     Detailed error: {:?}\n\
185                     Original error: {}",
186                    e,
187                    error_msg
188                )
189            } else if error_msg.contains("no pg_hba.conf entry") {
190                anyhow::anyhow!(
191                    "Access denied: No pg_hba.conf entry for host.\n\
192                     The database server is not configured to accept connections from your host.\n\
193                     Contact your database administrator to update pg_hba.conf.\n\
194                     Error: {}",
195                    error_msg
196                )
197            } else {
198                anyhow::anyhow!("Failed to connect to database: {}", error_msg)
199            }
200        })?;
201
202    // Spawn connection handler
203    tokio::spawn(async move {
204        if let Err(e) = connection.await {
205            tracing::error!("Connection error: {}", e);
206        }
207    });
208
209    Ok(client)
210}
211
212/// Connect to PostgreSQL with automatic retry for transient failures
213///
214/// Attempts to connect up to 3 times with exponential backoff (1s, 2s, 4s).
215/// Useful for handling temporary network issues or server restarts.
216///
217/// # Arguments
218///
219/// * `connection_string` - PostgreSQL URL
220///
221/// # Returns
222///
223/// Returns a `Client` after successful connection, or error after all retries exhausted.
224///
225/// # Errors
226///
227/// Returns the last connection error if all retry attempts fail.
228///
229/// # Examples
230///
231/// ```no_run
232/// # use anyhow::Result;
233/// # use database_replicator::postgres::connection::connect_with_retry;
234/// # async fn example() -> Result<()> {
235/// let client = connect_with_retry("postgresql://user:pass@localhost:5432/mydb").await?;
236/// # Ok(())
237/// # }
238/// ```
239pub async fn connect_with_retry(connection_string: &str) -> Result<Client> {
240    utils::retry_with_backoff(
241        || connect(connection_string),
242        3,                      // Max 3 retries
243        Duration::from_secs(1), // Start with 1 second delay
244    )
245    .await
246    .context("Failed to connect after retries")
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_add_keepalive_params_to_url_without_query() {
255        let url = "postgresql://user:pass@host:5432/db";
256        let result = add_keepalive_params(url);
257
258        assert!(result.contains("keepalives=1"));
259        assert!(result.contains("keepalives_idle=60"));
260        assert!(result.contains("keepalives_interval=10"));
261        assert!(result.starts_with("postgresql://user:pass@host:5432/db?"));
262    }
263
264    #[test]
265    fn test_add_keepalive_params_to_url_with_existing_query() {
266        let url = "postgresql://user:pass@host:5432/db?sslmode=require";
267        let result = add_keepalive_params(url);
268
269        assert!(result.contains("keepalives=1"));
270        assert!(result.contains("keepalives_idle=60"));
271        assert!(result.contains("keepalives_interval=10"));
272        assert!(result.contains("sslmode=require"));
273        // Should use & separator not ?
274        assert!(result.contains("&keepalives=1"));
275    }
276
277    #[test]
278    fn test_add_keepalive_params_already_present() {
279        let url =
280            "postgresql://user:pass@host:5432/db?keepalives=1&keepalives_idle=60&keepalives_interval=10";
281        let result = add_keepalive_params(url);
282
283        // Should return unchanged
284        assert_eq!(result, url);
285    }
286
287    #[test]
288    fn test_add_keepalive_params_partial_existing() {
289        let url = "postgresql://user:pass@host:5432/db?keepalives=1";
290        let result = add_keepalive_params(url);
291
292        // Should only add missing params
293        assert!(result.contains("keepalives=1"));
294        assert!(result.contains("keepalives_idle=60"));
295        assert!(result.contains("keepalives_interval=10"));
296        // Should not duplicate keepalives=1
297        assert_eq!(result.matches("keepalives=1").count(), 1);
298    }
299
300    #[test]
301    fn test_add_keepalive_params_case_insensitive() {
302        let url = "postgresql://user:pass@host:5432/db?KEEPALIVES=1";
303        let result = add_keepalive_params(url);
304
305        // Should detect uppercase params and still add the missing ones
306        assert!(result.contains("KEEPALIVES=1"));
307        assert!(result.contains("keepalives_idle=60"));
308        assert!(result.contains("keepalives_interval=10"));
309        // Should not add lowercase keepalives=1 because KEEPALIVES=1 already exists
310        let lower_result = result.to_lowercase();
311        assert_eq!(lower_result.matches("keepalives=1").count(), 1);
312    }
313
314    #[tokio::test]
315    async fn test_connect_with_invalid_url_returns_error() {
316        let result = connect("invalid-url").await;
317        assert!(result.is_err());
318    }
319
320    // NOTE: This test requires a real PostgreSQL instance
321    // Skip if TEST_DATABASE_URL is not set
322    #[tokio::test]
323    #[ignore]
324    async fn test_connect_with_valid_url_succeeds() {
325        let url = std::env::var("TEST_DATABASE_URL")
326            .expect("TEST_DATABASE_URL must be set for integration tests");
327
328        let result = connect(&url).await;
329        assert!(result.is_ok());
330    }
331}