database_replicator/postgres/
connection.rs

1// ABOUTME: PostgreSQL connection utilities for Neon and Seren
2// ABOUTME: Handles connection string parsing, TLS setup, and connection lifecycle
3
4use crate::utils;
5use anyhow::{Context, Result};
6use native_tls::TlsConnector;
7use postgres_native_tls::MakeTlsConnector;
8use std::sync::OnceLock;
9use std::time::Duration;
10use tokio_postgres::Client;
11
12/// Thread-safe storage for TLS configuration set at startup
13static ALLOW_SELF_SIGNED_CERTS: OnceLock<bool> = OnceLock::new();
14
15/// Initialize the TLS certificate policy (call once at startup)
16///
17/// This must be called before any database connections are made.
18/// It is thread-safe and will only set the value once.
19///
20/// # Arguments
21///
22/// * `allow` - If true, accept self-signed/invalid TLS certificates (insecure)
23pub fn init_tls_policy(allow: bool) {
24    let _ = ALLOW_SELF_SIGNED_CERTS.set(allow);
25    if allow {
26        tracing::warn!("TLS policy: Allowing self-signed/invalid certificates (insecure)");
27    }
28}
29
30/// Add TCP keepalive parameters to a PostgreSQL connection string
31///
32/// Automatically adds keepalive parameters to prevent idle connection timeouts
33/// when connecting through load balancers (like AWS ELB). These parameters ensure
34/// that TCP keepalive packets are sent regularly to keep the connection alive.
35///
36/// Parameters added:
37/// - `keepalives=1`: Enable TCP keepalives
38/// - `keepalives_idle=60`: Send first keepalive after 60 seconds of idle time
39/// - `keepalives_interval=10`: Send subsequent keepalives every 10 seconds
40///
41/// If any of these parameters already exist in the connection string, they are
42/// not overwritten.
43///
44/// # Arguments
45///
46/// * `connection_string` - Original PostgreSQL connection URL
47///
48/// # Returns
49///
50/// Connection string with keepalive parameters added
51///
52/// # Examples
53///
54/// ```
55/// # use database_replicator::postgres::connection::add_keepalive_params;
56/// let url = "postgresql://user:pass@host:5432/db";
57/// let url_with_keepalives = add_keepalive_params(url);
58/// assert!(url_with_keepalives.contains("keepalives=1"));
59/// assert!(url_with_keepalives.contains("keepalives_idle=60"));
60/// assert!(url_with_keepalives.contains("keepalives_interval=10"));
61/// ```
62pub fn add_keepalive_params(connection_string: &str) -> String {
63    // Parse to check if params already exist
64    let has_query = connection_string.contains('?');
65    let lower = connection_string.to_lowercase();
66
67    let needs_keepalives = !lower.contains("keepalives=");
68    let needs_idle = !lower.contains("keepalives_idle=");
69    let needs_interval = !lower.contains("keepalives_interval=");
70
71    // If all params already exist, return as-is
72    if !needs_keepalives && !needs_idle && !needs_interval {
73        return connection_string.to_string();
74    }
75
76    let mut url = connection_string.to_string();
77    let separator = if has_query { "&" } else { "?" };
78
79    // Add missing keepalive parameters
80    let mut params = Vec::new();
81    if needs_keepalives {
82        params.push("keepalives=1");
83    }
84    if needs_idle {
85        params.push("keepalives_idle=60");
86    }
87    if needs_interval {
88        params.push("keepalives_interval=10");
89    }
90
91    if !params.is_empty() {
92        url.push_str(separator);
93        url.push_str(&params.join("&"));
94    }
95
96    url
97}
98
99/// Connect to PostgreSQL database with TLS support
100///
101/// Establishes a connection using the provided connection string with TLS enabled.
102/// The connection lifecycle is managed automatically via tokio spawn.
103///
104/// **Automatic Keepalive:** This function automatically adds TCP keepalive parameters
105/// to prevent idle connection timeouts when connecting through load balancers.
106/// The following parameters are added if not already present:
107/// - `keepalives=1`
108/// - `keepalives_idle=60`
109/// - `keepalives_interval=10`
110///
111/// # Arguments
112///
113/// * `connection_string` - PostgreSQL URL (e.g., "postgresql://user:pass@host:5432/db")
114///
115/// # Returns
116///
117/// Returns a `Client` on success, or an error with context if connection fails.
118///
119/// # Errors
120///
121/// This function will return an error if:
122/// - The connection string format is invalid
123/// - Authentication fails (invalid username or password)
124/// - The database does not exist
125/// - The database server is unreachable
126/// - TLS negotiation fails
127/// - Connection times out
128/// - pg_hba.conf does not allow the connection
129///
130/// # Examples
131///
132/// ```no_run
133/// # use anyhow::Result;
134/// # use database_replicator::postgres::connect;
135/// # async fn example() -> Result<()> {
136/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
137/// # Ok(())
138/// # }
139/// ```
140pub async fn connect(connection_string: &str) -> Result<Client> {
141    // Add keepalive parameters to prevent idle connection timeouts
142    let connection_string_with_keepalive = add_keepalive_params(connection_string);
143
144    // Parse connection string
145    let _config = connection_string_with_keepalive
146        .parse::<tokio_postgres::Config>()
147        .context(
148        "Invalid connection string format. Expected: postgresql://user:password@host:port/database",
149    )?;
150
151    // Set up TLS connector for cloud connections
152    // By default, require valid certificates. Allow opt-in via init_tls_policy() called at startup.
153    let allow_self_signed = ALLOW_SELF_SIGNED_CERTS.get().copied().unwrap_or(false);
154
155    let mut tls_builder = TlsConnector::builder();
156    if allow_self_signed {
157        tls_builder.danger_accept_invalid_certs(true);
158    }
159
160    let tls_connector = tls_builder
161        .build()
162        .context("Failed to build TLS connector")?;
163    let tls = MakeTlsConnector::new(tls_connector);
164
165    // Connect with keepalive parameters
166    let (client, connection) = tokio_postgres::connect(&connection_string_with_keepalive, tls)
167        .await
168        .map_err(|e| {
169            // Parse error and provide helpful context
170            let error_msg = e.to_string();
171
172            if error_msg.contains("password authentication failed") {
173                anyhow::anyhow!(
174                    "Authentication failed: Invalid username or password.\n\
175                     Please verify your database credentials."
176                )
177            } else if error_msg.contains("database") && error_msg.contains("does not exist") {
178                anyhow::anyhow!(
179                    "Database does not exist: {}\n\
180                     Please create the database first or check the connection URL.",
181                    error_msg
182                )
183            } else if error_msg.contains("Connection refused")
184                || error_msg.contains("could not connect")
185            {
186                anyhow::anyhow!(
187                    "Connection refused: Unable to reach database server.\n\
188                     Please check:\n\
189                     - The host and port are correct\n\
190                     - The database server is running\n\
191                     - Firewall rules allow connections\n\
192                     Error: {}",
193                    error_msg
194                )
195            } else if error_msg.contains("timeout") || error_msg.contains("timed out") {
196                anyhow::anyhow!(
197                    "Connection timeout: Database server did not respond in time.\n\
198                     This could indicate network issues or server overload.\n\
199                     Error: {}",
200                    error_msg
201                )
202            } else if error_msg.contains("SSL") || error_msg.contains("TLS") {
203                // Log full error for debugging TLS issues
204                tracing::error!("TLS/SSL connection failed with error: {:?}", e);
205                anyhow::anyhow!(
206                    "TLS/SSL error: Failed to establish secure connection.\n\
207                     Please verify SSL/TLS configuration.\n\
208                     Detailed error: {:?}\n\
209                     Original error: {}",
210                    e,
211                    error_msg
212                )
213            } else if error_msg.contains("no pg_hba.conf entry") {
214                anyhow::anyhow!(
215                    "Access denied: No pg_hba.conf entry for host.\n\
216                     The database server is not configured to accept connections from your host.\n\
217                     Contact your database administrator to update pg_hba.conf.\n\
218                     Error: {}",
219                    error_msg
220                )
221            } else {
222                anyhow::anyhow!("Failed to connect to database: {}", error_msg)
223            }
224        })?;
225
226    // Spawn connection handler
227    tokio::spawn(async move {
228        if let Err(e) = connection.await {
229            tracing::error!("Connection error: {}", e);
230        }
231    });
232
233    Ok(client)
234}
235
236/// Connect to PostgreSQL with automatic retry for transient failures
237///
238/// Attempts to connect up to 3 times with exponential backoff (1s, 2s, 4s).
239/// Useful for handling temporary network issues or server restarts.
240///
241/// # Arguments
242///
243/// * `connection_string` - PostgreSQL URL
244///
245/// # Returns
246///
247/// Returns a `Client` after successful connection, or error after all retries exhausted.
248///
249/// # Errors
250///
251/// Returns the last connection error if all retry attempts fail.
252///
253/// # Examples
254///
255/// ```no_run
256/// # use anyhow::Result;
257/// # use database_replicator::postgres::connection::connect_with_retry;
258/// # async fn example() -> Result<()> {
259/// let client = connect_with_retry("postgresql://user:pass@localhost:5432/mydb").await?;
260/// # Ok(())
261/// # }
262/// ```
263pub async fn connect_with_retry(connection_string: &str) -> Result<Client> {
264    utils::retry_with_backoff(
265        || connect(connection_string),
266        3,                      // Max 3 retries
267        Duration::from_secs(1), // Start with 1 second delay
268    )
269    .await
270    .context("Failed to connect after retries")
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_add_keepalive_params_to_url_without_query() {
279        let url = "postgresql://user:pass@host:5432/db";
280        let result = add_keepalive_params(url);
281
282        assert!(result.contains("keepalives=1"));
283        assert!(result.contains("keepalives_idle=60"));
284        assert!(result.contains("keepalives_interval=10"));
285        assert!(result.starts_with("postgresql://user:pass@host:5432/db?"));
286    }
287
288    #[test]
289    fn test_add_keepalive_params_to_url_with_existing_query() {
290        let url = "postgresql://user:pass@host:5432/db?sslmode=require";
291        let result = add_keepalive_params(url);
292
293        assert!(result.contains("keepalives=1"));
294        assert!(result.contains("keepalives_idle=60"));
295        assert!(result.contains("keepalives_interval=10"));
296        assert!(result.contains("sslmode=require"));
297        // Should use & separator not ?
298        assert!(result.contains("&keepalives=1"));
299    }
300
301    #[test]
302    fn test_add_keepalive_params_already_present() {
303        let url =
304            "postgresql://user:pass@host:5432/db?keepalives=1&keepalives_idle=60&keepalives_interval=10";
305        let result = add_keepalive_params(url);
306
307        // Should return unchanged
308        assert_eq!(result, url);
309    }
310
311    #[test]
312    fn test_add_keepalive_params_partial_existing() {
313        let url = "postgresql://user:pass@host:5432/db?keepalives=1";
314        let result = add_keepalive_params(url);
315
316        // Should only add missing params
317        assert!(result.contains("keepalives=1"));
318        assert!(result.contains("keepalives_idle=60"));
319        assert!(result.contains("keepalives_interval=10"));
320        // Should not duplicate keepalives=1
321        assert_eq!(result.matches("keepalives=1").count(), 1);
322    }
323
324    #[test]
325    fn test_add_keepalive_params_case_insensitive() {
326        let url = "postgresql://user:pass@host:5432/db?KEEPALIVES=1";
327        let result = add_keepalive_params(url);
328
329        // Should detect uppercase params and still add the missing ones
330        assert!(result.contains("KEEPALIVES=1"));
331        assert!(result.contains("keepalives_idle=60"));
332        assert!(result.contains("keepalives_interval=10"));
333        // Should not add lowercase keepalives=1 because KEEPALIVES=1 already exists
334        let lower_result = result.to_lowercase();
335        assert_eq!(lower_result.matches("keepalives=1").count(), 1);
336    }
337
338    #[tokio::test]
339    async fn test_connect_with_invalid_url_returns_error() {
340        let result = connect("invalid-url").await;
341        assert!(result.is_err());
342    }
343
344    // NOTE: This test requires a real PostgreSQL instance
345    // Skip if TEST_DATABASE_URL is not set
346    #[tokio::test]
347    #[ignore]
348    async fn test_connect_with_valid_url_succeeds() {
349        let url = std::env::var("TEST_DATABASE_URL")
350            .expect("TEST_DATABASE_URL must be set for integration tests");
351
352        let result = connect(&url).await;
353        assert!(result.is_ok());
354    }
355}