database_replicator/
utils.rs

1// ABOUTME: Utility functions for validation and error handling
2// ABOUTME: Provides input validation, retry logic, and resource cleanup
3
4use anyhow::{bail, Context, Result};
5use std::time::Duration;
6use which::which;
7
8/// Get TCP keepalive environment variables for PostgreSQL client tools
9///
10/// Returns environment variables that configure TCP keepalives for external
11/// PostgreSQL tools (pg_dump, pg_restore, psql, pg_dumpall). These prevent
12/// idle connection timeouts when connecting through load balancers like AWS ELB.
13///
14/// Environment variables returned:
15/// - `PGKEEPALIVES=1`: Enable TCP keepalives
16/// - `PGKEEPALIVESIDLE=60`: Send first keepalive after 60 seconds of idle time
17/// - `PGKEEPALIVESINTERVAL=10`: Send subsequent keepalives every 10 seconds
18///
19/// # Returns
20///
21/// A vector of (variable_name, value) tuples to be passed to subprocess commands
22///
23/// # Examples
24///
25/// ```
26/// # use database_replicator::utils::get_keepalive_env_vars;
27/// # use std::process::Command;
28/// let keepalive_vars = get_keepalive_env_vars();
29/// let mut cmd = Command::new("psql");
30/// for (key, value) in keepalive_vars {
31///     cmd.env(key, value);
32/// }
33/// ```
34pub fn get_keepalive_env_vars() -> Vec<(&'static str, &'static str)> {
35    vec![
36        ("PGKEEPALIVES", "1"),
37        ("PGKEEPALIVESIDLE", "60"),
38        ("PGKEEPALIVESINTERVAL", "10"),
39    ]
40}
41
42/// Validate a PostgreSQL connection string
43///
44/// Checks that the connection string has proper format and required components:
45/// - Starts with "postgres://" or "postgresql://"
46/// - Contains user credentials (@ symbol)
47/// - Contains database name (/ separator with at least 3 occurrences)
48///
49/// # Arguments
50///
51/// * `url` - Connection string to validate
52///
53/// # Returns
54///
55/// Returns `Ok(())` if the connection string is valid.
56///
57/// # Errors
58///
59/// Returns an error with helpful message if the connection string is:
60/// - Empty or whitespace only
61/// - Missing proper scheme (postgres:// or postgresql://)
62/// - Missing user credentials (@ symbol)
63/// - Missing database name
64///
65/// # Examples
66///
67/// ```
68/// # use database_replicator::utils::validate_connection_string;
69/// # use anyhow::Result;
70/// # fn example() -> Result<()> {
71/// // Valid connection strings
72/// validate_connection_string("postgresql://user:pass@localhost:5432/mydb")?;
73/// validate_connection_string("postgres://user@host/db")?;
74///
75/// // Invalid - will return error
76/// assert!(validate_connection_string("").is_err());
77/// assert!(validate_connection_string("mysql://localhost/db").is_err());
78/// # Ok(())
79/// # }
80/// ```
81pub fn validate_connection_string(url: &str) -> Result<()> {
82    if url.trim().is_empty() {
83        bail!("Connection string cannot be empty");
84    }
85
86    // Check for common URL schemes
87    if !url.starts_with("postgres://") && !url.starts_with("postgresql://") {
88        bail!(
89            "Invalid connection string format.\n\
90             Expected format: postgresql://user:password@host:port/database\n\
91             Got: {}",
92            url
93        );
94    }
95
96    // Check for minimum required components (user@host/database)
97    if !url.contains('@') {
98        bail!(
99            "Connection string missing user credentials.\n\
100             Expected format: postgresql://user:password@host:port/database"
101        );
102    }
103
104    if !url.contains('/') || url.matches('/').count() < 3 {
105        bail!(
106            "Connection string missing database name.\n\
107             Expected format: postgresql://user:password@host:port/database"
108        );
109    }
110
111    Ok(())
112}
113
114/// Check that required PostgreSQL client tools are available
115///
116/// Verifies that the following tools are installed and in PATH:
117/// - `pg_dump` - For dumping database schema and data
118/// - `pg_dumpall` - For dumping global objects (roles, tablespaces)
119/// - `psql` - For restoring databases
120///
121/// # Returns
122///
123/// Returns `Ok(())` if all required tools are found.
124///
125/// # Errors
126///
127/// Returns an error with installation instructions if any tools are missing.
128///
129/// # Examples
130///
131/// ```
132/// # use database_replicator::utils::check_required_tools;
133/// # use anyhow::Result;
134/// # fn example() -> Result<()> {
135/// // Check if PostgreSQL tools are installed
136/// check_required_tools()?;
137/// # Ok(())
138/// # }
139/// ```
140pub fn check_required_tools() -> Result<()> {
141    let tools = ["pg_dump", "pg_dumpall", "psql"];
142    let mut missing = Vec::new();
143
144    for tool in &tools {
145        if which(tool).is_err() {
146            missing.push(*tool);
147        }
148    }
149
150    if !missing.is_empty() {
151        bail!(
152            "Missing required PostgreSQL client tools: {}\n\
153             \n\
154             Please install PostgreSQL client tools:\n\
155             - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
156             - macOS: brew install postgresql\n\
157             - RHEL/CentOS: sudo yum install postgresql\n\
158             - Windows: Download from https://www.postgresql.org/download/windows/",
159            missing.join(", ")
160        );
161    }
162
163    Ok(())
164}
165
166/// Retry a function with exponential backoff
167///
168/// Executes an async operation with automatic retry on failure. Each retry doubles
169/// the delay (exponential backoff) to handle transient failures gracefully.
170///
171/// # Arguments
172///
173/// * `operation` - Async function to retry (FnMut returning Future\<Output = Result\<T\>\>)
174/// * `max_retries` - Maximum number of retry attempts (0 = no retries, just initial attempt)
175/// * `initial_delay` - Delay before first retry (doubles each subsequent retry)
176///
177/// # Returns
178///
179/// Returns the successful result or the last error after all retries exhausted.
180///
181/// # Examples
182///
183/// ```no_run
184/// # use anyhow::Result;
185/// # use std::time::Duration;
186/// # use database_replicator::utils::retry_with_backoff;
187/// # async fn example() -> Result<()> {
188/// let result = retry_with_backoff(
189///     || async { Ok("success") },
190///     3,  // Try up to 3 times
191///     Duration::from_secs(1)  // Start with 1s delay
192/// ).await?;
193/// # Ok(())
194/// # }
195/// ```
196pub async fn retry_with_backoff<F, Fut, T>(
197    mut operation: F,
198    max_retries: u32,
199    initial_delay: Duration,
200) -> Result<T>
201where
202    F: FnMut() -> Fut,
203    Fut: std::future::Future<Output = Result<T>>,
204{
205    let mut delay = initial_delay;
206    let mut last_error = None;
207
208    for attempt in 0..=max_retries {
209        match operation().await {
210            Ok(result) => return Ok(result),
211            Err(e) => {
212                last_error = Some(e);
213
214                if attempt < max_retries {
215                    tracing::warn!(
216                        "Operation failed (attempt {}/{}), retrying in {:?}...",
217                        attempt + 1,
218                        max_retries + 1,
219                        delay
220                    );
221                    tokio::time::sleep(delay).await;
222                    delay *= 2; // Exponential backoff
223                }
224            }
225        }
226    }
227
228    Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Operation failed after retries")))
229}
230
231/// Retry a subprocess execution with exponential backoff on connection errors
232///
233/// Executes a subprocess command with automatic retry on connection-related failures.
234/// Each retry doubles the delay (exponential backoff) to handle transient connection issues.
235///
236/// Connection errors are detected by checking:
237/// - Non-zero exit codes
238/// - Stderr output containing connection-related error patterns:
239///   - "connection closed"
240///   - "connection refused"
241///   - "could not connect"
242///   - "server closed the connection"
243///   - "timeout"
244///   - "Connection timed out"
245///
246/// # Arguments
247///
248/// * `operation` - Function that executes a Command and returns the exit status
249/// * `max_retries` - Maximum number of retry attempts (0 = no retries, just initial attempt)
250/// * `initial_delay` - Delay before first retry (doubles each subsequent retry)
251/// * `operation_name` - Name of the operation for logging (e.g., "pg_restore", "psql")
252///
253/// # Returns
254///
255/// Returns Ok(()) on success or the last error after all retries exhausted.
256///
257/// # Examples
258///
259/// ```no_run
260/// # use anyhow::Result;
261/// # use std::time::Duration;
262/// # use std::process::Command;
263/// # use database_replicator::utils::retry_subprocess_with_backoff;
264/// # fn example() -> Result<()> {
265/// retry_subprocess_with_backoff(
266///     || {
267///         let mut cmd = Command::new("psql");
268///         cmd.arg("--version");
269///         cmd.status().map_err(anyhow::Error::from)
270///     },
271///     3,  // Try up to 3 times
272///     Duration::from_secs(1),  // Start with 1s delay
273///     "psql"
274/// )?;
275/// # Ok(())
276/// # }
277/// ```
278pub fn retry_subprocess_with_backoff<F>(
279    mut operation: F,
280    max_retries: u32,
281    initial_delay: Duration,
282    operation_name: &str,
283) -> Result<()>
284where
285    F: FnMut() -> Result<std::process::ExitStatus>,
286{
287    let mut delay = initial_delay;
288    let mut last_error = None;
289
290    for attempt in 0..=max_retries {
291        match operation() {
292            Ok(status) => {
293                if status.success() {
294                    return Ok(());
295                } else {
296                    // Non-zero exit code - check if it's a connection error
297                    // We can't easily capture stderr here, so we'll treat all non-zero
298                    // exit codes as potential connection errors for now
299                    let error = anyhow::anyhow!(
300                        "{} failed with exit code: {}",
301                        operation_name,
302                        status.code().unwrap_or(-1)
303                    );
304                    last_error = Some(error);
305
306                    if attempt < max_retries {
307                        tracing::warn!(
308                            "{} failed (attempt {}/{}), retrying in {:?}...",
309                            operation_name,
310                            attempt + 1,
311                            max_retries + 1,
312                            delay
313                        );
314                        std::thread::sleep(delay);
315                        delay *= 2; // Exponential backoff
316                    }
317                }
318            }
319            Err(e) => {
320                last_error = Some(e);
321
322                if attempt < max_retries {
323                    tracing::warn!(
324                        "{} failed (attempt {}/{}): {}, retrying in {:?}...",
325                        operation_name,
326                        attempt + 1,
327                        max_retries + 1,
328                        last_error.as_ref().unwrap(),
329                        delay
330                    );
331                    std::thread::sleep(delay);
332                    delay *= 2; // Exponential backoff
333                }
334            }
335        }
336    }
337
338    Err(last_error.unwrap_or_else(|| {
339        anyhow::anyhow!("{} failed after {} retries", operation_name, max_retries)
340    }))
341}
342
343/// Validate a PostgreSQL identifier (database name, schema name, etc.)
344///
345/// Validates that an identifier follows PostgreSQL naming rules to prevent SQL injection.
346/// PostgreSQL identifiers must:
347/// - Be 1-63 characters long
348/// - Start with a letter (a-z, A-Z) or underscore (_)
349/// - Contain only letters, digits (0-9), or underscores
350///
351/// # Arguments
352///
353/// * `identifier` - The identifier to validate (database name, schema name, etc.)
354///
355/// # Returns
356///
357/// Returns `Ok(())` if the identifier is valid.
358///
359/// # Errors
360///
361/// Returns an error if the identifier:
362/// - Is empty or whitespace-only
363/// - Exceeds 63 characters
364/// - Starts with an invalid character (digit or special character)
365/// - Contains invalid characters (anything except a-z, A-Z, 0-9, _)
366///
367/// # Security
368///
369/// This function is critical for preventing SQL injection attacks. All database
370/// names, schema names, and table names from untrusted sources MUST be validated
371/// before use in SQL statements.
372///
373/// # Examples
374///
375/// ```
376/// # use database_replicator::utils::validate_postgres_identifier;
377/// # use anyhow::Result;
378/// # fn example() -> Result<()> {
379/// // Valid identifiers
380/// validate_postgres_identifier("mydb")?;
381/// validate_postgres_identifier("my_database")?;
382/// validate_postgres_identifier("_private_db")?;
383///
384/// // Invalid - will return error
385/// assert!(validate_postgres_identifier("123db").is_err());
386/// assert!(validate_postgres_identifier("my-database").is_err());
387/// assert!(validate_postgres_identifier("db\"; DROP TABLE users; --").is_err());
388/// # Ok(())
389/// # }
390/// ```
391pub fn validate_postgres_identifier(identifier: &str) -> Result<()> {
392    // Check for empty or whitespace-only
393    let trimmed = identifier.trim();
394    if trimmed.is_empty() {
395        bail!("Identifier cannot be empty or whitespace-only");
396    }
397
398    // Check length (PostgreSQL limit is 63 characters)
399    if trimmed.len() > 63 {
400        bail!(
401            "Identifier '{}' exceeds maximum length of 63 characters (got {})",
402            sanitize_identifier(trimmed),
403            trimmed.len()
404        );
405    }
406
407    // Get first character
408    let first_char = trimmed.chars().next().unwrap();
409
410    // First character must be a letter or underscore
411    if !first_char.is_ascii_alphabetic() && first_char != '_' {
412        bail!(
413            "Identifier '{}' must start with a letter or underscore, not '{}'",
414            sanitize_identifier(trimmed),
415            first_char
416        );
417    }
418
419    // All characters must be alphanumeric or underscore
420    for (i, c) in trimmed.chars().enumerate() {
421        if !c.is_ascii_alphanumeric() && c != '_' {
422            bail!(
423                "Identifier '{}' contains invalid character '{}' at position {}. \
424                 Only letters, digits, and underscores are allowed",
425                sanitize_identifier(trimmed),
426                if c.is_control() {
427                    format!("\\x{:02x}", c as u32)
428                } else {
429                    c.to_string()
430                },
431                i
432            );
433        }
434    }
435
436    Ok(())
437}
438
439/// Sanitize an identifier (table name, schema name, etc.) for display
440///
441/// Removes control characters and limits length to prevent log injection attacks
442/// and ensure readable error messages.
443///
444/// **Note**: This is for display purposes only. For SQL safety, use parameterized
445/// queries instead.
446///
447/// # Arguments
448///
449/// * `identifier` - The identifier to sanitize (table name, schema name, etc.)
450///
451/// # Returns
452///
453/// Sanitized string with control characters removed and length limited to 100 chars.
454///
455/// # Examples
456///
457/// ```
458/// # use database_replicator::utils::sanitize_identifier;
459/// assert_eq!(sanitize_identifier("normal_table"), "normal_table");
460/// assert_eq!(sanitize_identifier("table\x00name"), "tablename");
461/// assert_eq!(sanitize_identifier("table\nname"), "tablename");
462///
463/// // Length limit
464/// let long_name = "a".repeat(200);
465/// assert_eq!(sanitize_identifier(&long_name).len(), 100);
466/// ```
467pub fn sanitize_identifier(identifier: &str) -> String {
468    // Remove any control characters and limit length for display
469    identifier
470        .chars()
471        .filter(|c| !c.is_control())
472        .take(100)
473        .collect()
474}
475
476/// Quote a PostgreSQL identifier (database, schema, table, column)
477///
478/// Assumes the identifier has already been validated. Escapes embedded quotes
479/// and wraps the identifier in double quotes.
480pub fn quote_ident(identifier: &str) -> String {
481    let mut quoted = String::with_capacity(identifier.len() + 2);
482    quoted.push('"');
483    for ch in identifier.chars() {
484        if ch == '"' {
485            quoted.push('"');
486        }
487        quoted.push(ch);
488    }
489    quoted.push('"');
490    quoted
491}
492
493/// Validate that source and target URLs are different to prevent accidental data loss
494///
495/// Compares two PostgreSQL connection URLs to ensure they point to different databases.
496/// This is critical for preventing data loss from operations like `init --drop-existing`
497/// where using the same URL for source and target would destroy the source data.
498///
499/// # Comparison Strategy
500///
501/// URLs are normalized and compared on:
502/// - Host (case-insensitive)
503/// - Port (defaulting to 5432 if not specified)
504/// - Database name (case-sensitive)
505/// - User (if present)
506///
507/// Query parameters (like SSL settings) are ignored as they don't affect database identity.
508///
509/// # Arguments
510///
511/// * `source_url` - Source database connection string
512/// * `target_url` - Target database connection string
513///
514/// # Returns
515///
516/// Returns `Ok(())` if the URLs point to different databases.
517///
518/// # Errors
519///
520/// Returns an error if:
521/// - The URLs point to the same database (same host, port, database name, and user)
522/// - Either URL is malformed and cannot be parsed
523///
524/// # Examples
525///
526/// ```
527/// # use database_replicator::utils::validate_source_target_different;
528/// # use anyhow::Result;
529/// # fn example() -> Result<()> {
530/// // Valid - different hosts
531/// validate_source_target_different(
532///     "postgresql://user:pass@source.com:5432/db",
533///     "postgresql://user:pass@target.com:5432/db"
534/// )?;
535///
536/// // Valid - different databases
537/// validate_source_target_different(
538///     "postgresql://user:pass@host:5432/db1",
539///     "postgresql://user:pass@host:5432/db2"
540/// )?;
541///
542/// // Invalid - same database
543/// assert!(validate_source_target_different(
544///     "postgresql://user:pass@host:5432/db",
545///     "postgresql://user:pass@host:5432/db"
546/// ).is_err());
547/// # Ok(())
548/// # }
549/// ```
550pub fn validate_source_target_different(source_url: &str, target_url: &str) -> Result<()> {
551    // Parse both URLs to extract components
552    let source_parts = parse_postgres_url(source_url)
553        .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
554    let target_parts = parse_postgres_url(target_url)
555        .with_context(|| format!("Failed to parse target URL: {}", target_url))?;
556
557    // Compare normalized components
558    if source_parts.host == target_parts.host
559        && source_parts.port == target_parts.port
560        && source_parts.database == target_parts.database
561        && source_parts.user == target_parts.user
562    {
563        bail!(
564            "Source and target URLs point to the same database!\\n\\\n             \\n\\\n             This would cause DATA LOSS - the target would overwrite the source.\\n\\\n             \\n\\\n             Source: {}@{}:{}/{}\\n\\\n             Target: {}@{}:{}/{}\\n\\\n             \\n\\\n             Please ensure source and target are different databases.\\n\\\n             Common causes:\\n\\\n             - Copy-paste error in connection strings\\n\\\n             - Wrong environment variables (e.g., SOURCE_URL == TARGET_URL)\\n\\\n             - Typo in database name or host",
565            source_parts.user.as_deref().unwrap_or("(no user)"),
566            source_parts.host,
567            source_parts.port,
568            source_parts.database,
569            target_parts.user.as_deref().unwrap_or("(no user)"),
570            target_parts.host,
571            target_parts.port,
572            target_parts.database
573        );
574    }
575
576    Ok(())
577}
578
579/// Parse a PostgreSQL URL into its components
580///
581/// # Arguments
582///
583/// * `url` - PostgreSQL connection URL (postgres:// or postgresql://)
584///
585/// # Returns
586///
587/// Returns a `PostgresUrlParts` struct with normalized components.
588///
589/// # Security
590///
591/// This function extracts passwords from URLs for use with .pgpass files.
592/// Ensure returned values are handled securely and not logged.
593pub fn parse_postgres_url(url: &str) -> Result<PostgresUrlParts> {
594    // Remove scheme
595    let url_without_scheme = url
596        .trim_start_matches("postgres://")
597        .trim_start_matches("postgresql://");
598
599    // Split into base and query params
600    let (base, query_string) = if let Some((b, q)) = url_without_scheme.split_once('?') {
601        (b, Some(q))
602    } else {
603        (url_without_scheme, None)
604    };
605
606    // Parse query parameters into HashMap
607    let mut query_params = std::collections::HashMap::new();
608    if let Some(query) = query_string {
609        for param in query.split('&') {
610            if let Some((key, value)) = param.split_once('=') {
611                query_params.insert(key.to_string(), value.to_string());
612            }
613        }
614    }
615
616    // Parse: [user[:password]@]host[:port]/database
617    let (auth_and_host, database) = base
618        .rsplit_once('/')
619        .ok_or_else(|| anyhow::anyhow!("Missing database name in URL"))?;
620
621    // Parse authentication and host
622    // Use rsplit_once to split from the right, so passwords can contain '@'
623    let (user, password, host_and_port) = if let Some((auth, hp)) = auth_and_host.rsplit_once('@') {
624        // Has authentication
625        let (user, pass) = if let Some((u, p)) = auth.split_once(':') {
626            (Some(u.to_string()), Some(p.to_string()))
627        } else {
628            (Some(auth.to_string()), None)
629        };
630        (user, pass, hp)
631    } else {
632        // No authentication
633        (None, None, auth_and_host)
634    };
635
636    // Parse host and port
637    let (host, port) = if let Some((h, p)) = host_and_port.rsplit_once(':') {
638        // Port specified
639        let port = p
640            .parse::<u16>()
641            .with_context(|| format!("Invalid port number: {}", p))?;
642        (h, port)
643    } else {
644        // Use default PostgreSQL port
645        (host_and_port, 5432)
646    };
647
648    Ok(PostgresUrlParts {
649        host: host.to_lowercase(), // Hostnames are case-insensitive
650        port,
651        database: database.to_string(), // Database names are case-sensitive in PostgreSQL
652        user,
653        password,
654        query_params,
655    })
656}
657
658/// Strip password from PostgreSQL connection URL
659/// Returns a new URL with password removed, preserving all other components
660/// This is useful for storing connection strings in places where passwords should not be visible
661pub fn strip_password_from_url(url: &str) -> Result<String> {
662    let parts = parse_postgres_url(url)?;
663
664    // Reconstruct URL without password
665    let scheme = if url.starts_with("postgresql://") {
666        "postgresql://"
667    } else if url.starts_with("postgres://") {
668        "postgres://"
669    } else {
670        bail!("Invalid PostgreSQL URL scheme");
671    };
672
673    let mut result = String::from(scheme);
674
675    // Add user if present (without password)
676    if let Some(user) = &parts.user {
677        result.push_str(user);
678        result.push('@');
679    }
680
681    // Add host and port
682    result.push_str(&parts.host);
683    result.push(':');
684    result.push_str(&parts.port.to_string());
685
686    // Add database
687    result.push('/');
688    result.push_str(&parts.database);
689
690    // Preserve query parameters if present
691    if let Some(query_start) = url.find('?') {
692        result.push_str(&url[query_start..]);
693    }
694
695    Ok(result)
696}
697
698/// Parsed components of a PostgreSQL connection URL
699#[derive(Debug, PartialEq)]
700pub struct PostgresUrlParts {
701    pub host: String,
702    pub port: u16,
703    pub database: String,
704    pub user: Option<String>,
705    pub password: Option<String>,
706    pub query_params: std::collections::HashMap<String, String>,
707}
708
709impl PostgresUrlParts {
710    /// Convert query parameters to PostgreSQL environment variables
711    ///
712    /// Maps common connection URL query parameters to their corresponding
713    /// PostgreSQL environment variable names. This allows SSL/TLS and other
714    /// connection settings to be passed to pg_dump, pg_dumpall, psql, etc.
715    ///
716    /// # Supported Parameters
717    ///
718    /// - `sslmode` → `PGSSLMODE`
719    /// - `sslcert` → `PGSSLCERT`
720    /// - `sslkey` → `PGSSLKEY`
721    /// - `sslrootcert` → `PGSSLROOTCERT`
722    /// - `channel_binding` → `PGCHANNELBINDING`
723    /// - `connect_timeout` → `PGCONNECT_TIMEOUT`
724    /// - `application_name` → `PGAPPNAME`
725    /// - `client_encoding` → `PGCLIENTENCODING`
726    ///
727    /// # Returns
728    ///
729    /// Vec of (env_var_name, value) pairs to be set as environment variables
730    pub fn to_pg_env_vars(&self) -> Vec<(&'static str, String)> {
731        let mut env_vars = Vec::new();
732
733        // Map query parameters to PostgreSQL environment variables
734        let param_mapping = [
735            ("sslmode", "PGSSLMODE"),
736            ("sslcert", "PGSSLCERT"),
737            ("sslkey", "PGSSLKEY"),
738            ("sslrootcert", "PGSSLROOTCERT"),
739            ("channel_binding", "PGCHANNELBINDING"),
740            ("connect_timeout", "PGCONNECT_TIMEOUT"),
741            ("application_name", "PGAPPNAME"),
742            ("client_encoding", "PGCLIENTENCODING"),
743        ];
744
745        for (param_name, env_var_name) in param_mapping {
746            if let Some(value) = self.query_params.get(param_name) {
747                env_vars.push((env_var_name, value.clone()));
748            }
749        }
750
751        env_vars
752    }
753}
754
755/// Managed .pgpass file for secure password passing to PostgreSQL tools
756///
757/// This struct creates a temporary .pgpass file with secure permissions (0600)
758/// and automatically cleans it up when dropped. PostgreSQL command-line tools
759/// read credentials from this file instead of accepting passwords in URLs,
760/// which prevents command injection vulnerabilities.
761///
762/// # Security
763///
764/// - File permissions are set to 0600 (owner read/write only)
765/// - File is automatically removed on Drop
766/// - Credentials are never passed on command line
767///
768/// # Format
769///
770/// .pgpass file format: hostname:port:database:username:password
771/// Wildcards (*) are used for maximum compatibility
772///
773/// # Examples
774///
775/// ```no_run
776/// # use database_replicator::utils::{PgPassFile, parse_postgres_url};
777/// # use anyhow::Result;
778/// # fn example() -> Result<()> {
779/// let url = "postgresql://user:pass@localhost:5432/mydb";
780/// let parts = parse_postgres_url(url)?;
781/// let pgpass = PgPassFile::new(&parts)?;
782///
783/// // Use pgpass.path() with PGPASSFILE environment variable
784/// // File is automatically cleaned up when pgpass goes out of scope
785/// # Ok(())
786/// # }
787/// ```
788pub struct PgPassFile {
789    path: std::path::PathBuf,
790}
791
792impl PgPassFile {
793    /// Create a new .pgpass file with credentials from URL parts
794    ///
795    /// # Arguments
796    ///
797    /// * `parts` - Parsed PostgreSQL URL components
798    ///
799    /// # Returns
800    ///
801    /// Returns a PgPassFile that will be automatically cleaned up on Drop
802    ///
803    /// # Errors
804    ///
805    /// Returns an error if the file cannot be created or permissions cannot be set
806    pub fn new(parts: &PostgresUrlParts) -> Result<Self> {
807        use std::fs;
808        use std::io::Write;
809
810        // Create temp file with secure name
811        let temp_dir = std::env::temp_dir();
812        let random: u32 = rand::random();
813        let filename = format!("pgpass-{:08x}", random);
814        let path = temp_dir.join(filename);
815
816        // Write .pgpass entry
817        // Format: hostname:port:database:username:password
818        let username = parts.user.as_deref().unwrap_or("*");
819        let password = parts.password.as_deref().unwrap_or("");
820        let entry = format!(
821            "{}:{}:{}:{}:{}\n",
822            parts.host, parts.port, parts.database, username, password
823        );
824
825        let mut file = fs::File::create(&path)
826            .with_context(|| format!("Failed to create .pgpass file at {}", path.display()))?;
827
828        file.write_all(entry.as_bytes())
829            .with_context(|| format!("Failed to write to .pgpass file at {}", path.display()))?;
830
831        // Set secure permissions (0600) - owner read/write only
832        #[cfg(unix)]
833        {
834            use std::os::unix::fs::PermissionsExt;
835            let permissions = fs::Permissions::from_mode(0o600);
836            fs::set_permissions(&path, permissions).with_context(|| {
837                format!(
838                    "Failed to set permissions on .pgpass file at {}",
839                    path.display()
840                )
841            })?;
842        }
843
844        // On Windows, .pgpass is stored in %APPDATA%\postgresql\pgpass.conf
845        // but for our temporary use case, we'll just use a temp file
846        // PostgreSQL on Windows also checks permissions but less strictly
847
848        Ok(Self { path })
849    }
850
851    /// Get the path to the .pgpass file
852    ///
853    /// Use this with the PGPASSFILE environment variable when running
854    /// PostgreSQL command-line tools
855    pub fn path(&self) -> &std::path::Path {
856        &self.path
857    }
858}
859
860impl Drop for PgPassFile {
861    fn drop(&mut self) {
862        // Best effort cleanup - don't panic if removal fails
863        let _ = std::fs::remove_file(&self.path);
864    }
865}
866
867/// Create a managed temporary directory with explicit cleanup support
868///
869/// Creates a temporary directory with a timestamped name that can be cleaned up
870/// even if the process is killed with SIGKILL. Unlike `TempDir::new()` which
871/// relies on the Drop trait, this function creates named directories that can
872/// be cleaned up on next process startup.
873///
874/// Directory naming format: `postgres-seren-replicator-{timestamp}-{random}`
875/// Example: `postgres-seren-replicator-20250106-120534-a3b2c1d4`
876///
877/// # Returns
878///
879/// Returns the path to the created temporary directory.
880///
881/// # Errors
882///
883/// Returns an error if the directory cannot be created.
884///
885/// # Examples
886///
887/// ```no_run
888/// # use database_replicator::utils::create_managed_temp_dir;
889/// # use anyhow::Result;
890/// # fn example() -> Result<()> {
891/// let temp_path = create_managed_temp_dir()?;
892/// println!("Using temp directory: {}", temp_path.display());
893/// // ... do work ...
894/// // Cleanup happens automatically on next startup via cleanup_stale_temp_dirs()
895/// # Ok(())
896/// # }
897/// ```
898pub fn create_managed_temp_dir() -> Result<std::path::PathBuf> {
899    use std::fs;
900    use std::time::SystemTime;
901
902    let system_temp = std::env::temp_dir();
903
904    // Generate timestamp for directory name
905    let timestamp = SystemTime::now()
906        .duration_since(SystemTime::UNIX_EPOCH)
907        .unwrap()
908        .as_secs();
909
910    // Generate random suffix for uniqueness
911    let random: u32 = rand::random();
912
913    // Create directory name with timestamp and random suffix
914    let dir_name = format!("postgres-seren-replicator-{}-{:08x}", timestamp, random);
915
916    let temp_path = system_temp.join(dir_name);
917
918    // Create the directory
919    fs::create_dir_all(&temp_path)
920        .with_context(|| format!("Failed to create temp directory at {}", temp_path.display()))?;
921
922    tracing::debug!("Created managed temp directory: {}", temp_path.display());
923
924    Ok(temp_path)
925}
926
927/// Clean up stale temporary directories from previous runs
928///
929/// Removes temporary directories created by `create_managed_temp_dir()` that are
930/// older than the specified age. This should be called on process startup to clean
931/// up directories left behind by processes killed with SIGKILL.
932///
933/// Only directories matching the pattern `postgres-seren-replicator-*` are removed.
934///
935/// # Arguments
936///
937/// * `max_age_secs` - Maximum age in seconds before a directory is considered stale
938///   (recommended: 86400 for 24 hours)
939///
940/// # Returns
941///
942/// Returns the number of directories cleaned up.
943///
944/// # Errors
945///
946/// Returns an error if the system temp directory cannot be read. Individual
947/// directory removal errors are logged but don't fail the entire operation.
948///
949/// # Examples
950///
951/// ```no_run
952/// # use database_replicator::utils::cleanup_stale_temp_dirs;
953/// # use anyhow::Result;
954/// # fn example() -> Result<()> {
955/// // Clean up temp directories older than 24 hours
956/// let cleaned = cleanup_stale_temp_dirs(86400)?;
957/// println!("Cleaned up {} stale temp directories", cleaned);
958/// # Ok(())
959/// # }
960/// ```
961pub fn cleanup_stale_temp_dirs(max_age_secs: u64) -> Result<usize> {
962    use std::fs;
963    use std::time::SystemTime;
964
965    let system_temp = std::env::temp_dir();
966    let now = SystemTime::now();
967    let mut cleaned_count = 0;
968
969    // Read all entries in system temp directory
970    let entries = fs::read_dir(&system_temp).with_context(|| {
971        format!(
972            "Failed to read system temp directory: {}",
973            system_temp.display()
974        )
975    })?;
976
977    for entry in entries.flatten() {
978        let path = entry.path();
979
980        // Only process directories matching our naming pattern
981        if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
982            if !name.starts_with("postgres-seren-replicator-") {
983                continue;
984            }
985
986            // Check directory age
987            match entry.metadata() {
988                Ok(metadata) => {
989                    if let Ok(modified) = metadata.modified() {
990                        if let Ok(age) = now.duration_since(modified) {
991                            if age.as_secs() > max_age_secs {
992                                // Directory is stale, remove it
993                                match fs::remove_dir_all(&path) {
994                                    Ok(_) => {
995                                        tracing::info!(
996                                            "Cleaned up stale temp directory: {} (age: {}s)",
997                                            path.display(),
998                                            age.as_secs()
999                                        );
1000                                        cleaned_count += 1;
1001                                    }
1002                                    Err(e) => {
1003                                        tracing::warn!(
1004                                            "Failed to remove stale temp directory {}: {}",
1005                                            path.display(),
1006                                            e
1007                                        );
1008                                    }
1009                                }
1010                            }
1011                        }
1012                    }
1013                }
1014                Err(e) => {
1015                    tracing::warn!(
1016                        "Failed to get metadata for temp directory {}: {}",
1017                        path.display(),
1018                        e
1019                    );
1020                }
1021            }
1022        }
1023    }
1024
1025    if cleaned_count > 0 {
1026        tracing::info!(
1027            "Cleaned up {} stale temp directory(ies) older than {} seconds",
1028            cleaned_count,
1029            max_age_secs
1030        );
1031    }
1032
1033    Ok(cleaned_count)
1034}
1035
1036/// Remove a managed temporary directory
1037///
1038/// Explicitly removes a temporary directory created by `create_managed_temp_dir()`.
1039/// This should be called when the directory is no longer needed.
1040///
1041/// # Arguments
1042///
1043/// * `path` - Path to the temporary directory to remove
1044///
1045/// # Errors
1046///
1047/// Returns an error if the directory cannot be removed.
1048///
1049/// # Examples
1050///
1051/// ```no_run
1052/// # use database_replicator::utils::{create_managed_temp_dir, remove_managed_temp_dir};
1053/// # use anyhow::Result;
1054/// # fn example() -> Result<()> {
1055/// let temp_path = create_managed_temp_dir()?;
1056/// // ... do work ...
1057/// remove_managed_temp_dir(&temp_path)?;
1058/// # Ok(())
1059/// # }
1060/// ```
1061pub fn remove_managed_temp_dir(path: &std::path::Path) -> Result<()> {
1062    use std::fs;
1063
1064    // Verify this is one of our temp directories (safety check)
1065    if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1066        if !name.starts_with("postgres-seren-replicator-") {
1067            bail!(
1068                "Refusing to remove directory that doesn't match our naming pattern: {}",
1069                path.display()
1070            );
1071        }
1072    } else {
1073        bail!("Invalid temp directory path: {}", path.display());
1074    }
1075
1076    tracing::debug!("Removing managed temp directory: {}", path.display());
1077
1078    fs::remove_dir_all(path)
1079        .with_context(|| format!("Failed to remove temp directory at {}", path.display()))?;
1080
1081    Ok(())
1082}
1083
1084#[cfg(test)]
1085mod tests {
1086    use super::*;
1087
1088    #[test]
1089    fn test_validate_connection_string_valid() {
1090        assert!(validate_connection_string("postgresql://user:pass@localhost:5432/dbname").is_ok());
1091        assert!(validate_connection_string("postgres://user@host/db").is_ok());
1092    }
1093
1094    #[test]
1095    fn test_check_required_tools() {
1096        // This test will pass if PostgreSQL client tools are installed
1097        // It will fail (appropriately) if they're not installed
1098        let result = check_required_tools();
1099
1100        // On systems with PostgreSQL installed, this should pass
1101        // On systems without it, we expect a specific error message
1102        if let Err(err) = result {
1103            let err_msg = err.to_string();
1104            assert!(err_msg.contains("Missing required PostgreSQL client tools"));
1105            assert!(
1106                err_msg.contains("pg_dump")
1107                    || err_msg.contains("pg_dumpall")
1108                    || err_msg.contains("psql")
1109            );
1110        }
1111    }
1112
1113    #[test]
1114    fn test_validate_connection_string_invalid() {
1115        assert!(validate_connection_string("").is_err());
1116        assert!(validate_connection_string("   ").is_err());
1117        assert!(validate_connection_string("mysql://localhost/db").is_err());
1118        assert!(validate_connection_string("postgresql://localhost").is_err());
1119        assert!(validate_connection_string("postgresql://localhost/db").is_err());
1120        // Missing user
1121    }
1122
1123    #[test]
1124    fn test_sanitize_identifier() {
1125        assert_eq!(sanitize_identifier("normal_table"), "normal_table");
1126        assert_eq!(sanitize_identifier("table\x00name"), "tablename");
1127        assert_eq!(sanitize_identifier("table\nname"), "tablename");
1128
1129        // Test length limit
1130        let long_name = "a".repeat(200);
1131        assert_eq!(sanitize_identifier(&long_name).len(), 100);
1132    }
1133
1134    #[tokio::test]
1135    async fn test_retry_with_backoff_success() {
1136        let mut attempts = 0;
1137        let result = retry_with_backoff(
1138            || {
1139                attempts += 1;
1140                async move {
1141                    if attempts < 3 {
1142                        anyhow::bail!("Temporary failure")
1143                    } else {
1144                        Ok("Success")
1145                    }
1146                }
1147            },
1148            5,
1149            Duration::from_millis(10),
1150        )
1151        .await;
1152
1153        assert!(result.is_ok());
1154        assert_eq!(result.unwrap(), "Success");
1155        assert_eq!(attempts, 3);
1156    }
1157
1158    #[tokio::test]
1159    async fn test_retry_with_backoff_failure() {
1160        let mut attempts = 0;
1161        let result: Result<&str> = retry_with_backoff(
1162            || {
1163                attempts += 1;
1164                async move { anyhow::bail!("Permanent failure") }
1165            },
1166            2,
1167            Duration::from_millis(10),
1168        )
1169        .await;
1170
1171        assert!(result.is_err());
1172        assert_eq!(attempts, 3); // Initial + 2 retries
1173    }
1174
1175    #[test]
1176    fn test_validate_source_target_different_valid() {
1177        // Different hosts
1178        assert!(validate_source_target_different(
1179            "postgresql://user:pass@source.com:5432/db",
1180            "postgresql://user:pass@target.com:5432/db"
1181        )
1182        .is_ok());
1183
1184        // Different databases on same host
1185        assert!(validate_source_target_different(
1186            "postgresql://user:pass@host:5432/db1",
1187            "postgresql://user:pass@host:5432/db2"
1188        )
1189        .is_ok());
1190
1191        // Different ports on same host
1192        assert!(validate_source_target_different(
1193            "postgresql://user:pass@host:5432/db",
1194            "postgresql://user:pass@host:5433/db"
1195        )
1196        .is_ok());
1197
1198        // Different users on same host/db (edge case but allowed)
1199        assert!(validate_source_target_different(
1200            "postgresql://user1:pass@host:5432/db",
1201            "postgresql://user2:pass@host:5432/db"
1202        )
1203        .is_ok());
1204    }
1205
1206    #[test]
1207    fn test_validate_source_target_different_invalid() {
1208        // Exact same URL
1209        assert!(validate_source_target_different(
1210            "postgresql://user:pass@host:5432/db",
1211            "postgresql://user:pass@host:5432/db"
1212        )
1213        .is_err());
1214
1215        // Same URL with different scheme (postgres vs postgresql)
1216        assert!(validate_source_target_different(
1217            "postgres://user:pass@host:5432/db",
1218            "postgresql://user:pass@host:5432/db"
1219        )
1220        .is_err());
1221
1222        // Same URL with default port vs explicit port
1223        assert!(validate_source_target_different(
1224            "postgresql://user:pass@host/db",
1225            "postgresql://user:pass@host:5432/db"
1226        )
1227        .is_err());
1228
1229        // Same URL with different query parameters (still same database)
1230        assert!(validate_source_target_different(
1231            "postgresql://user:pass@host:5432/db?sslmode=require",
1232            "postgresql://user:pass@host:5432/db?sslmode=prefer"
1233        )
1234        .is_err());
1235
1236        // Same host with different case (hostnames are case-insensitive)
1237        assert!(validate_source_target_different(
1238            "postgresql://user:pass@HOST.COM:5432/db",
1239            "postgresql://user:pass@host.com:5432/db"
1240        )
1241        .is_err());
1242    }
1243
1244    #[test]
1245    fn test_parse_postgres_url() {
1246        // Full URL with all components including password
1247        let parts = parse_postgres_url("postgresql://myuser:mypass@localhost:5432/mydb").unwrap();
1248        assert_eq!(parts.host, "localhost");
1249        assert_eq!(parts.port, 5432);
1250        assert_eq!(parts.database, "mydb");
1251        assert_eq!(parts.user, Some("myuser".to_string()));
1252        assert_eq!(parts.password, Some("mypass".to_string()));
1253
1254        // URL without port (should default to 5432)
1255        let parts = parse_postgres_url("postgresql://user:pass@host/db").unwrap();
1256        assert_eq!(parts.host, "host");
1257        assert_eq!(parts.port, 5432);
1258        assert_eq!(parts.database, "db");
1259        assert_eq!(parts.user, Some("user".to_string()));
1260        assert_eq!(parts.password, Some("pass".to_string()));
1261
1262        // URL with user but no password
1263        let parts = parse_postgres_url("postgresql://user@host/db").unwrap();
1264        assert_eq!(parts.host, "host");
1265        assert_eq!(parts.user, Some("user".to_string()));
1266        assert_eq!(parts.password, None);
1267
1268        // URL without authentication
1269        let parts = parse_postgres_url("postgresql://host:5433/db").unwrap();
1270        assert_eq!(parts.host, "host");
1271        assert_eq!(parts.port, 5433);
1272        assert_eq!(parts.database, "db");
1273        assert_eq!(parts.user, None);
1274        assert_eq!(parts.password, None);
1275
1276        // URL with query parameters
1277        let parts = parse_postgres_url("postgresql://user:pass@host/db?sslmode=require").unwrap();
1278        assert_eq!(parts.host, "host");
1279        assert_eq!(parts.database, "db");
1280        assert_eq!(parts.password, Some("pass".to_string()));
1281
1282        // URL with postgres:// scheme (alternative)
1283        let parts = parse_postgres_url("postgres://user:pass@host/db").unwrap();
1284        assert_eq!(parts.host, "host");
1285        assert_eq!(parts.database, "db");
1286        assert_eq!(parts.password, Some("pass".to_string()));
1287
1288        // Host normalization (lowercase)
1289        let parts = parse_postgres_url("postgresql://user:pass@HOST.COM/db").unwrap();
1290        assert_eq!(parts.host, "host.com");
1291        assert_eq!(parts.password, Some("pass".to_string()));
1292
1293        // Password with special characters
1294        let parts = parse_postgres_url("postgresql://user:p@ss!word@host/db").unwrap();
1295        assert_eq!(parts.password, Some("p@ss!word".to_string()));
1296    }
1297
1298    #[test]
1299    fn test_validate_postgres_identifier_valid() {
1300        // Valid identifiers
1301        assert!(validate_postgres_identifier("mydb").is_ok());
1302        assert!(validate_postgres_identifier("my_database").is_ok());
1303        assert!(validate_postgres_identifier("_private_db").is_ok());
1304        assert!(validate_postgres_identifier("db123").is_ok());
1305        assert!(validate_postgres_identifier("Database_2024").is_ok());
1306
1307        // Maximum length (63 characters)
1308        let max_length_name = "a".repeat(63);
1309        assert!(validate_postgres_identifier(&max_length_name).is_ok());
1310    }
1311
1312    #[test]
1313    fn test_pgpass_file_creation() {
1314        let parts = PostgresUrlParts {
1315            host: "localhost".to_string(),
1316            port: 5432,
1317            database: "testdb".to_string(),
1318            user: Some("testuser".to_string()),
1319            password: Some("testpass".to_string()),
1320            query_params: std::collections::HashMap::new(),
1321        };
1322
1323        let pgpass = PgPassFile::new(&parts).unwrap();
1324        assert!(pgpass.path().exists());
1325
1326        // Verify file content
1327        let content = std::fs::read_to_string(pgpass.path()).unwrap();
1328        assert_eq!(content, "localhost:5432:testdb:testuser:testpass\n");
1329
1330        // Verify permissions on Unix
1331        #[cfg(unix)]
1332        {
1333            use std::os::unix::fs::PermissionsExt;
1334            let metadata = std::fs::metadata(pgpass.path()).unwrap();
1335            let permissions = metadata.permissions();
1336            assert_eq!(permissions.mode() & 0o777, 0o600);
1337        }
1338
1339        // File should be cleaned up when pgpass is dropped
1340        let path = pgpass.path().to_path_buf();
1341        drop(pgpass);
1342        assert!(!path.exists());
1343    }
1344
1345    #[test]
1346    fn test_pgpass_file_without_password() {
1347        let parts = PostgresUrlParts {
1348            host: "localhost".to_string(),
1349            port: 5432,
1350            database: "testdb".to_string(),
1351            user: Some("testuser".to_string()),
1352            password: None,
1353            query_params: std::collections::HashMap::new(),
1354        };
1355
1356        let pgpass = PgPassFile::new(&parts).unwrap();
1357        let content = std::fs::read_to_string(pgpass.path()).unwrap();
1358        // Should use empty password
1359        assert_eq!(content, "localhost:5432:testdb:testuser:\n");
1360    }
1361
1362    #[test]
1363    fn test_pgpass_file_without_user() {
1364        let parts = PostgresUrlParts {
1365            host: "localhost".to_string(),
1366            port: 5432,
1367            database: "testdb".to_string(),
1368            user: None,
1369            password: Some("testpass".to_string()),
1370            query_params: std::collections::HashMap::new(),
1371        };
1372
1373        let pgpass = PgPassFile::new(&parts).unwrap();
1374        let content = std::fs::read_to_string(pgpass.path()).unwrap();
1375        // Should use wildcard for user
1376        assert_eq!(content, "localhost:5432:testdb:*:testpass\n");
1377    }
1378
1379    #[test]
1380    fn test_strip_password_from_url() {
1381        // With password
1382        let url = "postgresql://user:p@ssw0rd@host:5432/db";
1383        let stripped = strip_password_from_url(url).unwrap();
1384        assert_eq!(stripped, "postgresql://user@host:5432/db");
1385
1386        // With special characters in password
1387        let url = "postgresql://user:p@ss!w0rd@host:5432/db";
1388        let stripped = strip_password_from_url(url).unwrap();
1389        assert_eq!(stripped, "postgresql://user@host:5432/db");
1390
1391        // Without password
1392        let url = "postgresql://user@host:5432/db";
1393        let stripped = strip_password_from_url(url).unwrap();
1394        assert_eq!(stripped, "postgresql://user@host:5432/db");
1395
1396        // With query parameters
1397        let url = "postgresql://user:pass@host:5432/db?sslmode=require";
1398        let stripped = strip_password_from_url(url).unwrap();
1399        assert_eq!(stripped, "postgresql://user@host:5432/db?sslmode=require");
1400
1401        // No user
1402        let url = "postgresql://host:5432/db";
1403        let stripped = strip_password_from_url(url).unwrap();
1404        assert_eq!(stripped, "postgresql://host:5432/db");
1405    }
1406
1407    #[test]
1408    fn test_validate_postgres_identifier_invalid() {
1409        // SQL injection attempts
1410        assert!(validate_postgres_identifier("mydb\"; DROP DATABASE production; --").is_err());
1411        assert!(validate_postgres_identifier("db'; DELETE FROM users; --").is_err());
1412
1413        // Invalid start characters
1414        assert!(validate_postgres_identifier("123db").is_err()); // Starts with digit
1415        assert!(validate_postgres_identifier("$db").is_err()); // Starts with special char
1416        assert!(validate_postgres_identifier("-db").is_err()); // Starts with dash
1417
1418        // Contains invalid characters
1419        assert!(validate_postgres_identifier("my-database").is_err()); // Contains dash
1420        assert!(validate_postgres_identifier("my.database").is_err()); // Contains dot
1421        assert!(validate_postgres_identifier("my database").is_err()); // Contains space
1422        assert!(validate_postgres_identifier("my@db").is_err()); // Contains @
1423        assert!(validate_postgres_identifier("my#db").is_err()); // Contains #
1424
1425        // Empty or too long
1426        assert!(validate_postgres_identifier("").is_err());
1427        assert!(validate_postgres_identifier("   ").is_err());
1428
1429        // Over maximum length (64+ characters)
1430        let too_long = "a".repeat(64);
1431        assert!(validate_postgres_identifier(&too_long).is_err());
1432
1433        // Control characters
1434        assert!(validate_postgres_identifier("my\ndb").is_err());
1435        assert!(validate_postgres_identifier("my\tdb").is_err());
1436        assert!(validate_postgres_identifier("my\x00db").is_err());
1437    }
1438}