database_replicator/
utils.rs

1// ABOUTME: Utility functions for validation and error handling
2// ABOUTME: Provides input validation, retry logic, and resource cleanup
3
4use anyhow::{bail, Context, Result};
5use std::time::Duration;
6use url::Url;
7use which::which;
8
9/// Get TCP keepalive environment variables for PostgreSQL client tools
10///
11/// Returns environment variables that configure TCP keepalives for external
12/// PostgreSQL tools (pg_dump, pg_restore, psql, pg_dumpall). These prevent
13/// idle connection timeouts when connecting through load balancers like AWS ELB.
14///
15/// Environment variables returned:
16/// - `PGKEEPALIVES=1`: Enable TCP keepalives
17/// - `PGKEEPALIVESIDLE=60`: Send first keepalive after 60 seconds of idle time
18/// - `PGKEEPALIVESINTERVAL=10`: Send subsequent keepalives every 10 seconds
19///
20/// # Returns
21///
22/// A vector of (variable_name, value) tuples to be passed to subprocess commands
23///
24/// # Examples
25///
26/// ```
27/// # use database_replicator::utils::get_keepalive_env_vars;
28/// # use std::process::Command;
29/// let keepalive_vars = get_keepalive_env_vars();
30/// let mut cmd = Command::new("psql");
31/// for (key, value) in keepalive_vars {
32///     cmd.env(key, value);
33/// }
34/// ```
35pub fn get_keepalive_env_vars() -> Vec<(&'static str, &'static str)> {
36    vec![
37        ("PGKEEPALIVES", "1"),
38        ("PGKEEPALIVESIDLE", "60"),
39        ("PGKEEPALIVESINTERVAL", "10"),
40    ]
41}
42
43/// Validate a PostgreSQL connection string
44///
45/// Checks that the connection string has proper format and required components:
46/// - Starts with "postgres://" or "postgresql://"
47/// - Contains user credentials (@ symbol)
48/// - Contains database name (/ separator with at least 3 occurrences)
49///
50/// # Arguments
51///
52/// * `url` - Connection string to validate
53///
54/// # Returns
55///
56/// Returns `Ok(())` if the connection string is valid.
57///
58/// # Errors
59///
60/// Returns an error with helpful message if the connection string is:
61/// - Empty or whitespace only
62/// - Missing proper scheme (postgres:// or postgresql://)
63/// - Missing user credentials (@ symbol)
64/// - Missing database name
65///
66/// # Examples
67///
68/// ```
69/// # use database_replicator::utils::validate_connection_string;
70/// # use anyhow::Result;
71/// # fn example() -> Result<()> {
72/// // Valid connection strings
73/// validate_connection_string("postgresql://user:pass@localhost:5432/mydb")?;
74/// validate_connection_string("postgres://user@host/db")?;
75///
76/// // Invalid - will return error
77/// assert!(validate_connection_string("").is_err());
78/// assert!(validate_connection_string("mysql://localhost/db").is_err());
79/// # Ok(())
80/// # }
81/// ```
82pub fn validate_connection_string(url: &str) -> Result<()> {
83    if url.trim().is_empty() {
84        bail!("Connection string cannot be empty");
85    }
86
87    // Check for common URL schemes
88    if !url.starts_with("postgres://") && !url.starts_with("postgresql://") {
89        bail!(
90            "Invalid connection string format.\n\
91             Expected format: postgresql://user:password@host:port/database\n\
92             Got: {}",
93            url
94        );
95    }
96
97    // Check for minimum required components (user@host/database)
98    if !url.contains('@') {
99        bail!(
100            "Connection string missing user credentials.\n\
101             Expected format: postgresql://user:password@host:port/database"
102        );
103    }
104
105    if !url.contains('/') || url.matches('/').count() < 3 {
106        bail!(
107            "Connection string missing database name.\n\
108             Expected format: postgresql://user:password@host:port/database"
109        );
110    }
111
112    Ok(())
113}
114
115/// Check that required PostgreSQL client tools are available
116///
117/// Verifies that the following tools are installed and in PATH:
118/// - `pg_dump` - For dumping database schema and data
119/// - `pg_dumpall` - For dumping global objects (roles, tablespaces)
120/// - `psql` - For restoring databases
121///
122/// # Returns
123///
124/// Returns `Ok(())` if all required tools are found.
125///
126/// # Errors
127///
128/// Returns an error with installation instructions if any tools are missing.
129///
130/// # Examples
131///
132/// ```
133/// # use database_replicator::utils::check_required_tools;
134/// # use anyhow::Result;
135/// # fn example() -> Result<()> {
136/// // Check if PostgreSQL tools are installed
137/// check_required_tools()?;
138/// # Ok(())
139/// # }
140/// ```
141pub fn check_required_tools() -> Result<()> {
142    let tools = ["pg_dump", "pg_dumpall", "psql"];
143    let mut missing = Vec::new();
144
145    for tool in &tools {
146        if which(tool).is_err() {
147            missing.push(*tool);
148        }
149    }
150
151    if !missing.is_empty() {
152        bail!(
153            "Missing required PostgreSQL client tools: {}\n\
154             \n\
155             Please install PostgreSQL client tools:\n\
156             - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
157             - macOS: brew install postgresql\n\
158             - RHEL/CentOS: sudo yum install postgresql\n\
159             - Windows: Download from https://www.postgresql.org/download/windows/",
160            missing.join(", ")
161        );
162    }
163
164    Ok(())
165}
166
167/// Retry a function with exponential backoff
168///
169/// Executes an async operation with automatic retry on failure. Each retry doubles
170/// the delay (exponential backoff) to handle transient failures gracefully.
171///
172/// # Arguments
173///
174/// * `operation` - Async function to retry (FnMut returning Future\<Output = Result\<T\>\>)
175/// * `max_retries` - Maximum number of retry attempts (0 = no retries, just initial attempt)
176/// * `initial_delay` - Delay before first retry (doubles each subsequent retry)
177///
178/// # Returns
179///
180/// Returns the successful result or the last error after all retries exhausted.
181///
182/// # Examples
183///
184/// ```no_run
185/// # use anyhow::Result;
186/// # use std::time::Duration;
187/// # use database_replicator::utils::retry_with_backoff;
188/// # async fn example() -> Result<()> {
189/// let result = retry_with_backoff(
190///     || async { Ok("success") },
191///     3,  // Try up to 3 times
192///     Duration::from_secs(1)  // Start with 1s delay
193/// ).await?;
194/// # Ok(())
195/// # }
196/// ```
197pub async fn retry_with_backoff<F, Fut, T>(
198    mut operation: F,
199    max_retries: u32,
200    initial_delay: Duration,
201) -> Result<T>
202where
203    F: FnMut() -> Fut,
204    Fut: std::future::Future<Output = Result<T>>,
205{
206    let mut delay = initial_delay;
207    let mut last_error = None;
208
209    for attempt in 0..=max_retries {
210        match operation().await {
211            Ok(result) => return Ok(result),
212            Err(e) => {
213                last_error = Some(e);
214
215                if attempt < max_retries {
216                    tracing::warn!(
217                        "Operation failed (attempt {}/{}), retrying in {:?}...",
218                        attempt + 1,
219                        max_retries + 1,
220                        delay
221                    );
222                    tokio::time::sleep(delay).await;
223                    delay *= 2; // Exponential backoff
224                }
225            }
226        }
227    }
228
229    Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Operation failed after retries")))
230}
231
232/// Retry a subprocess execution with exponential backoff on connection errors
233///
234/// Executes a subprocess command with automatic retry on connection-related failures.
235/// Each retry doubles the delay (exponential backoff) to handle transient connection issues.
236///
237/// Connection errors are detected by checking:
238/// - Non-zero exit codes
239/// - Stderr output containing connection-related error patterns:
240///   - "connection closed"
241///   - "connection refused"
242///   - "could not connect"
243///   - "server closed the connection"
244///   - "timeout"
245///   - "Connection timed out"
246///
247/// # Arguments
248///
249/// * `operation` - Function that executes a Command and returns the exit status
250/// * `max_retries` - Maximum number of retry attempts (0 = no retries, just initial attempt)
251/// * `initial_delay` - Delay before first retry (doubles each subsequent retry)
252/// * `operation_name` - Name of the operation for logging (e.g., "pg_restore", "psql")
253///
254/// # Returns
255///
256/// Returns Ok(()) on success or the last error after all retries exhausted.
257///
258/// # Examples
259///
260/// ```no_run
261/// # use anyhow::Result;
262/// # use std::time::Duration;
263/// # use std::process::Command;
264/// # use database_replicator::utils::retry_subprocess_with_backoff;
265/// # async fn example() -> Result<()> {
266/// retry_subprocess_with_backoff(
267///     || {
268///         let mut cmd = Command::new("psql");
269///         cmd.arg("--version");
270///         cmd.status().map_err(anyhow::Error::from)
271///     },
272///     3,  // Try up to 3 times
273///     Duration::from_secs(1),  // Start with 1s delay
274///     "psql"
275/// ).await?;
276/// # Ok(())
277/// # }
278/// ```
279pub async fn retry_subprocess_with_backoff<F>(
280    mut operation: F,
281    max_retries: u32,
282    initial_delay: Duration,
283    operation_name: &str,
284) -> Result<()>
285where
286    F: FnMut() -> Result<std::process::ExitStatus>,
287{
288    let mut delay = initial_delay;
289    let mut last_error = None;
290
291    for attempt in 0..=max_retries {
292        match operation() {
293            Ok(status) => {
294                if status.success() {
295                    return Ok(());
296                } else {
297                    // Non-zero exit code - check if it's a connection error
298                    // We can't easily capture stderr here, so we'll treat all non-zero
299                    // exit codes as potential connection errors for now
300                    let error = anyhow::anyhow!(
301                        "{} failed with exit code: {}",
302                        operation_name,
303                        status.code().unwrap_or(-1)
304                    );
305                    last_error = Some(error);
306
307                    if attempt < max_retries {
308                        tracing::warn!(
309                            "{} failed (attempt {}/{}), retrying in {:?}...",
310                            operation_name,
311                            attempt + 1,
312                            max_retries + 1,
313                            delay
314                        );
315                        tokio::time::sleep(delay).await;
316                        delay *= 2; // Exponential backoff
317                    }
318                }
319            }
320            Err(e) => {
321                last_error = Some(e);
322
323                if attempt < max_retries {
324                    tracing::warn!(
325                        "{} failed (attempt {}/{}): {}, retrying in {:?}...",
326                        operation_name,
327                        attempt + 1,
328                        max_retries + 1,
329                        last_error.as_ref().unwrap(),
330                        delay
331                    );
332                    tokio::time::sleep(delay).await;
333                    delay *= 2; // Exponential backoff
334                }
335            }
336        }
337    }
338
339    Err(last_error.unwrap_or_else(|| {
340        anyhow::anyhow!("{} failed after {} retries", operation_name, max_retries)
341    }))
342}
343
344/// Validate a PostgreSQL identifier (database name, schema name, etc.)
345///
346/// Validates that an identifier follows PostgreSQL naming rules to prevent SQL injection.
347/// PostgreSQL identifiers must:
348/// - Be 1-63 characters long
349/// - Start with a letter (a-z, A-Z) or underscore (_)
350/// - Contain only letters, digits (0-9), or underscores
351///
352/// # Arguments
353///
354/// * `identifier` - The identifier to validate (database name, schema name, etc.)
355///
356/// # Returns
357///
358/// Returns `Ok(())` if the identifier is valid.
359///
360/// # Errors
361///
362/// Returns an error if the identifier:
363/// - Is empty or whitespace-only
364/// - Exceeds 63 characters
365/// - Starts with an invalid character (digit or special character)
366/// - Contains invalid characters (anything except a-z, A-Z, 0-9, _)
367///
368/// # Security
369///
370/// This function is critical for preventing SQL injection attacks. All database
371/// names, schema names, and table names from untrusted sources MUST be validated
372/// before use in SQL statements.
373///
374/// # Examples
375///
376/// ```
377/// # use database_replicator::utils::validate_postgres_identifier;
378/// # use anyhow::Result;
379/// # fn example() -> Result<()> {
380/// // Valid identifiers
381/// validate_postgres_identifier("mydb")?;
382/// validate_postgres_identifier("my_database")?;
383/// validate_postgres_identifier("_private_db")?;
384///
385/// // Invalid - will return error
386/// assert!(validate_postgres_identifier("123db").is_err());
387/// assert!(validate_postgres_identifier("my-database").is_err());
388/// assert!(validate_postgres_identifier("db\"; DROP TABLE users; --").is_err());
389/// # Ok(())
390/// # }
391/// ```
392pub fn validate_postgres_identifier(identifier: &str) -> Result<()> {
393    // Check for empty or whitespace-only
394    let trimmed = identifier.trim();
395    if trimmed.is_empty() {
396        bail!("Identifier cannot be empty or whitespace-only");
397    }
398
399    // Check length (PostgreSQL limit is 63 characters)
400    if trimmed.len() > 63 {
401        bail!(
402            "Identifier '{}' exceeds maximum length of 63 characters (got {})",
403            sanitize_identifier(trimmed),
404            trimmed.len()
405        );
406    }
407
408    // Get first character
409    let first_char = trimmed.chars().next().unwrap();
410
411    // First character must be a letter or underscore
412    if !first_char.is_ascii_alphabetic() && first_char != '_' {
413        bail!(
414            "Identifier '{}' must start with a letter or underscore, not '{}'",
415            sanitize_identifier(trimmed),
416            first_char
417        );
418    }
419
420    // All characters must be alphanumeric or underscore
421    for (i, c) in trimmed.chars().enumerate() {
422        if !c.is_ascii_alphanumeric() && c != '_' {
423            bail!(
424                "Identifier '{}' contains invalid character '{}' at position {}. \
425                 Only letters, digits, and underscores are allowed",
426                sanitize_identifier(trimmed),
427                if c.is_control() {
428                    format!("\\x{:02x}", c as u32)
429                } else {
430                    c.to_string()
431                },
432                i
433            );
434        }
435    }
436
437    Ok(())
438}
439
440/// Sanitize an identifier (table name, schema name, etc.) for display
441///
442/// Removes control characters and limits length to prevent log injection attacks
443/// and ensure readable error messages.
444///
445/// **Note**: This is for display purposes only. For SQL safety, use parameterized
446/// queries instead.
447///
448/// # Arguments
449///
450/// * `identifier` - The identifier to sanitize (table name, schema name, etc.)
451///
452/// # Returns
453///
454/// Sanitized string with control characters removed and length limited to 100 chars.
455///
456/// # Examples
457///
458/// ```
459/// # use database_replicator::utils::sanitize_identifier;
460/// assert_eq!(sanitize_identifier("normal_table"), "normal_table");
461/// assert_eq!(sanitize_identifier("table\x00name"), "tablename");
462/// assert_eq!(sanitize_identifier("table\nname"), "tablename");
463///
464/// // Length limit
465/// let long_name = "a".repeat(200);
466/// assert_eq!(sanitize_identifier(&long_name).len(), 100);
467/// ```
468pub fn sanitize_identifier(identifier: &str) -> String {
469    // Remove any control characters and limit length for display
470    identifier
471        .chars()
472        .filter(|c| !c.is_control())
473        .take(100)
474        .collect()
475}
476
477/// Quote a PostgreSQL identifier (database, schema, table, column)
478///
479/// Assumes the identifier has already been validated. Escapes embedded quotes
480/// and wraps the identifier in double quotes.
481pub fn quote_ident(identifier: &str) -> String {
482    let mut quoted = String::with_capacity(identifier.len() + 2);
483    quoted.push('"');
484    for ch in identifier.chars() {
485        if ch == '"' {
486            quoted.push('"');
487        }
488        quoted.push(ch);
489    }
490    quoted.push('"');
491    quoted
492}
493
494/// Quote a SQL string literal (for use in SQL statements)
495///
496/// Escapes single quotes by doubling them and wraps the string in single quotes.
497/// Use this for string values in SQL, not for identifiers.
498///
499/// # Examples
500///
501/// ```
502/// use database_replicator::utils::quote_literal;
503/// assert_eq!(quote_literal("hello"), "'hello'");
504/// assert_eq!(quote_literal("it's"), "'it''s'");
505/// assert_eq!(quote_literal(""), "''");
506/// ```
507pub fn quote_literal(value: &str) -> String {
508    let mut quoted = String::with_capacity(value.len() + 2);
509    quoted.push('\'');
510    for ch in value.chars() {
511        if ch == '\'' {
512            quoted.push('\'');
513        }
514        quoted.push(ch);
515    }
516    quoted.push('\'');
517    quoted
518}
519
520/// Quote a MySQL identifier (database, table, column)
521///
522/// MySQL uses backticks for identifier quoting. Escapes embedded backticks
523/// by doubling them.
524///
525/// # Examples
526///
527/// ```
528/// use database_replicator::utils::quote_mysql_ident;
529/// assert_eq!(quote_mysql_ident("users"), "`users`");
530/// assert_eq!(quote_mysql_ident("user`name"), "`user``name`");
531/// ```
532pub fn quote_mysql_ident(identifier: &str) -> String {
533    let mut quoted = String::with_capacity(identifier.len() + 2);
534    quoted.push('`');
535    for ch in identifier.chars() {
536        if ch == '`' {
537            quoted.push('`');
538        }
539        quoted.push(ch);
540    }
541    quoted.push('`');
542    quoted
543}
544
545/// Validate that source and target URLs are different to prevent accidental data loss
546///
547/// Compares two PostgreSQL connection URLs to ensure they point to different databases.
548/// This is critical for preventing data loss from operations like `init --drop-existing`
549/// where using the same URL for source and target would destroy the source data.
550///
551/// # Comparison Strategy
552///
553/// URLs are normalized and compared on:
554/// - Host (case-insensitive)
555/// - Port (defaulting to 5432 if not specified)
556/// - Database name (case-sensitive)
557/// - User (if present)
558///
559/// Query parameters (like SSL settings) are ignored as they don't affect database identity.
560///
561/// # Arguments
562///
563/// * `source_url` - Source database connection string
564/// * `target_url` - Target database connection string
565///
566/// # Returns
567///
568/// Returns `Ok(())` if the URLs point to different databases.
569///
570/// # Errors
571///
572/// Returns an error if:
573/// - The URLs point to the same database (same host, port, database name, and user)
574/// - Either URL is malformed and cannot be parsed
575///
576/// # Examples
577///
578/// ```
579/// # use database_replicator::utils::validate_source_target_different;
580/// # use anyhow::Result;
581/// # fn example() -> Result<()> {
582/// // Valid - different hosts
583/// validate_source_target_different(
584///     "postgresql://user:pass@source.com:5432/db",
585///     "postgresql://user:pass@target.com:5432/db"
586/// )?;
587///
588/// // Valid - different databases
589/// validate_source_target_different(
590///     "postgresql://user:pass@host:5432/db1",
591///     "postgresql://user:pass@host:5432/db2"
592/// )?;
593///
594/// // Invalid - same database
595/// assert!(validate_source_target_different(
596///     "postgresql://user:pass@host:5432/db",
597///     "postgresql://user:pass@host:5432/db"
598/// ).is_err());
599/// # Ok(())
600/// # }
601/// ```
602pub fn validate_source_target_different(source_url: &str, target_url: &str) -> Result<()> {
603    // Parse both URLs to extract components
604    let source_parts = parse_postgres_url(source_url)
605        .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
606    let target_parts = parse_postgres_url(target_url)
607        .with_context(|| format!("Failed to parse target URL: {}", target_url))?;
608
609    // Compare normalized components
610    if source_parts.host == target_parts.host
611        && source_parts.port == target_parts.port
612        && source_parts.database == target_parts.database
613        && source_parts.user == target_parts.user
614    {
615        bail!(
616            "Source and target URLs point to the same database!\n\
617             \n\
618             This would cause DATA LOSS - the target would overwrite the source.\n\
619             \n\
620             Source: {}@{}:{}/{}\n\
621             Target: {}@{}:{}/{}\n\
622             \n\
623             Please ensure source and target are different databases.\n\
624             Common causes:\n\
625             - Copy-paste error in connection strings\n\
626             - Wrong environment variables (e.g., SOURCE_URL == TARGET_URL)\n\
627             - Typo in database name or host",
628            source_parts.user.as_deref().unwrap_or("(no user)"),
629            source_parts.host,
630            source_parts.port,
631            source_parts.database,
632            target_parts.user.as_deref().unwrap_or("(no user)"),
633            target_parts.host,
634            target_parts.port,
635            target_parts.database
636        );
637    }
638
639    Ok(())
640}
641
642/// Parse a PostgreSQL URL into its components
643///
644/// # Arguments
645///
646/// * `url` - PostgreSQL connection URL (postgres:// or postgresql://)
647///
648/// # Returns
649///
650/// Returns a `PostgresUrlParts` struct with normalized components.
651///
652/// # Security
653///
654/// This function extracts passwords from URLs for use with .pgpass files.
655/// Ensure returned values are handled securely and not logged.
656pub fn parse_postgres_url(url: &str) -> Result<PostgresUrlParts> {
657    // Remove scheme
658    let url_without_scheme = url
659        .trim_start_matches("postgres://")
660        .trim_start_matches("postgresql://");
661
662    // Split into base and query params
663    let (base, query_string) = if let Some((b, q)) = url_without_scheme.split_once('?') {
664        (b, Some(q))
665    } else {
666        (url_without_scheme, None)
667    };
668
669    // Parse query parameters into HashMap
670    let mut query_params = std::collections::HashMap::new();
671    if let Some(query) = query_string {
672        for param in query.split('&') {
673            if let Some((key, value)) = param.split_once('=') {
674                query_params.insert(key.to_string(), value.to_string());
675            }
676        }
677    }
678
679    // Parse: [user[:password]@]host[:port]/database
680    let (auth_and_host, database) = base
681        .rsplit_once('/')
682        .ok_or_else(|| anyhow::anyhow!("Missing database name in URL"))?;
683
684    // Parse authentication and host
685    // Use rsplit_once to split from the right, so passwords can contain '@'
686    let (user, password, host_and_port) = if let Some((auth, hp)) = auth_and_host.rsplit_once('@') {
687        // Has authentication
688        let (user, pass) = if let Some((u, p)) = auth.split_once(':') {
689            (Some(u.to_string()), Some(p.to_string()))
690        } else {
691            (Some(auth.to_string()), None)
692        };
693        (user, pass, hp)
694    } else {
695        // No authentication
696        (None, None, auth_and_host)
697    };
698
699    // Parse host and port
700    let (host, port) = if let Some((h, p)) = host_and_port.rsplit_once(':') {
701        // Port specified
702        let port = p
703            .parse::<u16>()
704            .with_context(|| format!("Invalid port number: {}", p))?;
705        (h, port)
706    } else {
707        // Use default PostgreSQL port
708        (host_and_port, 5432)
709    };
710
711    Ok(PostgresUrlParts {
712        host: host.to_lowercase(), // Hostnames are case-insensitive
713        port,
714        database: database.to_string(), // Database names are case-sensitive in PostgreSQL
715        user,
716        password,
717        query_params,
718    })
719}
720
721/// Strip password from PostgreSQL connection URL
722/// Returns a new URL with password removed, preserving all other components
723/// This is useful for storing connection strings in places where passwords should not be visible
724pub fn strip_password_from_url(url: &str) -> Result<String> {
725    let parts = parse_postgres_url(url)?;
726
727    // Reconstruct URL without password
728    let scheme = if url.starts_with("postgresql://") {
729        "postgresql://"
730    } else if url.starts_with("postgres://") {
731        "postgres://"
732    } else {
733        bail!("Invalid PostgreSQL URL scheme");
734    };
735
736    let mut result = String::from(scheme);
737
738    // Add user if present (without password)
739    if let Some(user) = &parts.user {
740        result.push_str(user);
741        result.push('@');
742    }
743
744    // Add host and port
745    result.push_str(&parts.host);
746    result.push(':');
747    result.push_str(&parts.port.to_string());
748
749    // Add database
750    result.push('/');
751    result.push_str(&parts.database);
752
753    // Preserve query parameters if present
754    if let Some(query_start) = url.find('?') {
755        result.push_str(&url[query_start..]);
756    }
757
758    Ok(result)
759}
760
761/// Parsed components of a PostgreSQL connection URL
762#[derive(Debug, PartialEq)]
763pub struct PostgresUrlParts {
764    pub host: String,
765    pub port: u16,
766    pub database: String,
767    pub user: Option<String>,
768    pub password: Option<String>,
769    pub query_params: std::collections::HashMap<String, String>,
770}
771
772impl PostgresUrlParts {
773    /// Convert query parameters to PostgreSQL environment variables
774    ///
775    /// Maps common connection URL query parameters to their corresponding
776    /// PostgreSQL environment variable names. This allows SSL/TLS and other
777    /// connection settings to be passed to pg_dump, pg_dumpall, psql, etc.
778    ///
779    /// # Supported Parameters
780    ///
781    /// - `sslmode` → `PGSSLMODE`
782    /// - `sslcert` → `PGSSLCERT`
783    /// - `sslkey` → `PGSSLKEY`
784    /// - `sslrootcert` → `PGSSLROOTCERT`
785    /// - `channel_binding` → `PGCHANNELBINDING`
786    /// - `connect_timeout` → `PGCONNECT_TIMEOUT`
787    /// - `application_name` → `PGAPPNAME`
788    /// - `client_encoding` → `PGCLIENTENCODING`
789    ///
790    /// # Returns
791    ///
792    /// Vec of (env_var_name, value) pairs to be set as environment variables
793    pub fn to_pg_env_vars(&self) -> Vec<(&'static str, String)> {
794        let mut env_vars = Vec::new();
795
796        // Map query parameters to PostgreSQL environment variables
797        let param_mapping = [
798            ("sslmode", "PGSSLMODE"),
799            ("sslcert", "PGSSLCERT"),
800            ("sslkey", "PGSSLKEY"),
801            ("sslrootcert", "PGSSLROOTCERT"),
802            ("channel_binding", "PGCHANNELBINDING"),
803            ("connect_timeout", "PGCONNECT_TIMEOUT"),
804            ("application_name", "PGAPPNAME"),
805            ("client_encoding", "PGCLIENTENCODING"),
806        ];
807
808        for (param_name, env_var_name) in param_mapping {
809            if let Some(value) = self.query_params.get(param_name) {
810                env_vars.push((env_var_name, value.clone()));
811            }
812        }
813
814        env_vars
815    }
816}
817
818/// Managed .pgpass file for secure password passing to PostgreSQL tools
819///
820/// This struct creates a temporary .pgpass file with secure permissions (0600)
821/// and automatically cleans it up when dropped. PostgreSQL command-line tools
822/// read credentials from this file instead of accepting passwords in URLs,
823/// which prevents command injection vulnerabilities.
824///
825/// # Security
826///
827/// - File permissions are set to 0600 (owner read/write only)
828/// - File is automatically removed on Drop
829/// - Credentials are never passed on command line
830///
831/// # Format
832///
833/// .pgpass file format: hostname:port:database:username:password
834/// Wildcards (*) are used for maximum compatibility
835///
836/// # Examples
837///
838/// ```no_run
839/// # use database_replicator::utils::{PgPassFile, parse_postgres_url};
840/// # use anyhow::Result;
841/// # fn example() -> Result<()> {
842/// let url = "postgresql://user:pass@localhost:5432/mydb";
843/// let parts = parse_postgres_url(url)?;
844/// let pgpass = PgPassFile::new(&parts)?;
845///
846/// // Use pgpass.path() with PGPASSFILE environment variable
847/// // File is automatically cleaned up when pgpass goes out of scope
848/// # Ok(())
849/// # }
850/// ```
851pub struct PgPassFile {
852    path: std::path::PathBuf,
853}
854
855impl PgPassFile {
856    /// Create a new .pgpass file with credentials from URL parts
857    ///
858    /// # Arguments
859    ///
860    /// * `parts` - Parsed PostgreSQL URL components
861    ///
862    /// # Returns
863    ///
864    /// Returns a PgPassFile that will be automatically cleaned up on Drop
865    ///
866    /// # Errors
867    ///
868    /// Returns an error if the file cannot be created or permissions cannot be set
869    pub fn new(parts: &PostgresUrlParts) -> Result<Self> {
870        use std::fs;
871        use std::io::Write;
872
873        // Create temp file with secure name
874        let temp_dir = std::env::temp_dir();
875        let random: u32 = rand::random();
876        let filename = format!("pgpass-{:08x}", random);
877        let path = temp_dir.join(filename);
878
879        // Write .pgpass entry
880        // Format: hostname:port:database:username:password
881        let username = parts.user.as_deref().unwrap_or("*");
882        let password = parts.password.as_deref().unwrap_or("");
883        let entry = format!(
884            "{}:{}:{}:{}:{}\n",
885            parts.host, parts.port, parts.database, username, password
886        );
887
888        let mut file = fs::File::create(&path)
889            .with_context(|| format!("Failed to create .pgpass file at {}", path.display()))?;
890
891        file.write_all(entry.as_bytes())
892            .with_context(|| format!("Failed to write to .pgpass file at {}", path.display()))?;
893
894        // Set secure permissions (0600) - owner read/write only
895        #[cfg(unix)]
896        {
897            use std::os::unix::fs::PermissionsExt;
898            let permissions = fs::Permissions::from_mode(0o600);
899            fs::set_permissions(&path, permissions).with_context(|| {
900                format!(
901                    "Failed to set permissions on .pgpass file at {}",
902                    path.display()
903                )
904            })?;
905        }
906
907        // On Windows, .pgpass is stored in %APPDATA%\postgresql\pgpass.conf
908        // but for our temporary use case, we'll just use a temp file
909        // PostgreSQL on Windows also checks permissions but less strictly
910
911        Ok(Self { path })
912    }
913
914    /// Get the path to the .pgpass file
915    ///
916    /// Use this with the PGPASSFILE environment variable when running
917    /// PostgreSQL command-line tools
918    pub fn path(&self) -> &std::path::Path {
919        &self.path
920    }
921}
922
923impl Drop for PgPassFile {
924    fn drop(&mut self) {
925        // Best effort cleanup - don't panic if removal fails
926        let _ = std::fs::remove_file(&self.path);
927    }
928}
929
930/// Create a managed temporary directory with explicit cleanup support
931///
932/// Creates a temporary directory with a timestamped name that can be cleaned up
933/// even if the process is killed with SIGKILL. Unlike `TempDir::new()` which
934/// relies on the Drop trait, this function creates named directories that can
935/// be cleaned up on next process startup.
936///
937/// Directory naming format: `postgres-seren-replicator-{timestamp}-{random}`
938/// Example: `postgres-seren-replicator-20250106-120534-a3b2c1d4`
939///
940/// # Returns
941///
942/// Returns the path to the created temporary directory.
943///
944/// # Errors
945///
946/// Returns an error if the directory cannot be created.
947///
948/// # Examples
949///
950/// ```no_run
951/// # use database_replicator::utils::create_managed_temp_dir;
952/// # use anyhow::Result;
953/// # fn example() -> Result<()> {
954/// let temp_path = create_managed_temp_dir()?;
955/// println!("Using temp directory: {}", temp_path.display());
956/// // ... do work ...
957/// // Cleanup happens automatically on next startup via cleanup_stale_temp_dirs()
958/// # Ok(())
959/// # }
960/// ```
961pub fn create_managed_temp_dir() -> Result<std::path::PathBuf> {
962    use std::fs;
963    use std::time::SystemTime;
964
965    let system_temp = std::env::temp_dir();
966
967    // Generate timestamp for directory name
968    let timestamp = SystemTime::now()
969        .duration_since(SystemTime::UNIX_EPOCH)
970        .unwrap()
971        .as_secs();
972
973    // Generate random suffix for uniqueness
974    let random: u32 = rand::random();
975
976    // Create directory name with timestamp and random suffix
977    let dir_name = format!("postgres-seren-replicator-{}-{:08x}", timestamp, random);
978
979    let temp_path = system_temp.join(dir_name);
980
981    // Create the directory
982    fs::create_dir_all(&temp_path)
983        .with_context(|| format!("Failed to create temp directory at {}", temp_path.display()))?;
984
985    tracing::debug!("Created managed temp directory: {}", temp_path.display());
986
987    Ok(temp_path)
988}
989
990/// Clean up stale temporary directories from previous runs
991///
992/// Removes temporary directories created by `create_managed_temp_dir()` that are
993/// older than the specified age. This should be called on process startup to clean
994/// up directories left behind by processes killed with SIGKILL.
995///
996/// Only directories matching the pattern `postgres-seren-replicator-*` are removed.
997///
998/// # Arguments
999///
1000/// * `max_age_secs` - Maximum age in seconds before a directory is considered stale
1001///   (recommended: 86400 for 24 hours)
1002///
1003/// # Returns
1004///
1005/// Returns the number of directories cleaned up.
1006///
1007/// # Errors
1008///
1009/// Returns an error if the system temp directory cannot be read. Individual
1010/// directory removal errors are logged but don't fail the entire operation.
1011///
1012/// # Examples
1013///
1014/// ```no_run
1015/// # use database_replicator::utils::cleanup_stale_temp_dirs;
1016/// # use anyhow::Result;
1017/// # fn example() -> Result<()> {
1018/// // Clean up temp directories older than 24 hours
1019/// let cleaned = cleanup_stale_temp_dirs(86400)?;
1020/// println!("Cleaned up {} stale temp directories", cleaned);
1021/// # Ok(())
1022/// # }
1023/// ```
1024pub fn cleanup_stale_temp_dirs(max_age_secs: u64) -> Result<usize> {
1025    use std::fs;
1026    use std::time::SystemTime;
1027
1028    let system_temp = std::env::temp_dir();
1029    let now = SystemTime::now();
1030    let mut cleaned_count = 0;
1031
1032    // Read all entries in system temp directory
1033    let entries = fs::read_dir(&system_temp).with_context(|| {
1034        format!(
1035            "Failed to read system temp directory: {}",
1036            system_temp.display()
1037        )
1038    })?;
1039
1040    for entry in entries.flatten() {
1041        let path = entry.path();
1042
1043        // Only process directories matching our naming pattern
1044        if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1045            if !name.starts_with("postgres-seren-replicator-") {
1046                continue;
1047            }
1048
1049            // Check directory age
1050            match entry.metadata() {
1051                Ok(metadata) => {
1052                    if let Ok(modified) = metadata.modified() {
1053                        if let Ok(age) = now.duration_since(modified) {
1054                            if age.as_secs() > max_age_secs {
1055                                // Directory is stale, remove it
1056                                match fs::remove_dir_all(&path) {
1057                                    Ok(_) => {
1058                                        tracing::info!(
1059                                            "Cleaned up stale temp directory: {} (age: {}s)",
1060                                            path.display(),
1061                                            age.as_secs()
1062                                        );
1063                                        cleaned_count += 1;
1064                                    }
1065                                    Err(e) => {
1066                                        tracing::warn!(
1067                                            "Failed to remove stale temp directory {}: {}",
1068                                            path.display(),
1069                                            e
1070                                        );
1071                                    }
1072                                }
1073                            }
1074                        }
1075                    }
1076                }
1077                Err(e) => {
1078                    tracing::warn!(
1079                        "Failed to get metadata for temp directory {}: {}",
1080                        path.display(),
1081                        e
1082                    );
1083                }
1084            }
1085        }
1086    }
1087
1088    if cleaned_count > 0 {
1089        tracing::info!(
1090            "Cleaned up {} stale temp directory(ies) older than {} seconds",
1091            cleaned_count,
1092            max_age_secs
1093        );
1094    }
1095
1096    Ok(cleaned_count)
1097}
1098
1099/// Parse a SerenDB URL to extract project, branch, and database IDs
1100///
1101/// SerenDB URLs have the format: postgresql://user:pass@<database-id>.<branch-id>.<project-id>.serendb.com:5432/db
1102/// This function extracts the three UUIDs from the hostname.
1103///
1104/// # Arguments
1105///
1106/// * `url` - The SerenDB PostgreSQL connection string
1107///
1108/// # Returns
1109///
1110/// An `Option` containing a tuple of `(project_id, branch_id, database_id)` if the
1111/// URL is a valid SerenDB target and contains the expected ID format, otherwise `None`.
1112pub fn parse_serendb_url_for_ids(url: &str) -> Option<(String, String, String)> {
1113    let parts = parse_postgres_url(url).ok()?;
1114
1115    if !is_serendb_target(url) {
1116        return None;
1117    }
1118
1119    // Hostname format: <database-id>.<branch-id>.<project-id>.serendb.com
1120    // Or with custom subdomains: <database-id>.<branch-id>.<project-id>.<custom>.serendb.com
1121    // We want the last three parts before .serendb.com
1122    let host_parts: Vec<&str> = parts.host.split('.').collect();
1123
1124    if host_parts.len() < 4 {
1125        return None; // Not enough parts for SerenDB ID format
1126    }
1127
1128    let num_host_parts = host_parts.len();
1129    let database_id = host_parts[num_host_parts - 4].to_string();
1130    let branch_id = host_parts[num_host_parts - 3].to_string();
1131    let project_id = host_parts[num_host_parts - 2].to_string();
1132
1133    // Basic UUID format validation (optional but good for robustness)
1134    // A real UUID check would be more extensive, but string length is a good start
1135    if database_id.len() == 36 && branch_id.len() == 36 && project_id.len() == 36 {
1136        Some((project_id, branch_id, database_id))
1137    } else {
1138        None
1139    }
1140}
1141
1142/// Remove a managed temporary directory
1143///
1144/// Explicitly removes a temporary directory created by `create_managed_temp_dir()`.
1145/// This should be called when the directory is no longer needed.
1146///
1147/// # Arguments
1148///
1149/// * `path` - Path to the temporary directory to remove
1150///
1151/// # Errors
1152///
1153/// Returns an error if the directory cannot be removed.
1154///
1155/// # Examples
1156///
1157/// ```no_run
1158/// # use database_replicator::utils::{create_managed_temp_dir, remove_managed_temp_dir};
1159/// # use anyhow::Result;
1160/// # fn example() -> Result<()> {
1161/// let temp_path = create_managed_temp_dir()?;
1162/// // ... do work ...
1163/// remove_managed_temp_dir(&temp_path)?;
1164/// # Ok(())
1165/// # }
1166/// ```
1167pub fn remove_managed_temp_dir(path: &std::path::Path) -> Result<()> {
1168    use std::fs;
1169
1170    // Verify this is one of our temp directories (safety check)
1171    if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1172        if !name.starts_with("postgres-seren-replicator-") {
1173            bail!(
1174                "Refusing to remove directory that doesn't match our naming pattern: {}",
1175                path.display()
1176            );
1177        }
1178    } else {
1179        bail!("Invalid temp directory path: {}", path.display());
1180    }
1181
1182    tracing::debug!("Removing managed temp directory: {}", path.display());
1183
1184    fs::remove_dir_all(path)
1185        .with_context(|| format!("Failed to remove temp directory at {}", path.display()))?;
1186
1187    Ok(())
1188}
1189
1190/// Replace the database name in a connection string URL
1191///
1192/// This is used internally by SerenDB to provide a generic connection string
1193/// which then needs to be specialized for a particular database.
1194///
1195/// # Arguments
1196///
1197/// * `url` - The connection string URL (e.g., postgresql://host/template_db)
1198/// * `new_db` - The new database name to insert into the URL
1199///
1200/// # Returns
1201///
1202/// A new URL string with the database name replaced.
1203///
1204/// # Errors
1205///
1206/// Returns an error if the URL is invalid and cannot be parsed.
1207pub fn replace_database_in_connection_string(url: &str, new_db: &str) -> Result<String> {
1208    let mut parsed = Url::parse(url).context("Invalid connection string URL")?;
1209    parsed.set_path(&format!("/{}", new_db));
1210
1211    Ok(parsed.to_string())
1212}
1213
1214/// Check if a PostgreSQL URL points to a SerenDB instance
1215///
1216/// SerenDB hosts have domains ending with `.serendb.com`
1217///
1218/// # Arguments
1219///
1220/// * `url` - PostgreSQL connection string to check
1221///
1222/// # Returns
1223///
1224/// Returns `true` if the URL points to a SerenDB host.
1225///
1226/// # Examples
1227///
1228/// ```
1229/// use database_replicator::utils::is_serendb_target;
1230///
1231/// assert!(is_serendb_target("postgresql://user:pass@db.serendb.com/mydb"));
1232/// assert!(is_serendb_target("postgresql://user:pass@cluster-123.console.serendb.com/mydb"));
1233/// assert!(!is_serendb_target("postgresql://user:pass@localhost/mydb"));
1234/// assert!(!is_serendb_target("postgresql://user:pass@rds.amazonaws.com/mydb"));
1235/// ```
1236pub fn is_serendb_target(url: &str) -> bool {
1237    match parse_postgres_url(url) {
1238        Ok(parts) => parts.host.ends_with(".serendb.com") || parts.host == "serendb.com",
1239        Err(_) => false,
1240    }
1241}
1242
1243/// Get the major version of a PostgreSQL client tool (pg_dump, psql, etc.)
1244///
1245/// Executes `<tool> --version` and parses the output.
1246///
1247/// # Arguments
1248///
1249/// * `tool` - Name of the tool (e.g., "pg_dump", "psql")
1250///
1251/// # Returns
1252///
1253/// The major version number (e.g., 16 for pg_dump 16.10)
1254///
1255/// # Errors
1256///
1257/// Returns an error if:
1258/// - Tool is not found in PATH
1259/// - Tool execution fails
1260/// - Version output cannot be parsed
1261///
1262/// # Examples
1263///
1264/// ```no_run
1265/// use database_replicator::utils::get_pg_tool_version;
1266/// use anyhow::Result;
1267///
1268/// fn example() -> Result<()> {
1269///     let version = get_pg_tool_version("pg_dump")?;
1270///     println!("pg_dump major version: {}", version); // e.g., 16
1271///     Ok(())
1272/// }
1273/// ```
1274pub fn get_pg_tool_version(tool: &str) -> Result<u32> {
1275    use std::process::Command;
1276
1277    let path = which(tool).with_context(|| format!("{} not found in PATH", tool))?;
1278
1279    let output = Command::new(&path)
1280        .arg("--version")
1281        .output()
1282        .with_context(|| format!("Failed to execute {} --version", tool))?;
1283
1284    let version_str = String::from_utf8_lossy(&output.stdout);
1285    parse_pg_version_string(&version_str)
1286}
1287
1288/// Parse major version from PostgreSQL version string
1289///
1290/// Handles formats like:
1291/// - "pg_dump (PostgreSQL) 16.10 (Ubuntu 16.10-0ubuntu0.24.04.1)"
1292/// - "psql (PostgreSQL) 17.2"
1293/// - "17.2 (Debian 17.2-1.pgdg120+1)"
1294///
1295/// # Arguments
1296///
1297/// * `version_str` - Version string output from a PostgreSQL tool
1298///
1299/// # Returns
1300///
1301/// The major version number (e.g., 16, 17)
1302///
1303/// # Errors
1304///
1305/// Returns an error if the version cannot be parsed.
1306pub fn parse_pg_version_string(version_str: &str) -> Result<u32> {
1307    // Look for version pattern: major.minor
1308    for word in version_str.split_whitespace() {
1309        if let Some(major_str) = word.split('.').next() {
1310            if let Ok(major) = major_str.parse::<u32>() {
1311                // Valid PostgreSQL versions are between 9 and 99
1312                if (9..=99).contains(&major) {
1313                    return Ok(major);
1314                }
1315            }
1316        }
1317    }
1318    bail!("Could not parse PostgreSQL version from: {}", version_str)
1319}
1320
1321/// Get available system memory in bytes
1322///
1323/// Cross-platform function that works on Linux, macOS, and Windows.
1324/// Returns the amount of memory available for use by applications.
1325///
1326/// # Platform Details
1327///
1328/// - **Linux**: Reads `MemAvailable` from `/proc/meminfo`
1329/// - **macOS**: Uses `sysctl hw.memsize` for total memory, estimates available
1330/// - **Windows**: Uses `GlobalMemoryStatusEx` API
1331///
1332/// # Returns
1333///
1334/// Available memory in bytes, or an error if detection fails.
1335///
1336/// # Examples
1337///
1338/// ```no_run
1339/// use database_replicator::utils::get_available_memory;
1340///
1341/// let available = get_available_memory().unwrap();
1342/// println!("Available memory: {} MB", available / 1024 / 1024);
1343/// ```
1344pub fn get_available_memory() -> Result<u64> {
1345    #[cfg(target_os = "linux")]
1346    {
1347        get_available_memory_linux()
1348    }
1349
1350    #[cfg(target_os = "macos")]
1351    {
1352        get_available_memory_macos()
1353    }
1354
1355    #[cfg(target_os = "windows")]
1356    {
1357        get_available_memory_windows()
1358    }
1359
1360    #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
1361    {
1362        // Fallback: assume 1GB available for unknown platforms
1363        tracing::warn!("Memory detection not supported on this platform, assuming 1GB available");
1364        Ok(1024 * 1024 * 1024)
1365    }
1366}
1367
1368#[cfg(target_os = "linux")]
1369fn get_available_memory_linux() -> Result<u64> {
1370    use std::fs;
1371
1372    let meminfo = fs::read_to_string("/proc/meminfo").context("Failed to read /proc/meminfo")?;
1373
1374    // Try MemAvailable first (more accurate, available since Linux 3.14)
1375    for line in meminfo.lines() {
1376        if line.starts_with("MemAvailable:") {
1377            let parts: Vec<&str> = line.split_whitespace().collect();
1378            if parts.len() >= 2 {
1379                let kb: u64 = parts[1]
1380                    .parse()
1381                    .context("Failed to parse MemAvailable value")?;
1382                return Ok(kb * 1024); // Convert KB to bytes
1383            }
1384        }
1385    }
1386
1387    // Fallback: MemFree + Buffers + Cached (less accurate but works on older kernels)
1388    let mut mem_free: u64 = 0;
1389    let mut buffers: u64 = 0;
1390    let mut cached: u64 = 0;
1391
1392    for line in meminfo.lines() {
1393        let parts: Vec<&str> = line.split_whitespace().collect();
1394        if parts.len() >= 2 {
1395            let value: u64 = parts[1].parse().unwrap_or(0);
1396            if line.starts_with("MemFree:") {
1397                mem_free = value;
1398            } else if line.starts_with("Buffers:") {
1399                buffers = value;
1400            } else if line.starts_with("Cached:") && !line.starts_with("SwapCached:") {
1401                cached = value;
1402            }
1403        }
1404    }
1405
1406    Ok((mem_free + buffers + cached) * 1024) // Convert KB to bytes
1407}
1408
1409#[cfg(target_os = "macos")]
1410fn get_available_memory_macos() -> Result<u64> {
1411    use std::process::Command;
1412
1413    // Get total physical memory using sysctl
1414    let output = Command::new("sysctl")
1415        .args(["-n", "hw.memsize"])
1416        .output()
1417        .context("Failed to execute sysctl")?;
1418
1419    let total_str = String::from_utf8_lossy(&output.stdout);
1420    let total_bytes: u64 = total_str
1421        .trim()
1422        .parse()
1423        .context("Failed to parse hw.memsize")?;
1424
1425    // Get actual page size using sysctl hw.pagesize
1426    // Intel Macs use 4KB (4096), Apple Silicon uses 16KB (16384)
1427    let page_size: u64 = {
1428        let page_output = Command::new("sysctl")
1429            .args(["-n", "hw.pagesize"])
1430            .output()
1431            .context("Failed to execute sysctl hw.pagesize")?;
1432
1433        let page_str = String::from_utf8_lossy(&page_output.stdout);
1434        page_str.trim().parse().unwrap_or(4096) // Default to 4KB if parsing fails
1435    };
1436
1437    // Get free pages using vm_stat
1438    let vm_output = Command::new("vm_stat")
1439        .output()
1440        .context("Failed to execute vm_stat")?;
1441
1442    let vm_stat = String::from_utf8_lossy(&vm_output.stdout);
1443
1444    // Parse free and inactive pages from vm_stat
1445    let mut pages_free: u64 = 0;
1446    let mut pages_inactive: u64 = 0;
1447    let mut pages_purgeable: u64 = 0;
1448
1449    for line in vm_stat.lines() {
1450        if line.starts_with("Pages free:") {
1451            pages_free = parse_vm_stat_value(line);
1452        } else if line.starts_with("Pages inactive:") {
1453            pages_inactive = parse_vm_stat_value(line);
1454        } else if line.starts_with("Pages purgeable:") {
1455            pages_purgeable = parse_vm_stat_value(line);
1456        }
1457    }
1458
1459    // Available = free + inactive + purgeable (conservative estimate)
1460    let available = (pages_free + pages_inactive + pages_purgeable) * page_size;
1461
1462    // If vm_stat parsing failed, estimate 50% of total as available
1463    if available == 0 {
1464        tracing::debug!("vm_stat parsing returned 0, estimating 50% of total memory as available");
1465        return Ok(total_bytes / 2);
1466    }
1467
1468    Ok(available)
1469}
1470
1471#[cfg(target_os = "macos")]
1472fn parse_vm_stat_value(line: &str) -> u64 {
1473    // Format: "Pages free:    12345."
1474    line.split(':')
1475        .nth(1)
1476        .and_then(|s| s.trim().trim_end_matches('.').parse().ok())
1477        .unwrap_or(0)
1478}
1479
1480#[cfg(target_os = "windows")]
1481fn get_available_memory_windows() -> Result<u64> {
1482    use std::mem;
1483
1484    // MEMORYSTATUSEX structure
1485    #[repr(C)]
1486    #[allow(non_snake_case)]
1487    struct MEMORYSTATUSEX {
1488        dwLength: u32,
1489        dwMemoryLoad: u32,
1490        ullTotalPhys: u64,
1491        ullAvailPhys: u64,
1492        ullTotalPageFile: u64,
1493        ullAvailPageFile: u64,
1494        ullTotalVirtual: u64,
1495        ullAvailVirtual: u64,
1496        ullAvailExtendedVirtual: u64,
1497    }
1498
1499    #[link(name = "kernel32")]
1500    extern "system" {
1501        fn GlobalMemoryStatusEx(lpBuffer: *mut MEMORYSTATUSEX) -> i32;
1502    }
1503
1504    let mut mem_status: MEMORYSTATUSEX = unsafe { mem::zeroed() };
1505    mem_status.dwLength = mem::size_of::<MEMORYSTATUSEX>() as u32;
1506
1507    let result = unsafe { GlobalMemoryStatusEx(&mut mem_status) };
1508
1509    if result == 0 {
1510        bail!("GlobalMemoryStatusEx failed");
1511    }
1512
1513    Ok(mem_status.ullAvailPhys)
1514}
1515
1516/// Calculate optimal batch size based on available system memory
1517///
1518/// Automatically determines an appropriate batch size for the sync daemon
1519/// based on the amount of available system memory. This prevents OOM errors
1520/// on memory-constrained instances while maximizing throughput on larger ones.
1521///
1522/// # Memory Model
1523///
1524/// The calculation assumes:
1525/// - Each row consumes approximately 2KB in memory (conservative estimate for wide tables)
1526/// - We should use at most 25% of available memory for batch processing
1527/// - Minimum batch size: 1,000 rows (for very constrained systems)
1528/// - Maximum batch size: 50,000 rows (diminishing returns beyond this)
1529///
1530/// # Returns
1531///
1532/// Optimal batch size in number of rows, or default of 10,000 if detection fails.
1533///
1534/// # Examples
1535///
1536/// ```no_run
1537/// use database_replicator::utils::calculate_optimal_batch_size;
1538///
1539/// let batch_size = calculate_optimal_batch_size();
1540/// println!("Using batch size: {}", batch_size);
1541/// // On t3.nano (512MB): ~1,000-2,000
1542/// // On t3.small (2GB): ~10,000
1543/// // On t3.large (8GB): ~50,000 (capped)
1544/// ```
1545pub fn calculate_optimal_batch_size() -> usize {
1546    const BYTES_PER_ROW: u64 = 2048; // Conservative: 2KB per row
1547    const MEMORY_FRACTION: f64 = 0.25; // Use at most 25% of available memory
1548    const MIN_BATCH_SIZE: usize = 1_000;
1549    const MAX_BATCH_SIZE: usize = 50_000;
1550    const DEFAULT_BATCH_SIZE: usize = 10_000;
1551
1552    match get_available_memory() {
1553        Ok(available_bytes) => {
1554            // Calculate how many rows we can fit in 25% of available memory
1555            let usable_bytes = (available_bytes as f64 * MEMORY_FRACTION) as u64;
1556            let calculated_size = (usable_bytes / BYTES_PER_ROW) as usize;
1557
1558            // Clamp to min/max range
1559            let batch_size = calculated_size.clamp(MIN_BATCH_SIZE, MAX_BATCH_SIZE);
1560
1561            tracing::info!(
1562                "Auto-detected batch size: {} (available memory: {} MB)",
1563                batch_size,
1564                available_bytes / 1024 / 1024
1565            );
1566
1567            batch_size
1568        }
1569        Err(e) => {
1570            tracing::warn!(
1571                "Failed to detect available memory: {}. Using default batch size: {}",
1572                e,
1573                DEFAULT_BATCH_SIZE
1574            );
1575            DEFAULT_BATCH_SIZE
1576        }
1577    }
1578}
1579
1580#[cfg(test)]
1581mod tests {
1582    use super::*;
1583
1584    #[test]
1585    fn test_get_available_memory() {
1586        // This should work on all supported platforms
1587        let result = get_available_memory();
1588
1589        // Should succeed (may fail in very restricted environments)
1590        if let Ok(available) = result {
1591            // Sanity check: should be at least 10MB, less than 1TB
1592            assert!(
1593                available > 10 * 1024 * 1024,
1594                "Available memory too low: {}",
1595                available
1596            );
1597            assert!(
1598                available < 1024 * 1024 * 1024 * 1024,
1599                "Available memory too high: {}",
1600                available
1601            );
1602        }
1603    }
1604
1605    #[test]
1606    fn test_calculate_optimal_batch_size() {
1607        let batch_size = calculate_optimal_batch_size();
1608
1609        // Should be within expected range
1610        assert!(batch_size >= 1_000, "Batch size too small: {}", batch_size);
1611        assert!(batch_size <= 50_000, "Batch size too large: {}", batch_size);
1612    }
1613
1614    #[test]
1615    fn test_validate_connection_string_valid() {
1616        assert!(validate_connection_string("postgresql://user:pass@localhost:5432/dbname").is_ok());
1617        assert!(validate_connection_string("postgres://user@host/db").is_ok());
1618    }
1619
1620    #[test]
1621    fn test_check_required_tools() {
1622        // This test will pass if PostgreSQL client tools are installed
1623        // It will fail (appropriately) if they're not installed
1624        let result = check_required_tools();
1625
1626        // On systems with PostgreSQL installed, this should pass
1627        // On systems without it, we expect a specific error message
1628        if let Err(err) = result {
1629            let err_msg = err.to_string();
1630            assert!(err_msg.contains("Missing required PostgreSQL client tools"));
1631            assert!(
1632                err_msg.contains("pg_dump")
1633                    || err_msg.contains("pg_dumpall")
1634                    || err_msg.contains("psql")
1635            );
1636        }
1637    }
1638
1639    #[test]
1640    fn test_validate_connection_string_invalid() {
1641        assert!(validate_connection_string("").is_err());
1642        assert!(validate_connection_string("   ").is_err());
1643        assert!(validate_connection_string("mysql://localhost/db").is_err());
1644        assert!(validate_connection_string("postgresql://localhost").is_err());
1645        assert!(validate_connection_string("postgresql://localhost/db").is_err());
1646        // Missing user
1647    }
1648
1649    #[test]
1650    fn test_sanitize_identifier() {
1651        assert_eq!(sanitize_identifier("normal_table"), "normal_table");
1652        assert_eq!(sanitize_identifier("table\x00name"), "tablename");
1653        assert_eq!(sanitize_identifier("table\nname"), "tablename");
1654
1655        // Test length limit
1656        let long_name = "a".repeat(200);
1657        assert_eq!(sanitize_identifier(&long_name).len(), 100);
1658    }
1659
1660    #[tokio::test]
1661    async fn test_retry_with_backoff_success() {
1662        let mut attempts = 0;
1663        let result = retry_with_backoff(
1664            || {
1665                attempts += 1;
1666                async move {
1667                    if attempts < 3 {
1668                        anyhow::bail!("Temporary failure")
1669                    } else {
1670                        Ok("Success")
1671                    }
1672                }
1673            },
1674            5,
1675            Duration::from_millis(10),
1676        )
1677        .await;
1678
1679        assert!(result.is_ok());
1680        assert_eq!(result.unwrap(), "Success");
1681        assert_eq!(attempts, 3);
1682    }
1683
1684    #[tokio::test]
1685    async fn test_retry_with_backoff_failure() {
1686        let mut attempts = 0;
1687        let result: Result<&str> = retry_with_backoff(
1688            || {
1689                attempts += 1;
1690                async move { anyhow::bail!("Permanent failure") }
1691            },
1692            2,
1693            Duration::from_millis(10),
1694        )
1695        .await;
1696
1697        assert!(result.is_err());
1698        assert_eq!(attempts, 3); // Initial + 2 retries
1699    }
1700
1701    #[test]
1702    fn test_validate_source_target_different_valid() {
1703        // Different hosts
1704        assert!(validate_source_target_different(
1705            "postgresql://user:pass@source.com:5432/db",
1706            "postgresql://user:pass@target.com:5432/db"
1707        )
1708        .is_ok());
1709
1710        // Different databases on same host
1711        assert!(validate_source_target_different(
1712            "postgresql://user:pass@host:5432/db1",
1713            "postgresql://user:pass@host:5432/db2"
1714        )
1715        .is_ok());
1716
1717        // Different ports on same host
1718        assert!(validate_source_target_different(
1719            "postgresql://user:pass@host:5432/db",
1720            "postgresql://user:pass@host:5433/db"
1721        )
1722        .is_ok());
1723
1724        // Different users on same host/db (edge case but allowed)
1725        assert!(validate_source_target_different(
1726            "postgresql://user1:pass@host:5432/db",
1727            "postgresql://user2:pass@host:5432/db"
1728        )
1729        .is_ok());
1730    }
1731
1732    #[test]
1733    fn test_validate_source_target_different_invalid() {
1734        // Exact same URL
1735        assert!(validate_source_target_different(
1736            "postgresql://user:pass@host:5432/db",
1737            "postgresql://user:pass@host:5432/db"
1738        )
1739        .is_err());
1740
1741        // Same URL with different scheme (postgres vs postgresql)
1742        assert!(validate_source_target_different(
1743            "postgres://user:pass@host:5432/db",
1744            "postgresql://user:pass@host:5432/db"
1745        )
1746        .is_err());
1747
1748        // Same URL with default port vs explicit port
1749        assert!(validate_source_target_different(
1750            "postgresql://user:pass@host/db",
1751            "postgresql://user:pass@host:5432/db"
1752        )
1753        .is_err());
1754
1755        // Same URL with different query parameters (still same database)
1756        assert!(validate_source_target_different(
1757            "postgresql://user:pass@host:5432/db?sslmode=require",
1758            "postgresql://user:pass@host:5432/db?sslmode=prefer"
1759        )
1760        .is_err());
1761
1762        // Same host with different case (hostnames are case-insensitive)
1763        assert!(validate_source_target_different(
1764            "postgresql://user:pass@HOST.COM:5432/db",
1765            "postgresql://user:pass@host.com:5432/db"
1766        )
1767        .is_err());
1768    }
1769
1770    #[test]
1771    fn test_parse_postgres_url() {
1772        // Full URL with all components including password
1773        let parts = parse_postgres_url("postgresql://myuser:mypass@localhost:5432/mydb").unwrap();
1774        assert_eq!(parts.host, "localhost");
1775        assert_eq!(parts.port, 5432);
1776        assert_eq!(parts.database, "mydb");
1777        assert_eq!(parts.user, Some("myuser".to_string()));
1778        assert_eq!(parts.password, Some("mypass".to_string()));
1779
1780        // URL without port (should default to 5432)
1781        let parts = parse_postgres_url("postgresql://user:pass@host/db").unwrap();
1782        assert_eq!(parts.host, "host");
1783        assert_eq!(parts.port, 5432);
1784        assert_eq!(parts.database, "db");
1785        assert_eq!(parts.user, Some("user".to_string()));
1786        assert_eq!(parts.password, Some("pass".to_string()));
1787
1788        // URL with user but no password
1789        let parts = parse_postgres_url("postgresql://user@host/db").unwrap();
1790        assert_eq!(parts.host, "host");
1791        assert_eq!(parts.user, Some("user".to_string()));
1792        assert_eq!(parts.password, None);
1793
1794        // URL without authentication
1795        let parts = parse_postgres_url("postgresql://host:5433/db").unwrap();
1796        assert_eq!(parts.host, "host");
1797        assert_eq!(parts.port, 5433);
1798        assert_eq!(parts.database, "db");
1799        assert_eq!(parts.user, None);
1800        assert_eq!(parts.password, None);
1801
1802        // URL with query parameters
1803        let parts = parse_postgres_url("postgresql://user:pass@host/db?sslmode=require").unwrap();
1804        assert_eq!(parts.host, "host");
1805        assert_eq!(parts.database, "db");
1806        assert_eq!(parts.password, Some("pass".to_string()));
1807
1808        // URL with postgres:// scheme (alternative)
1809        let parts = parse_postgres_url("postgres://user:pass@host/db").unwrap();
1810        assert_eq!(parts.host, "host");
1811        assert_eq!(parts.database, "db");
1812        assert_eq!(parts.password, Some("pass".to_string()));
1813
1814        // Host normalization (lowercase)
1815        let parts = parse_postgres_url("postgresql://user:pass@HOST.COM/db").unwrap();
1816        assert_eq!(parts.host, "host.com");
1817        assert_eq!(parts.password, Some("pass".to_string()));
1818
1819        // Password with special characters
1820        let parts = parse_postgres_url("postgresql://user:p@ss!word@host/db").unwrap();
1821        assert_eq!(parts.password, Some("p@ss!word".to_string()));
1822    }
1823
1824    #[test]
1825    fn test_validate_postgres_identifier_valid() {
1826        // Valid identifiers
1827        assert!(validate_postgres_identifier("mydb").is_ok());
1828        assert!(validate_postgres_identifier("my_database").is_ok());
1829        assert!(validate_postgres_identifier("_private_db").is_ok());
1830        assert!(validate_postgres_identifier("db123").is_ok());
1831        assert!(validate_postgres_identifier("Database_2024").is_ok());
1832
1833        // Maximum length (63 characters)
1834        let max_length_name = "a".repeat(63);
1835        assert!(validate_postgres_identifier(&max_length_name).is_ok());
1836    }
1837
1838    #[test]
1839    fn test_pgpass_file_creation() {
1840        let parts = PostgresUrlParts {
1841            host: "localhost".to_string(),
1842            port: 5432,
1843            database: "testdb".to_string(),
1844            user: Some("testuser".to_string()),
1845            password: Some("testpass".to_string()),
1846            query_params: std::collections::HashMap::new(),
1847        };
1848
1849        let pgpass = PgPassFile::new(&parts).unwrap();
1850        assert!(pgpass.path().exists());
1851
1852        // Verify file content
1853        let content = std::fs::read_to_string(pgpass.path()).unwrap();
1854        assert_eq!(content, "localhost:5432:testdb:testuser:testpass\n");
1855
1856        // Verify permissions on Unix
1857        #[cfg(unix)]
1858        {
1859            use std::os::unix::fs::PermissionsExt;
1860            let metadata = std::fs::metadata(pgpass.path()).unwrap();
1861            let permissions = metadata.permissions();
1862            assert_eq!(permissions.mode() & 0o777, 0o600);
1863        }
1864
1865        // File should be cleaned up when pgpass is dropped
1866        let path = pgpass.path().to_path_buf();
1867        drop(pgpass);
1868        assert!(!path.exists());
1869    }
1870
1871    #[test]
1872    fn test_pgpass_file_without_password() {
1873        let parts = PostgresUrlParts {
1874            host: "localhost".to_string(),
1875            port: 5432,
1876            database: "testdb".to_string(),
1877            user: Some("testuser".to_string()),
1878            password: None,
1879            query_params: std::collections::HashMap::new(),
1880        };
1881
1882        let pgpass = PgPassFile::new(&parts).unwrap();
1883        let content = std::fs::read_to_string(pgpass.path()).unwrap();
1884        // Should use empty password
1885        assert_eq!(content, "localhost:5432:testdb:testuser:\n");
1886    }
1887
1888    #[test]
1889    fn test_pgpass_file_without_user() {
1890        let parts = PostgresUrlParts {
1891            host: "localhost".to_string(),
1892            port: 5432,
1893            database: "testdb".to_string(),
1894            user: None,
1895            password: Some("testpass".to_string()),
1896            query_params: std::collections::HashMap::new(),
1897        };
1898
1899        let pgpass = PgPassFile::new(&parts).unwrap();
1900        let content = std::fs::read_to_string(pgpass.path()).unwrap();
1901        // Should use wildcard for user
1902        assert_eq!(content, "localhost:5432:testdb:*:testpass\n");
1903    }
1904
1905    #[test]
1906    fn test_strip_password_from_url() {
1907        // With password
1908        let url = "postgresql://user:p@ssw0rd@host:5432/db";
1909        let stripped = strip_password_from_url(url).unwrap();
1910        assert_eq!(stripped, "postgresql://user@host:5432/db");
1911
1912        // With special characters in password
1913        let url = "postgresql://user:p@ss!w0rd@host:5432/db";
1914        let stripped = strip_password_from_url(url).unwrap();
1915        assert_eq!(stripped, "postgresql://user@host:5432/db");
1916
1917        // Without password
1918        let url = "postgresql://user@host:5432/db";
1919        let stripped = strip_password_from_url(url).unwrap();
1920        assert_eq!(stripped, "postgresql://user@host:5432/db");
1921
1922        // With query parameters
1923        let url = "postgresql://user:pass@host:5432/db?sslmode=require";
1924        let stripped = strip_password_from_url(url).unwrap();
1925        assert_eq!(stripped, "postgresql://user@host:5432/db?sslmode=require");
1926
1927        // No user
1928        let url = "postgresql://host:5432/db";
1929        let stripped = strip_password_from_url(url).unwrap();
1930        assert_eq!(stripped, "postgresql://host:5432/db");
1931    }
1932
1933    #[test]
1934    fn test_validate_postgres_identifier_invalid() {
1935        // SQL injection attempts
1936        assert!(validate_postgres_identifier("mydb\"; DROP DATABASE production; --").is_err());
1937        assert!(validate_postgres_identifier("db'; DELETE FROM users; --").is_err());
1938
1939        // Invalid start characters
1940        assert!(validate_postgres_identifier("123db").is_err()); // Starts with digit
1941        assert!(validate_postgres_identifier("$db").is_err()); // Starts with special char
1942        assert!(validate_postgres_identifier("-db").is_err()); // Starts with dash
1943
1944        // Contains invalid characters
1945        assert!(validate_postgres_identifier("my-database").is_err()); // Contains dash
1946        assert!(validate_postgres_identifier("my.database").is_err()); // Contains dot
1947        assert!(validate_postgres_identifier("my database").is_err()); // Contains space
1948        assert!(validate_postgres_identifier("my@db").is_err()); // Contains @
1949        assert!(validate_postgres_identifier("my#db").is_err()); // Contains #
1950
1951        // Empty or too long
1952        assert!(validate_postgres_identifier("").is_err());
1953        assert!(validate_postgres_identifier("   ").is_err());
1954
1955        // Over maximum length (64+ characters)
1956        let too_long = "a".repeat(64);
1957        assert!(validate_postgres_identifier(&too_long).is_err());
1958
1959        // Control characters
1960        assert!(validate_postgres_identifier("my\ndb").is_err());
1961        assert!(validate_postgres_identifier("my\tdb").is_err());
1962        assert!(validate_postgres_identifier("my\x00db").is_err());
1963    }
1964
1965    #[test]
1966    fn test_is_serendb_target() {
1967        // Positive cases - SerenDB hosts
1968        assert!(is_serendb_target(
1969            "postgresql://user:pass@db.serendb.com/mydb"
1970        ));
1971        assert!(is_serendb_target(
1972            "postgresql://user:pass@cluster.console.serendb.com/mydb"
1973        ));
1974        assert!(is_serendb_target(
1975            "postgres://u:p@x.serendb.com:5432/db?sslmode=require"
1976        ));
1977        assert!(is_serendb_target("postgresql://user:pass@serendb.com/mydb"));
1978
1979        // Negative cases - not SerenDB
1980        assert!(!is_serendb_target("postgresql://user:pass@localhost/mydb"));
1981        assert!(!is_serendb_target(
1982            "postgresql://user:pass@rds.amazonaws.com/mydb"
1983        ));
1984        assert!(!is_serendb_target("postgresql://user:pass@neon.tech/mydb"));
1985        // Domain spoofing attempt - should NOT match
1986        assert!(!is_serendb_target(
1987            "postgresql://user:pass@serendb.com.evil.com/mydb"
1988        ));
1989        assert!(!is_serendb_target(
1990            "postgresql://user:pass@notserendb.com/mydb"
1991        ));
1992        // Invalid URL
1993        assert!(!is_serendb_target("not-a-url"));
1994    }
1995
1996    #[test]
1997    fn test_parse_pg_version_string() {
1998        // Standard pg_dump output
1999        assert_eq!(
2000            parse_pg_version_string("pg_dump (PostgreSQL) 16.10 (Ubuntu 16.10-0ubuntu0.24.04.1)")
2001                .unwrap(),
2002            16
2003        );
2004
2005        // Standard psql output
2006        assert_eq!(
2007            parse_pg_version_string("psql (PostgreSQL) 17.2").unwrap(),
2008            17
2009        );
2010
2011        // pg_restore output
2012        assert_eq!(
2013            parse_pg_version_string("pg_restore (PostgreSQL) 15.4").unwrap(),
2014            15
2015        );
2016
2017        // Debian-style version
2018        assert_eq!(
2019            parse_pg_version_string("17.2 (Debian 17.2-1.pgdg120+1)").unwrap(),
2020            17
2021        );
2022
2023        // Should fail on invalid input
2024        assert!(parse_pg_version_string("not a version").is_err());
2025        assert!(parse_pg_version_string("version 1.2.3").is_err()); // 1 is < 9
2026        assert!(parse_pg_version_string("").is_err());
2027    }
2028
2029    #[test]
2030    fn test_get_pg_tool_version() {
2031        // This test will only pass if pg_dump is installed
2032        // Skip gracefully if not available
2033        if which("pg_dump").is_ok() {
2034            let version = get_pg_tool_version("pg_dump").unwrap();
2035            assert!(
2036                version >= 12,
2037                "Expected pg_dump version >= 12, got {}",
2038                version
2039            );
2040            assert!(
2041                version <= 99,
2042                "Expected pg_dump version <= 99, got {}",
2043                version
2044            );
2045        }
2046
2047        // Non-existent tool should fail
2048        assert!(get_pg_tool_version("nonexistent_pg_tool_xyz").is_err());
2049    }
2050}