database_replicator/
utils.rs

1// ABOUTME: Utility functions for validation and error handling
2// ABOUTME: Provides input validation, retry logic, and resource cleanup
3
4use anyhow::{bail, Context, Result};
5use std::time::Duration;
6use which::which;
7
8/// Get TCP keepalive environment variables for PostgreSQL client tools
9///
10/// Returns environment variables that configure TCP keepalives for external
11/// PostgreSQL tools (pg_dump, pg_restore, psql, pg_dumpall). These prevent
12/// idle connection timeouts when connecting through load balancers like AWS ELB.
13///
14/// Environment variables returned:
15/// - `PGKEEPALIVES=1`: Enable TCP keepalives
16/// - `PGKEEPALIVESIDLE=60`: Send first keepalive after 60 seconds of idle time
17/// - `PGKEEPALIVESINTERVAL=10`: Send subsequent keepalives every 10 seconds
18///
19/// # Returns
20///
21/// A vector of (variable_name, value) tuples to be passed to subprocess commands
22///
23/// # Examples
24///
25/// ```
26/// # use database_replicator::utils::get_keepalive_env_vars;
27/// # use std::process::Command;
28/// let keepalive_vars = get_keepalive_env_vars();
29/// let mut cmd = Command::new("psql");
30/// for (key, value) in keepalive_vars {
31///     cmd.env(key, value);
32/// }
33/// ```
34pub fn get_keepalive_env_vars() -> Vec<(&'static str, &'static str)> {
35    vec![
36        ("PGKEEPALIVES", "1"),
37        ("PGKEEPALIVESIDLE", "60"),
38        ("PGKEEPALIVESINTERVAL", "10"),
39    ]
40}
41
42/// Validate a PostgreSQL connection string
43///
44/// Checks that the connection string has proper format and required components:
45/// - Starts with "postgres://" or "postgresql://"
46/// - Contains user credentials (@ symbol)
47/// - Contains database name (/ separator with at least 3 occurrences)
48///
49/// # Arguments
50///
51/// * `url` - Connection string to validate
52///
53/// # Returns
54///
55/// Returns `Ok(())` if the connection string is valid.
56///
57/// # Errors
58///
59/// Returns an error with helpful message if the connection string is:
60/// - Empty or whitespace only
61/// - Missing proper scheme (postgres:// or postgresql://)
62/// - Missing user credentials (@ symbol)
63/// - Missing database name
64///
65/// # Examples
66///
67/// ```
68/// # use database_replicator::utils::validate_connection_string;
69/// # use anyhow::Result;
70/// # fn example() -> Result<()> {
71/// // Valid connection strings
72/// validate_connection_string("postgresql://user:pass@localhost:5432/mydb")?;
73/// validate_connection_string("postgres://user@host/db")?;
74///
75/// // Invalid - will return error
76/// assert!(validate_connection_string("").is_err());
77/// assert!(validate_connection_string("mysql://localhost/db").is_err());
78/// # Ok(())
79/// # }
80/// ```
81pub fn validate_connection_string(url: &str) -> Result<()> {
82    if url.trim().is_empty() {
83        bail!("Connection string cannot be empty");
84    }
85
86    // Check for common URL schemes
87    if !url.starts_with("postgres://") && !url.starts_with("postgresql://") {
88        bail!(
89            "Invalid connection string format.\n\
90             Expected format: postgresql://user:password@host:port/database\n\
91             Got: {}",
92            url
93        );
94    }
95
96    // Check for minimum required components (user@host/database)
97    if !url.contains('@') {
98        bail!(
99            "Connection string missing user credentials.\n\
100             Expected format: postgresql://user:password@host:port/database"
101        );
102    }
103
104    if !url.contains('/') || url.matches('/').count() < 3 {
105        bail!(
106            "Connection string missing database name.\n\
107             Expected format: postgresql://user:password@host:port/database"
108        );
109    }
110
111    Ok(())
112}
113
114/// Check that required PostgreSQL client tools are available
115///
116/// Verifies that the following tools are installed and in PATH:
117/// - `pg_dump` - For dumping database schema and data
118/// - `pg_dumpall` - For dumping global objects (roles, tablespaces)
119/// - `psql` - For restoring databases
120///
121/// # Returns
122///
123/// Returns `Ok(())` if all required tools are found.
124///
125/// # Errors
126///
127/// Returns an error with installation instructions if any tools are missing.
128///
129/// # Examples
130///
131/// ```
132/// # use database_replicator::utils::check_required_tools;
133/// # use anyhow::Result;
134/// # fn example() -> Result<()> {
135/// // Check if PostgreSQL tools are installed
136/// check_required_tools()?;
137/// # Ok(())
138/// # }
139/// ```
140pub fn check_required_tools() -> Result<()> {
141    let tools = ["pg_dump", "pg_dumpall", "psql"];
142    let mut missing = Vec::new();
143
144    for tool in &tools {
145        if which(tool).is_err() {
146            missing.push(*tool);
147        }
148    }
149
150    if !missing.is_empty() {
151        bail!(
152            "Missing required PostgreSQL client tools: {}\n\
153             \n\
154             Please install PostgreSQL client tools:\n\
155             - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
156             - macOS: brew install postgresql\n\
157             - RHEL/CentOS: sudo yum install postgresql\n\
158             - Windows: Download from https://www.postgresql.org/download/windows/",
159            missing.join(", ")
160        );
161    }
162
163    Ok(())
164}
165
166/// Retry a function with exponential backoff
167///
168/// Executes an async operation with automatic retry on failure. Each retry doubles
169/// the delay (exponential backoff) to handle transient failures gracefully.
170///
171/// # Arguments
172///
173/// * `operation` - Async function to retry (FnMut returning Future\<Output = Result\<T\>\>)
174/// * `max_retries` - Maximum number of retry attempts (0 = no retries, just initial attempt)
175/// * `initial_delay` - Delay before first retry (doubles each subsequent retry)
176///
177/// # Returns
178///
179/// Returns the successful result or the last error after all retries exhausted.
180///
181/// # Examples
182///
183/// ```no_run
184/// # use anyhow::Result;
185/// # use std::time::Duration;
186/// # use database_replicator::utils::retry_with_backoff;
187/// # async fn example() -> Result<()> {
188/// let result = retry_with_backoff(
189///     || async { Ok("success") },
190///     3,  // Try up to 3 times
191///     Duration::from_secs(1)  // Start with 1s delay
192/// ).await?;
193/// # Ok(())
194/// # }
195/// ```
196pub async fn retry_with_backoff<F, Fut, T>(
197    mut operation: F,
198    max_retries: u32,
199    initial_delay: Duration,
200) -> Result<T>
201where
202    F: FnMut() -> Fut,
203    Fut: std::future::Future<Output = Result<T>>,
204{
205    let mut delay = initial_delay;
206    let mut last_error = None;
207
208    for attempt in 0..=max_retries {
209        match operation().await {
210            Ok(result) => return Ok(result),
211            Err(e) => {
212                last_error = Some(e);
213
214                if attempt < max_retries {
215                    tracing::warn!(
216                        "Operation failed (attempt {}/{}), retrying in {:?}...",
217                        attempt + 1,
218                        max_retries + 1,
219                        delay
220                    );
221                    tokio::time::sleep(delay).await;
222                    delay *= 2; // Exponential backoff
223                }
224            }
225        }
226    }
227
228    Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Operation failed after retries")))
229}
230
231/// Retry a subprocess execution with exponential backoff on connection errors
232///
233/// Executes a subprocess command with automatic retry on connection-related failures.
234/// Each retry doubles the delay (exponential backoff) to handle transient connection issues.
235///
236/// Connection errors are detected by checking:
237/// - Non-zero exit codes
238/// - Stderr output containing connection-related error patterns:
239///   - "connection closed"
240///   - "connection refused"
241///   - "could not connect"
242///   - "server closed the connection"
243///   - "timeout"
244///   - "Connection timed out"
245///
246/// # Arguments
247///
248/// * `operation` - Function that executes a Command and returns the exit status
249/// * `max_retries` - Maximum number of retry attempts (0 = no retries, just initial attempt)
250/// * `initial_delay` - Delay before first retry (doubles each subsequent retry)
251/// * `operation_name` - Name of the operation for logging (e.g., "pg_restore", "psql")
252///
253/// # Returns
254///
255/// Returns Ok(()) on success or the last error after all retries exhausted.
256///
257/// # Examples
258///
259/// ```no_run
260/// # use anyhow::Result;
261/// # use std::time::Duration;
262/// # use std::process::Command;
263/// # use database_replicator::utils::retry_subprocess_with_backoff;
264/// # async fn example() -> Result<()> {
265/// retry_subprocess_with_backoff(
266///     || {
267///         let mut cmd = Command::new("psql");
268///         cmd.arg("--version");
269///         cmd.status().map_err(anyhow::Error::from)
270///     },
271///     3,  // Try up to 3 times
272///     Duration::from_secs(1),  // Start with 1s delay
273///     "psql"
274/// ).await?;
275/// # Ok(())
276/// # }
277/// ```
278pub async fn retry_subprocess_with_backoff<F>(
279    mut operation: F,
280    max_retries: u32,
281    initial_delay: Duration,
282    operation_name: &str,
283) -> Result<()>
284where
285    F: FnMut() -> Result<std::process::ExitStatus>,
286{
287    let mut delay = initial_delay;
288    let mut last_error = None;
289
290    for attempt in 0..=max_retries {
291        match operation() {
292            Ok(status) => {
293                if status.success() {
294                    return Ok(());
295                } else {
296                    // Non-zero exit code - check if it's a connection error
297                    // We can't easily capture stderr here, so we'll treat all non-zero
298                    // exit codes as potential connection errors for now
299                    let error = anyhow::anyhow!(
300                        "{} failed with exit code: {}",
301                        operation_name,
302                        status.code().unwrap_or(-1)
303                    );
304                    last_error = Some(error);
305
306                    if attempt < max_retries {
307                        tracing::warn!(
308                            "{} failed (attempt {}/{}), retrying in {:?}...",
309                            operation_name,
310                            attempt + 1,
311                            max_retries + 1,
312                            delay
313                        );
314                        tokio::time::sleep(delay).await;
315                        delay *= 2; // Exponential backoff
316                    }
317                }
318            }
319            Err(e) => {
320                last_error = Some(e);
321
322                if attempt < max_retries {
323                    tracing::warn!(
324                        "{} failed (attempt {}/{}): {}, retrying in {:?}...",
325                        operation_name,
326                        attempt + 1,
327                        max_retries + 1,
328                        last_error.as_ref().unwrap(),
329                        delay
330                    );
331                    tokio::time::sleep(delay).await;
332                    delay *= 2; // Exponential backoff
333                }
334            }
335        }
336    }
337
338    Err(last_error.unwrap_or_else(|| {
339        anyhow::anyhow!("{} failed after {} retries", operation_name, max_retries)
340    }))
341}
342
343/// Validate a PostgreSQL identifier (database name, schema name, etc.)
344///
345/// Validates that an identifier follows PostgreSQL naming rules to prevent SQL injection.
346/// PostgreSQL identifiers must:
347/// - Be 1-63 characters long
348/// - Start with a letter (a-z, A-Z) or underscore (_)
349/// - Contain only letters, digits (0-9), or underscores
350///
351/// # Arguments
352///
353/// * `identifier` - The identifier to validate (database name, schema name, etc.)
354///
355/// # Returns
356///
357/// Returns `Ok(())` if the identifier is valid.
358///
359/// # Errors
360///
361/// Returns an error if the identifier:
362/// - Is empty or whitespace-only
363/// - Exceeds 63 characters
364/// - Starts with an invalid character (digit or special character)
365/// - Contains invalid characters (anything except a-z, A-Z, 0-9, _)
366///
367/// # Security
368///
369/// This function is critical for preventing SQL injection attacks. All database
370/// names, schema names, and table names from untrusted sources MUST be validated
371/// before use in SQL statements.
372///
373/// # Examples
374///
375/// ```
376/// # use database_replicator::utils::validate_postgres_identifier;
377/// # use anyhow::Result;
378/// # fn example() -> Result<()> {
379/// // Valid identifiers
380/// validate_postgres_identifier("mydb")?;
381/// validate_postgres_identifier("my_database")?;
382/// validate_postgres_identifier("_private_db")?;
383///
384/// // Invalid - will return error
385/// assert!(validate_postgres_identifier("123db").is_err());
386/// assert!(validate_postgres_identifier("my-database").is_err());
387/// assert!(validate_postgres_identifier("db\"; DROP TABLE users; --").is_err());
388/// # Ok(())
389/// # }
390/// ```
391pub fn validate_postgres_identifier(identifier: &str) -> Result<()> {
392    // Check for empty or whitespace-only
393    let trimmed = identifier.trim();
394    if trimmed.is_empty() {
395        bail!("Identifier cannot be empty or whitespace-only");
396    }
397
398    // Check length (PostgreSQL limit is 63 characters)
399    if trimmed.len() > 63 {
400        bail!(
401            "Identifier '{}' exceeds maximum length of 63 characters (got {})",
402            sanitize_identifier(trimmed),
403            trimmed.len()
404        );
405    }
406
407    // Get first character
408    let first_char = trimmed.chars().next().unwrap();
409
410    // First character must be a letter or underscore
411    if !first_char.is_ascii_alphabetic() && first_char != '_' {
412        bail!(
413            "Identifier '{}' must start with a letter or underscore, not '{}'",
414            sanitize_identifier(trimmed),
415            first_char
416        );
417    }
418
419    // All characters must be alphanumeric or underscore
420    for (i, c) in trimmed.chars().enumerate() {
421        if !c.is_ascii_alphanumeric() && c != '_' {
422            bail!(
423                "Identifier '{}' contains invalid character '{}' at position {}. \
424                 Only letters, digits, and underscores are allowed",
425                sanitize_identifier(trimmed),
426                if c.is_control() {
427                    format!("\\x{:02x}", c as u32)
428                } else {
429                    c.to_string()
430                },
431                i
432            );
433        }
434    }
435
436    Ok(())
437}
438
439/// Sanitize an identifier (table name, schema name, etc.) for display
440///
441/// Removes control characters and limits length to prevent log injection attacks
442/// and ensure readable error messages.
443///
444/// **Note**: This is for display purposes only. For SQL safety, use parameterized
445/// queries instead.
446///
447/// # Arguments
448///
449/// * `identifier` - The identifier to sanitize (table name, schema name, etc.)
450///
451/// # Returns
452///
453/// Sanitized string with control characters removed and length limited to 100 chars.
454///
455/// # Examples
456///
457/// ```
458/// # use database_replicator::utils::sanitize_identifier;
459/// assert_eq!(sanitize_identifier("normal_table"), "normal_table");
460/// assert_eq!(sanitize_identifier("table\x00name"), "tablename");
461/// assert_eq!(sanitize_identifier("table\nname"), "tablename");
462///
463/// // Length limit
464/// let long_name = "a".repeat(200);
465/// assert_eq!(sanitize_identifier(&long_name).len(), 100);
466/// ```
467pub fn sanitize_identifier(identifier: &str) -> String {
468    // Remove any control characters and limit length for display
469    identifier
470        .chars()
471        .filter(|c| !c.is_control())
472        .take(100)
473        .collect()
474}
475
476/// Quote a PostgreSQL identifier (database, schema, table, column)
477///
478/// Assumes the identifier has already been validated. Escapes embedded quotes
479/// and wraps the identifier in double quotes.
480pub fn quote_ident(identifier: &str) -> String {
481    let mut quoted = String::with_capacity(identifier.len() + 2);
482    quoted.push('"');
483    for ch in identifier.chars() {
484        if ch == '"' {
485            quoted.push('"');
486        }
487        quoted.push(ch);
488    }
489    quoted.push('"');
490    quoted
491}
492
493/// Quote a SQL string literal (for use in SQL statements)
494///
495/// Escapes single quotes by doubling them and wraps the string in single quotes.
496/// Use this for string values in SQL, not for identifiers.
497///
498/// # Examples
499///
500/// ```
501/// use database_replicator::utils::quote_literal;
502/// assert_eq!(quote_literal("hello"), "'hello'");
503/// assert_eq!(quote_literal("it's"), "'it''s'");
504/// assert_eq!(quote_literal(""), "''");
505/// ```
506pub fn quote_literal(value: &str) -> String {
507    let mut quoted = String::with_capacity(value.len() + 2);
508    quoted.push('\'');
509    for ch in value.chars() {
510        if ch == '\'' {
511            quoted.push('\'');
512        }
513        quoted.push(ch);
514    }
515    quoted.push('\'');
516    quoted
517}
518
519/// Quote a MySQL identifier (database, table, column)
520///
521/// MySQL uses backticks for identifier quoting. Escapes embedded backticks
522/// by doubling them.
523///
524/// # Examples
525///
526/// ```
527/// use database_replicator::utils::quote_mysql_ident;
528/// assert_eq!(quote_mysql_ident("users"), "`users`");
529/// assert_eq!(quote_mysql_ident("user`name"), "`user``name`");
530/// ```
531pub fn quote_mysql_ident(identifier: &str) -> String {
532    let mut quoted = String::with_capacity(identifier.len() + 2);
533    quoted.push('`');
534    for ch in identifier.chars() {
535        if ch == '`' {
536            quoted.push('`');
537        }
538        quoted.push(ch);
539    }
540    quoted.push('`');
541    quoted
542}
543
544/// Validate that source and target URLs are different to prevent accidental data loss
545///
546/// Compares two PostgreSQL connection URLs to ensure they point to different databases.
547/// This is critical for preventing data loss from operations like `init --drop-existing`
548/// where using the same URL for source and target would destroy the source data.
549///
550/// # Comparison Strategy
551///
552/// URLs are normalized and compared on:
553/// - Host (case-insensitive)
554/// - Port (defaulting to 5432 if not specified)
555/// - Database name (case-sensitive)
556/// - User (if present)
557///
558/// Query parameters (like SSL settings) are ignored as they don't affect database identity.
559///
560/// # Arguments
561///
562/// * `source_url` - Source database connection string
563/// * `target_url` - Target database connection string
564///
565/// # Returns
566///
567/// Returns `Ok(())` if the URLs point to different databases.
568///
569/// # Errors
570///
571/// Returns an error if:
572/// - The URLs point to the same database (same host, port, database name, and user)
573/// - Either URL is malformed and cannot be parsed
574///
575/// # Examples
576///
577/// ```
578/// # use database_replicator::utils::validate_source_target_different;
579/// # use anyhow::Result;
580/// # fn example() -> Result<()> {
581/// // Valid - different hosts
582/// validate_source_target_different(
583///     "postgresql://user:pass@source.com:5432/db",
584///     "postgresql://user:pass@target.com:5432/db"
585/// )?;
586///
587/// // Valid - different databases
588/// validate_source_target_different(
589///     "postgresql://user:pass@host:5432/db1",
590///     "postgresql://user:pass@host:5432/db2"
591/// )?;
592///
593/// // Invalid - same database
594/// assert!(validate_source_target_different(
595///     "postgresql://user:pass@host:5432/db",
596///     "postgresql://user:pass@host:5432/db"
597/// ).is_err());
598/// # Ok(())
599/// # }
600/// ```
601pub fn validate_source_target_different(source_url: &str, target_url: &str) -> Result<()> {
602    // Parse both URLs to extract components
603    let source_parts = parse_postgres_url(source_url)
604        .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
605    let target_parts = parse_postgres_url(target_url)
606        .with_context(|| format!("Failed to parse target URL: {}", target_url))?;
607
608    // Compare normalized components
609    if source_parts.host == target_parts.host
610        && source_parts.port == target_parts.port
611        && source_parts.database == target_parts.database
612        && source_parts.user == target_parts.user
613    {
614        bail!(
615            "Source and target URLs point to the same database!\n\
616             \n\
617             This would cause DATA LOSS - the target would overwrite the source.\n\
618             \n\
619             Source: {}@{}:{}/{}\n\
620             Target: {}@{}:{}/{}\n\
621             \n\
622             Please ensure source and target are different databases.\n\
623             Common causes:\n\
624             - Copy-paste error in connection strings\n\
625             - Wrong environment variables (e.g., SOURCE_URL == TARGET_URL)\n\
626             - Typo in database name or host",
627            source_parts.user.as_deref().unwrap_or("(no user)"),
628            source_parts.host,
629            source_parts.port,
630            source_parts.database,
631            target_parts.user.as_deref().unwrap_or("(no user)"),
632            target_parts.host,
633            target_parts.port,
634            target_parts.database
635        );
636    }
637
638    Ok(())
639}
640
641/// Parse a PostgreSQL URL into its components
642///
643/// # Arguments
644///
645/// * `url` - PostgreSQL connection URL (postgres:// or postgresql://)
646///
647/// # Returns
648///
649/// Returns a `PostgresUrlParts` struct with normalized components.
650///
651/// # Security
652///
653/// This function extracts passwords from URLs for use with .pgpass files.
654/// Ensure returned values are handled securely and not logged.
655pub fn parse_postgres_url(url: &str) -> Result<PostgresUrlParts> {
656    // Remove scheme
657    let url_without_scheme = url
658        .trim_start_matches("postgres://")
659        .trim_start_matches("postgresql://");
660
661    // Split into base and query params
662    let (base, query_string) = if let Some((b, q)) = url_without_scheme.split_once('?') {
663        (b, Some(q))
664    } else {
665        (url_without_scheme, None)
666    };
667
668    // Parse query parameters into HashMap
669    let mut query_params = std::collections::HashMap::new();
670    if let Some(query) = query_string {
671        for param in query.split('&') {
672            if let Some((key, value)) = param.split_once('=') {
673                query_params.insert(key.to_string(), value.to_string());
674            }
675        }
676    }
677
678    // Parse: [user[:password]@]host[:port]/database
679    let (auth_and_host, database) = base
680        .rsplit_once('/')
681        .ok_or_else(|| anyhow::anyhow!("Missing database name in URL"))?;
682
683    // Parse authentication and host
684    // Use rsplit_once to split from the right, so passwords can contain '@'
685    let (user, password, host_and_port) = if let Some((auth, hp)) = auth_and_host.rsplit_once('@') {
686        // Has authentication
687        let (user, pass) = if let Some((u, p)) = auth.split_once(':') {
688            (Some(u.to_string()), Some(p.to_string()))
689        } else {
690            (Some(auth.to_string()), None)
691        };
692        (user, pass, hp)
693    } else {
694        // No authentication
695        (None, None, auth_and_host)
696    };
697
698    // Parse host and port
699    let (host, port) = if let Some((h, p)) = host_and_port.rsplit_once(':') {
700        // Port specified
701        let port = p
702            .parse::<u16>()
703            .with_context(|| format!("Invalid port number: {}", p))?;
704        (h, port)
705    } else {
706        // Use default PostgreSQL port
707        (host_and_port, 5432)
708    };
709
710    Ok(PostgresUrlParts {
711        host: host.to_lowercase(), // Hostnames are case-insensitive
712        port,
713        database: database.to_string(), // Database names are case-sensitive in PostgreSQL
714        user,
715        password,
716        query_params,
717    })
718}
719
720/// Strip password from PostgreSQL connection URL
721/// Returns a new URL with password removed, preserving all other components
722/// This is useful for storing connection strings in places where passwords should not be visible
723pub fn strip_password_from_url(url: &str) -> Result<String> {
724    let parts = parse_postgres_url(url)?;
725
726    // Reconstruct URL without password
727    let scheme = if url.starts_with("postgresql://") {
728        "postgresql://"
729    } else if url.starts_with("postgres://") {
730        "postgres://"
731    } else {
732        bail!("Invalid PostgreSQL URL scheme");
733    };
734
735    let mut result = String::from(scheme);
736
737    // Add user if present (without password)
738    if let Some(user) = &parts.user {
739        result.push_str(user);
740        result.push('@');
741    }
742
743    // Add host and port
744    result.push_str(&parts.host);
745    result.push(':');
746    result.push_str(&parts.port.to_string());
747
748    // Add database
749    result.push('/');
750    result.push_str(&parts.database);
751
752    // Preserve query parameters if present
753    if let Some(query_start) = url.find('?') {
754        result.push_str(&url[query_start..]);
755    }
756
757    Ok(result)
758}
759
760/// Parsed components of a PostgreSQL connection URL
761#[derive(Debug, PartialEq)]
762pub struct PostgresUrlParts {
763    pub host: String,
764    pub port: u16,
765    pub database: String,
766    pub user: Option<String>,
767    pub password: Option<String>,
768    pub query_params: std::collections::HashMap<String, String>,
769}
770
771impl PostgresUrlParts {
772    /// Convert query parameters to PostgreSQL environment variables
773    ///
774    /// Maps common connection URL query parameters to their corresponding
775    /// PostgreSQL environment variable names. This allows SSL/TLS and other
776    /// connection settings to be passed to pg_dump, pg_dumpall, psql, etc.
777    ///
778    /// # Supported Parameters
779    ///
780    /// - `sslmode` → `PGSSLMODE`
781    /// - `sslcert` → `PGSSLCERT`
782    /// - `sslkey` → `PGSSLKEY`
783    /// - `sslrootcert` → `PGSSLROOTCERT`
784    /// - `channel_binding` → `PGCHANNELBINDING`
785    /// - `connect_timeout` → `PGCONNECT_TIMEOUT`
786    /// - `application_name` → `PGAPPNAME`
787    /// - `client_encoding` → `PGCLIENTENCODING`
788    ///
789    /// # Returns
790    ///
791    /// Vec of (env_var_name, value) pairs to be set as environment variables
792    pub fn to_pg_env_vars(&self) -> Vec<(&'static str, String)> {
793        let mut env_vars = Vec::new();
794
795        // Map query parameters to PostgreSQL environment variables
796        let param_mapping = [
797            ("sslmode", "PGSSLMODE"),
798            ("sslcert", "PGSSLCERT"),
799            ("sslkey", "PGSSLKEY"),
800            ("sslrootcert", "PGSSLROOTCERT"),
801            ("channel_binding", "PGCHANNELBINDING"),
802            ("connect_timeout", "PGCONNECT_TIMEOUT"),
803            ("application_name", "PGAPPNAME"),
804            ("client_encoding", "PGCLIENTENCODING"),
805        ];
806
807        for (param_name, env_var_name) in param_mapping {
808            if let Some(value) = self.query_params.get(param_name) {
809                env_vars.push((env_var_name, value.clone()));
810            }
811        }
812
813        env_vars
814    }
815}
816
817/// Managed .pgpass file for secure password passing to PostgreSQL tools
818///
819/// This struct creates a temporary .pgpass file with secure permissions (0600)
820/// and automatically cleans it up when dropped. PostgreSQL command-line tools
821/// read credentials from this file instead of accepting passwords in URLs,
822/// which prevents command injection vulnerabilities.
823///
824/// # Security
825///
826/// - File permissions are set to 0600 (owner read/write only)
827/// - File is automatically removed on Drop
828/// - Credentials are never passed on command line
829///
830/// # Format
831///
832/// .pgpass file format: hostname:port:database:username:password
833/// Wildcards (*) are used for maximum compatibility
834///
835/// # Examples
836///
837/// ```no_run
838/// # use database_replicator::utils::{PgPassFile, parse_postgres_url};
839/// # use anyhow::Result;
840/// # fn example() -> Result<()> {
841/// let url = "postgresql://user:pass@localhost:5432/mydb";
842/// let parts = parse_postgres_url(url)?;
843/// let pgpass = PgPassFile::new(&parts)?;
844///
845/// // Use pgpass.path() with PGPASSFILE environment variable
846/// // File is automatically cleaned up when pgpass goes out of scope
847/// # Ok(())
848/// # }
849/// ```
850pub struct PgPassFile {
851    path: std::path::PathBuf,
852}
853
854impl PgPassFile {
855    /// Create a new .pgpass file with credentials from URL parts
856    ///
857    /// # Arguments
858    ///
859    /// * `parts` - Parsed PostgreSQL URL components
860    ///
861    /// # Returns
862    ///
863    /// Returns a PgPassFile that will be automatically cleaned up on Drop
864    ///
865    /// # Errors
866    ///
867    /// Returns an error if the file cannot be created or permissions cannot be set
868    pub fn new(parts: &PostgresUrlParts) -> Result<Self> {
869        use std::fs;
870        use std::io::Write;
871
872        // Create temp file with secure name
873        let temp_dir = std::env::temp_dir();
874        let random: u32 = rand::random();
875        let filename = format!("pgpass-{:08x}", random);
876        let path = temp_dir.join(filename);
877
878        // Write .pgpass entry
879        // Format: hostname:port:database:username:password
880        let username = parts.user.as_deref().unwrap_or("*");
881        let password = parts.password.as_deref().unwrap_or("");
882        let entry = format!(
883            "{}:{}:{}:{}:{}\n",
884            parts.host, parts.port, parts.database, username, password
885        );
886
887        let mut file = fs::File::create(&path)
888            .with_context(|| format!("Failed to create .pgpass file at {}", path.display()))?;
889
890        file.write_all(entry.as_bytes())
891            .with_context(|| format!("Failed to write to .pgpass file at {}", path.display()))?;
892
893        // Set secure permissions (0600) - owner read/write only
894        #[cfg(unix)]
895        {
896            use std::os::unix::fs::PermissionsExt;
897            let permissions = fs::Permissions::from_mode(0o600);
898            fs::set_permissions(&path, permissions).with_context(|| {
899                format!(
900                    "Failed to set permissions on .pgpass file at {}",
901                    path.display()
902                )
903            })?;
904        }
905
906        // On Windows, .pgpass is stored in %APPDATA%\postgresql\pgpass.conf
907        // but for our temporary use case, we'll just use a temp file
908        // PostgreSQL on Windows also checks permissions but less strictly
909
910        Ok(Self { path })
911    }
912
913    /// Get the path to the .pgpass file
914    ///
915    /// Use this with the PGPASSFILE environment variable when running
916    /// PostgreSQL command-line tools
917    pub fn path(&self) -> &std::path::Path {
918        &self.path
919    }
920}
921
922impl Drop for PgPassFile {
923    fn drop(&mut self) {
924        // Best effort cleanup - don't panic if removal fails
925        let _ = std::fs::remove_file(&self.path);
926    }
927}
928
929/// Create a managed temporary directory with explicit cleanup support
930///
931/// Creates a temporary directory with a timestamped name that can be cleaned up
932/// even if the process is killed with SIGKILL. Unlike `TempDir::new()` which
933/// relies on the Drop trait, this function creates named directories that can
934/// be cleaned up on next process startup.
935///
936/// Directory naming format: `postgres-seren-replicator-{timestamp}-{random}`
937/// Example: `postgres-seren-replicator-20250106-120534-a3b2c1d4`
938///
939/// # Returns
940///
941/// Returns the path to the created temporary directory.
942///
943/// # Errors
944///
945/// Returns an error if the directory cannot be created.
946///
947/// # Examples
948///
949/// ```no_run
950/// # use database_replicator::utils::create_managed_temp_dir;
951/// # use anyhow::Result;
952/// # fn example() -> Result<()> {
953/// let temp_path = create_managed_temp_dir()?;
954/// println!("Using temp directory: {}", temp_path.display());
955/// // ... do work ...
956/// // Cleanup happens automatically on next startup via cleanup_stale_temp_dirs()
957/// # Ok(())
958/// # }
959/// ```
960pub fn create_managed_temp_dir() -> Result<std::path::PathBuf> {
961    use std::fs;
962    use std::time::SystemTime;
963
964    let system_temp = std::env::temp_dir();
965
966    // Generate timestamp for directory name
967    let timestamp = SystemTime::now()
968        .duration_since(SystemTime::UNIX_EPOCH)
969        .unwrap()
970        .as_secs();
971
972    // Generate random suffix for uniqueness
973    let random: u32 = rand::random();
974
975    // Create directory name with timestamp and random suffix
976    let dir_name = format!("postgres-seren-replicator-{}-{:08x}", timestamp, random);
977
978    let temp_path = system_temp.join(dir_name);
979
980    // Create the directory
981    fs::create_dir_all(&temp_path)
982        .with_context(|| format!("Failed to create temp directory at {}", temp_path.display()))?;
983
984    tracing::debug!("Created managed temp directory: {}", temp_path.display());
985
986    Ok(temp_path)
987}
988
989/// Clean up stale temporary directories from previous runs
990///
991/// Removes temporary directories created by `create_managed_temp_dir()` that are
992/// older than the specified age. This should be called on process startup to clean
993/// up directories left behind by processes killed with SIGKILL.
994///
995/// Only directories matching the pattern `postgres-seren-replicator-*` are removed.
996///
997/// # Arguments
998///
999/// * `max_age_secs` - Maximum age in seconds before a directory is considered stale
1000///   (recommended: 86400 for 24 hours)
1001///
1002/// # Returns
1003///
1004/// Returns the number of directories cleaned up.
1005///
1006/// # Errors
1007///
1008/// Returns an error if the system temp directory cannot be read. Individual
1009/// directory removal errors are logged but don't fail the entire operation.
1010///
1011/// # Examples
1012///
1013/// ```no_run
1014/// # use database_replicator::utils::cleanup_stale_temp_dirs;
1015/// # use anyhow::Result;
1016/// # fn example() -> Result<()> {
1017/// // Clean up temp directories older than 24 hours
1018/// let cleaned = cleanup_stale_temp_dirs(86400)?;
1019/// println!("Cleaned up {} stale temp directories", cleaned);
1020/// # Ok(())
1021/// # }
1022/// ```
1023pub fn cleanup_stale_temp_dirs(max_age_secs: u64) -> Result<usize> {
1024    use std::fs;
1025    use std::time::SystemTime;
1026
1027    let system_temp = std::env::temp_dir();
1028    let now = SystemTime::now();
1029    let mut cleaned_count = 0;
1030
1031    // Read all entries in system temp directory
1032    let entries = fs::read_dir(&system_temp).with_context(|| {
1033        format!(
1034            "Failed to read system temp directory: {}",
1035            system_temp.display()
1036        )
1037    })?;
1038
1039    for entry in entries.flatten() {
1040        let path = entry.path();
1041
1042        // Only process directories matching our naming pattern
1043        if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1044            if !name.starts_with("postgres-seren-replicator-") {
1045                continue;
1046            }
1047
1048            // Check directory age
1049            match entry.metadata() {
1050                Ok(metadata) => {
1051                    if let Ok(modified) = metadata.modified() {
1052                        if let Ok(age) = now.duration_since(modified) {
1053                            if age.as_secs() > max_age_secs {
1054                                // Directory is stale, remove it
1055                                match fs::remove_dir_all(&path) {
1056                                    Ok(_) => {
1057                                        tracing::info!(
1058                                            "Cleaned up stale temp directory: {} (age: {}s)",
1059                                            path.display(),
1060                                            age.as_secs()
1061                                        );
1062                                        cleaned_count += 1;
1063                                    }
1064                                    Err(e) => {
1065                                        tracing::warn!(
1066                                            "Failed to remove stale temp directory {}: {}",
1067                                            path.display(),
1068                                            e
1069                                        );
1070                                    }
1071                                }
1072                            }
1073                        }
1074                    }
1075                }
1076                Err(e) => {
1077                    tracing::warn!(
1078                        "Failed to get metadata for temp directory {}: {}",
1079                        path.display(),
1080                        e
1081                    );
1082                }
1083            }
1084        }
1085    }
1086
1087    if cleaned_count > 0 {
1088        tracing::info!(
1089            "Cleaned up {} stale temp directory(ies) older than {} seconds",
1090            cleaned_count,
1091            max_age_secs
1092        );
1093    }
1094
1095    Ok(cleaned_count)
1096}
1097
1098/// Remove a managed temporary directory
1099///
1100/// Explicitly removes a temporary directory created by `create_managed_temp_dir()`.
1101/// This should be called when the directory is no longer needed.
1102///
1103/// # Arguments
1104///
1105/// * `path` - Path to the temporary directory to remove
1106///
1107/// # Errors
1108///
1109/// Returns an error if the directory cannot be removed.
1110///
1111/// # Examples
1112///
1113/// ```no_run
1114/// # use database_replicator::utils::{create_managed_temp_dir, remove_managed_temp_dir};
1115/// # use anyhow::Result;
1116/// # fn example() -> Result<()> {
1117/// let temp_path = create_managed_temp_dir()?;
1118/// // ... do work ...
1119/// remove_managed_temp_dir(&temp_path)?;
1120/// # Ok(())
1121/// # }
1122/// ```
1123pub fn remove_managed_temp_dir(path: &std::path::Path) -> Result<()> {
1124    use std::fs;
1125
1126    // Verify this is one of our temp directories (safety check)
1127    if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1128        if !name.starts_with("postgres-seren-replicator-") {
1129            bail!(
1130                "Refusing to remove directory that doesn't match our naming pattern: {}",
1131                path.display()
1132            );
1133        }
1134    } else {
1135        bail!("Invalid temp directory path: {}", path.display());
1136    }
1137
1138    tracing::debug!("Removing managed temp directory: {}", path.display());
1139
1140    fs::remove_dir_all(path)
1141        .with_context(|| format!("Failed to remove temp directory at {}", path.display()))?;
1142
1143    Ok(())
1144}
1145
1146/// Check if a PostgreSQL URL points to a SerenDB instance
1147///
1148/// SerenDB hosts have domains ending with `.serendb.com`
1149///
1150/// # Arguments
1151///
1152/// * `url` - PostgreSQL connection string to check
1153///
1154/// # Returns
1155///
1156/// Returns `true` if the URL points to a SerenDB host.
1157///
1158/// # Examples
1159///
1160/// ```
1161/// use database_replicator::utils::is_serendb_target;
1162///
1163/// assert!(is_serendb_target("postgresql://user:pass@db.serendb.com/mydb"));
1164/// assert!(is_serendb_target("postgresql://user:pass@cluster-123.console.serendb.com/mydb"));
1165/// assert!(!is_serendb_target("postgresql://user:pass@localhost/mydb"));
1166/// assert!(!is_serendb_target("postgresql://user:pass@rds.amazonaws.com/mydb"));
1167/// ```
1168pub fn is_serendb_target(url: &str) -> bool {
1169    match parse_postgres_url(url) {
1170        Ok(parts) => parts.host.ends_with(".serendb.com") || parts.host == "serendb.com",
1171        Err(_) => false,
1172    }
1173}
1174
1175/// Get the major version of a PostgreSQL client tool (pg_dump, psql, etc.)
1176///
1177/// Executes `<tool> --version` and parses the output.
1178///
1179/// # Arguments
1180///
1181/// * `tool` - Name of the tool (e.g., "pg_dump", "psql")
1182///
1183/// # Returns
1184///
1185/// The major version number (e.g., 16 for pg_dump 16.10)
1186///
1187/// # Errors
1188///
1189/// Returns an error if:
1190/// - Tool is not found in PATH
1191/// - Tool execution fails
1192/// - Version output cannot be parsed
1193///
1194/// # Examples
1195///
1196/// ```no_run
1197/// use database_replicator::utils::get_pg_tool_version;
1198/// use anyhow::Result;
1199///
1200/// fn example() -> Result<()> {
1201///     let version = get_pg_tool_version("pg_dump")?;
1202///     println!("pg_dump major version: {}", version); // e.g., 16
1203///     Ok(())
1204/// }
1205/// ```
1206pub fn get_pg_tool_version(tool: &str) -> Result<u32> {
1207    use std::process::Command;
1208
1209    let path = which(tool).with_context(|| format!("{} not found in PATH", tool))?;
1210
1211    let output = Command::new(&path)
1212        .arg("--version")
1213        .output()
1214        .with_context(|| format!("Failed to execute {} --version", tool))?;
1215
1216    let version_str = String::from_utf8_lossy(&output.stdout);
1217    parse_pg_version_string(&version_str)
1218}
1219
1220/// Parse major version from PostgreSQL version string
1221///
1222/// Handles formats like:
1223/// - "pg_dump (PostgreSQL) 16.10 (Ubuntu 16.10-0ubuntu0.24.04.1)"
1224/// - "psql (PostgreSQL) 17.2"
1225/// - "17.2 (Debian 17.2-1.pgdg120+1)"
1226///
1227/// # Arguments
1228///
1229/// * `version_str` - Version string output from a PostgreSQL tool
1230///
1231/// # Returns
1232///
1233/// The major version number (e.g., 16, 17)
1234///
1235/// # Errors
1236///
1237/// Returns an error if the version cannot be parsed.
1238pub fn parse_pg_version_string(version_str: &str) -> Result<u32> {
1239    // Look for version pattern: major.minor
1240    for word in version_str.split_whitespace() {
1241        if let Some(major_str) = word.split('.').next() {
1242            if let Ok(major) = major_str.parse::<u32>() {
1243                // Valid PostgreSQL versions are between 9 and 99
1244                if (9..=99).contains(&major) {
1245                    return Ok(major);
1246                }
1247            }
1248        }
1249    }
1250    bail!("Could not parse PostgreSQL version from: {}", version_str)
1251}
1252
1253#[cfg(test)]
1254mod tests {
1255    use super::*;
1256
1257    #[test]
1258    fn test_validate_connection_string_valid() {
1259        assert!(validate_connection_string("postgresql://user:pass@localhost:5432/dbname").is_ok());
1260        assert!(validate_connection_string("postgres://user@host/db").is_ok());
1261    }
1262
1263    #[test]
1264    fn test_check_required_tools() {
1265        // This test will pass if PostgreSQL client tools are installed
1266        // It will fail (appropriately) if they're not installed
1267        let result = check_required_tools();
1268
1269        // On systems with PostgreSQL installed, this should pass
1270        // On systems without it, we expect a specific error message
1271        if let Err(err) = result {
1272            let err_msg = err.to_string();
1273            assert!(err_msg.contains("Missing required PostgreSQL client tools"));
1274            assert!(
1275                err_msg.contains("pg_dump")
1276                    || err_msg.contains("pg_dumpall")
1277                    || err_msg.contains("psql")
1278            );
1279        }
1280    }
1281
1282    #[test]
1283    fn test_validate_connection_string_invalid() {
1284        assert!(validate_connection_string("").is_err());
1285        assert!(validate_connection_string("   ").is_err());
1286        assert!(validate_connection_string("mysql://localhost/db").is_err());
1287        assert!(validate_connection_string("postgresql://localhost").is_err());
1288        assert!(validate_connection_string("postgresql://localhost/db").is_err());
1289        // Missing user
1290    }
1291
1292    #[test]
1293    fn test_sanitize_identifier() {
1294        assert_eq!(sanitize_identifier("normal_table"), "normal_table");
1295        assert_eq!(sanitize_identifier("table\x00name"), "tablename");
1296        assert_eq!(sanitize_identifier("table\nname"), "tablename");
1297
1298        // Test length limit
1299        let long_name = "a".repeat(200);
1300        assert_eq!(sanitize_identifier(&long_name).len(), 100);
1301    }
1302
1303    #[tokio::test]
1304    async fn test_retry_with_backoff_success() {
1305        let mut attempts = 0;
1306        let result = retry_with_backoff(
1307            || {
1308                attempts += 1;
1309                async move {
1310                    if attempts < 3 {
1311                        anyhow::bail!("Temporary failure")
1312                    } else {
1313                        Ok("Success")
1314                    }
1315                }
1316            },
1317            5,
1318            Duration::from_millis(10),
1319        )
1320        .await;
1321
1322        assert!(result.is_ok());
1323        assert_eq!(result.unwrap(), "Success");
1324        assert_eq!(attempts, 3);
1325    }
1326
1327    #[tokio::test]
1328    async fn test_retry_with_backoff_failure() {
1329        let mut attempts = 0;
1330        let result: Result<&str> = retry_with_backoff(
1331            || {
1332                attempts += 1;
1333                async move { anyhow::bail!("Permanent failure") }
1334            },
1335            2,
1336            Duration::from_millis(10),
1337        )
1338        .await;
1339
1340        assert!(result.is_err());
1341        assert_eq!(attempts, 3); // Initial + 2 retries
1342    }
1343
1344    #[test]
1345    fn test_validate_source_target_different_valid() {
1346        // Different hosts
1347        assert!(validate_source_target_different(
1348            "postgresql://user:pass@source.com:5432/db",
1349            "postgresql://user:pass@target.com:5432/db"
1350        )
1351        .is_ok());
1352
1353        // Different databases on same host
1354        assert!(validate_source_target_different(
1355            "postgresql://user:pass@host:5432/db1",
1356            "postgresql://user:pass@host:5432/db2"
1357        )
1358        .is_ok());
1359
1360        // Different ports on same host
1361        assert!(validate_source_target_different(
1362            "postgresql://user:pass@host:5432/db",
1363            "postgresql://user:pass@host:5433/db"
1364        )
1365        .is_ok());
1366
1367        // Different users on same host/db (edge case but allowed)
1368        assert!(validate_source_target_different(
1369            "postgresql://user1:pass@host:5432/db",
1370            "postgresql://user2:pass@host:5432/db"
1371        )
1372        .is_ok());
1373    }
1374
1375    #[test]
1376    fn test_validate_source_target_different_invalid() {
1377        // Exact same URL
1378        assert!(validate_source_target_different(
1379            "postgresql://user:pass@host:5432/db",
1380            "postgresql://user:pass@host:5432/db"
1381        )
1382        .is_err());
1383
1384        // Same URL with different scheme (postgres vs postgresql)
1385        assert!(validate_source_target_different(
1386            "postgres://user:pass@host:5432/db",
1387            "postgresql://user:pass@host:5432/db"
1388        )
1389        .is_err());
1390
1391        // Same URL with default port vs explicit port
1392        assert!(validate_source_target_different(
1393            "postgresql://user:pass@host/db",
1394            "postgresql://user:pass@host:5432/db"
1395        )
1396        .is_err());
1397
1398        // Same URL with different query parameters (still same database)
1399        assert!(validate_source_target_different(
1400            "postgresql://user:pass@host:5432/db?sslmode=require",
1401            "postgresql://user:pass@host:5432/db?sslmode=prefer"
1402        )
1403        .is_err());
1404
1405        // Same host with different case (hostnames are case-insensitive)
1406        assert!(validate_source_target_different(
1407            "postgresql://user:pass@HOST.COM:5432/db",
1408            "postgresql://user:pass@host.com:5432/db"
1409        )
1410        .is_err());
1411    }
1412
1413    #[test]
1414    fn test_parse_postgres_url() {
1415        // Full URL with all components including password
1416        let parts = parse_postgres_url("postgresql://myuser:mypass@localhost:5432/mydb").unwrap();
1417        assert_eq!(parts.host, "localhost");
1418        assert_eq!(parts.port, 5432);
1419        assert_eq!(parts.database, "mydb");
1420        assert_eq!(parts.user, Some("myuser".to_string()));
1421        assert_eq!(parts.password, Some("mypass".to_string()));
1422
1423        // URL without port (should default to 5432)
1424        let parts = parse_postgres_url("postgresql://user:pass@host/db").unwrap();
1425        assert_eq!(parts.host, "host");
1426        assert_eq!(parts.port, 5432);
1427        assert_eq!(parts.database, "db");
1428        assert_eq!(parts.user, Some("user".to_string()));
1429        assert_eq!(parts.password, Some("pass".to_string()));
1430
1431        // URL with user but no password
1432        let parts = parse_postgres_url("postgresql://user@host/db").unwrap();
1433        assert_eq!(parts.host, "host");
1434        assert_eq!(parts.user, Some("user".to_string()));
1435        assert_eq!(parts.password, None);
1436
1437        // URL without authentication
1438        let parts = parse_postgres_url("postgresql://host:5433/db").unwrap();
1439        assert_eq!(parts.host, "host");
1440        assert_eq!(parts.port, 5433);
1441        assert_eq!(parts.database, "db");
1442        assert_eq!(parts.user, None);
1443        assert_eq!(parts.password, None);
1444
1445        // URL with query parameters
1446        let parts = parse_postgres_url("postgresql://user:pass@host/db?sslmode=require").unwrap();
1447        assert_eq!(parts.host, "host");
1448        assert_eq!(parts.database, "db");
1449        assert_eq!(parts.password, Some("pass".to_string()));
1450
1451        // URL with postgres:// scheme (alternative)
1452        let parts = parse_postgres_url("postgres://user:pass@host/db").unwrap();
1453        assert_eq!(parts.host, "host");
1454        assert_eq!(parts.database, "db");
1455        assert_eq!(parts.password, Some("pass".to_string()));
1456
1457        // Host normalization (lowercase)
1458        let parts = parse_postgres_url("postgresql://user:pass@HOST.COM/db").unwrap();
1459        assert_eq!(parts.host, "host.com");
1460        assert_eq!(parts.password, Some("pass".to_string()));
1461
1462        // Password with special characters
1463        let parts = parse_postgres_url("postgresql://user:p@ss!word@host/db").unwrap();
1464        assert_eq!(parts.password, Some("p@ss!word".to_string()));
1465    }
1466
1467    #[test]
1468    fn test_validate_postgres_identifier_valid() {
1469        // Valid identifiers
1470        assert!(validate_postgres_identifier("mydb").is_ok());
1471        assert!(validate_postgres_identifier("my_database").is_ok());
1472        assert!(validate_postgres_identifier("_private_db").is_ok());
1473        assert!(validate_postgres_identifier("db123").is_ok());
1474        assert!(validate_postgres_identifier("Database_2024").is_ok());
1475
1476        // Maximum length (63 characters)
1477        let max_length_name = "a".repeat(63);
1478        assert!(validate_postgres_identifier(&max_length_name).is_ok());
1479    }
1480
1481    #[test]
1482    fn test_pgpass_file_creation() {
1483        let parts = PostgresUrlParts {
1484            host: "localhost".to_string(),
1485            port: 5432,
1486            database: "testdb".to_string(),
1487            user: Some("testuser".to_string()),
1488            password: Some("testpass".to_string()),
1489            query_params: std::collections::HashMap::new(),
1490        };
1491
1492        let pgpass = PgPassFile::new(&parts).unwrap();
1493        assert!(pgpass.path().exists());
1494
1495        // Verify file content
1496        let content = std::fs::read_to_string(pgpass.path()).unwrap();
1497        assert_eq!(content, "localhost:5432:testdb:testuser:testpass\n");
1498
1499        // Verify permissions on Unix
1500        #[cfg(unix)]
1501        {
1502            use std::os::unix::fs::PermissionsExt;
1503            let metadata = std::fs::metadata(pgpass.path()).unwrap();
1504            let permissions = metadata.permissions();
1505            assert_eq!(permissions.mode() & 0o777, 0o600);
1506        }
1507
1508        // File should be cleaned up when pgpass is dropped
1509        let path = pgpass.path().to_path_buf();
1510        drop(pgpass);
1511        assert!(!path.exists());
1512    }
1513
1514    #[test]
1515    fn test_pgpass_file_without_password() {
1516        let parts = PostgresUrlParts {
1517            host: "localhost".to_string(),
1518            port: 5432,
1519            database: "testdb".to_string(),
1520            user: Some("testuser".to_string()),
1521            password: None,
1522            query_params: std::collections::HashMap::new(),
1523        };
1524
1525        let pgpass = PgPassFile::new(&parts).unwrap();
1526        let content = std::fs::read_to_string(pgpass.path()).unwrap();
1527        // Should use empty password
1528        assert_eq!(content, "localhost:5432:testdb:testuser:\n");
1529    }
1530
1531    #[test]
1532    fn test_pgpass_file_without_user() {
1533        let parts = PostgresUrlParts {
1534            host: "localhost".to_string(),
1535            port: 5432,
1536            database: "testdb".to_string(),
1537            user: None,
1538            password: Some("testpass".to_string()),
1539            query_params: std::collections::HashMap::new(),
1540        };
1541
1542        let pgpass = PgPassFile::new(&parts).unwrap();
1543        let content = std::fs::read_to_string(pgpass.path()).unwrap();
1544        // Should use wildcard for user
1545        assert_eq!(content, "localhost:5432:testdb:*:testpass\n");
1546    }
1547
1548    #[test]
1549    fn test_strip_password_from_url() {
1550        // With password
1551        let url = "postgresql://user:p@ssw0rd@host:5432/db";
1552        let stripped = strip_password_from_url(url).unwrap();
1553        assert_eq!(stripped, "postgresql://user@host:5432/db");
1554
1555        // With special characters in password
1556        let url = "postgresql://user:p@ss!w0rd@host:5432/db";
1557        let stripped = strip_password_from_url(url).unwrap();
1558        assert_eq!(stripped, "postgresql://user@host:5432/db");
1559
1560        // Without password
1561        let url = "postgresql://user@host:5432/db";
1562        let stripped = strip_password_from_url(url).unwrap();
1563        assert_eq!(stripped, "postgresql://user@host:5432/db");
1564
1565        // With query parameters
1566        let url = "postgresql://user:pass@host:5432/db?sslmode=require";
1567        let stripped = strip_password_from_url(url).unwrap();
1568        assert_eq!(stripped, "postgresql://user@host:5432/db?sslmode=require");
1569
1570        // No user
1571        let url = "postgresql://host:5432/db";
1572        let stripped = strip_password_from_url(url).unwrap();
1573        assert_eq!(stripped, "postgresql://host:5432/db");
1574    }
1575
1576    #[test]
1577    fn test_validate_postgres_identifier_invalid() {
1578        // SQL injection attempts
1579        assert!(validate_postgres_identifier("mydb\"; DROP DATABASE production; --").is_err());
1580        assert!(validate_postgres_identifier("db'; DELETE FROM users; --").is_err());
1581
1582        // Invalid start characters
1583        assert!(validate_postgres_identifier("123db").is_err()); // Starts with digit
1584        assert!(validate_postgres_identifier("$db").is_err()); // Starts with special char
1585        assert!(validate_postgres_identifier("-db").is_err()); // Starts with dash
1586
1587        // Contains invalid characters
1588        assert!(validate_postgres_identifier("my-database").is_err()); // Contains dash
1589        assert!(validate_postgres_identifier("my.database").is_err()); // Contains dot
1590        assert!(validate_postgres_identifier("my database").is_err()); // Contains space
1591        assert!(validate_postgres_identifier("my@db").is_err()); // Contains @
1592        assert!(validate_postgres_identifier("my#db").is_err()); // Contains #
1593
1594        // Empty or too long
1595        assert!(validate_postgres_identifier("").is_err());
1596        assert!(validate_postgres_identifier("   ").is_err());
1597
1598        // Over maximum length (64+ characters)
1599        let too_long = "a".repeat(64);
1600        assert!(validate_postgres_identifier(&too_long).is_err());
1601
1602        // Control characters
1603        assert!(validate_postgres_identifier("my\ndb").is_err());
1604        assert!(validate_postgres_identifier("my\tdb").is_err());
1605        assert!(validate_postgres_identifier("my\x00db").is_err());
1606    }
1607
1608    #[test]
1609    fn test_is_serendb_target() {
1610        // Positive cases - SerenDB hosts
1611        assert!(is_serendb_target(
1612            "postgresql://user:pass@db.serendb.com/mydb"
1613        ));
1614        assert!(is_serendb_target(
1615            "postgresql://user:pass@cluster.console.serendb.com/mydb"
1616        ));
1617        assert!(is_serendb_target(
1618            "postgres://u:p@x.serendb.com:5432/db?sslmode=require"
1619        ));
1620        assert!(is_serendb_target("postgresql://user:pass@serendb.com/mydb"));
1621
1622        // Negative cases - not SerenDB
1623        assert!(!is_serendb_target("postgresql://user:pass@localhost/mydb"));
1624        assert!(!is_serendb_target(
1625            "postgresql://user:pass@rds.amazonaws.com/mydb"
1626        ));
1627        assert!(!is_serendb_target("postgresql://user:pass@neon.tech/mydb"));
1628        // Domain spoofing attempt - should NOT match
1629        assert!(!is_serendb_target(
1630            "postgresql://user:pass@serendb.com.evil.com/mydb"
1631        ));
1632        assert!(!is_serendb_target(
1633            "postgresql://user:pass@notserendb.com/mydb"
1634        ));
1635        // Invalid URL
1636        assert!(!is_serendb_target("not-a-url"));
1637    }
1638
1639    #[test]
1640    fn test_parse_pg_version_string() {
1641        // Standard pg_dump output
1642        assert_eq!(
1643            parse_pg_version_string("pg_dump (PostgreSQL) 16.10 (Ubuntu 16.10-0ubuntu0.24.04.1)")
1644                .unwrap(),
1645            16
1646        );
1647
1648        // Standard psql output
1649        assert_eq!(
1650            parse_pg_version_string("psql (PostgreSQL) 17.2").unwrap(),
1651            17
1652        );
1653
1654        // pg_restore output
1655        assert_eq!(
1656            parse_pg_version_string("pg_restore (PostgreSQL) 15.4").unwrap(),
1657            15
1658        );
1659
1660        // Debian-style version
1661        assert_eq!(
1662            parse_pg_version_string("17.2 (Debian 17.2-1.pgdg120+1)").unwrap(),
1663            17
1664        );
1665
1666        // Should fail on invalid input
1667        assert!(parse_pg_version_string("not a version").is_err());
1668        assert!(parse_pg_version_string("version 1.2.3").is_err()); // 1 is < 9
1669        assert!(parse_pg_version_string("").is_err());
1670    }
1671
1672    #[test]
1673    fn test_get_pg_tool_version() {
1674        // This test will only pass if pg_dump is installed
1675        // Skip gracefully if not available
1676        if which("pg_dump").is_ok() {
1677            let version = get_pg_tool_version("pg_dump").unwrap();
1678            assert!(
1679                version >= 12,
1680                "Expected pg_dump version >= 12, got {}",
1681                version
1682            );
1683            assert!(
1684                version <= 99,
1685                "Expected pg_dump version <= 99, got {}",
1686                version
1687            );
1688        }
1689
1690        // Non-existent tool should fail
1691        assert!(get_pg_tool_version("nonexistent_pg_tool_xyz").is_err());
1692    }
1693}