database_replicator/utils.rs
1// ABOUTME: Utility functions for validation and error handling
2// ABOUTME: Provides input validation, retry logic, and resource cleanup
3
4use anyhow::{bail, Context, Result};
5use std::time::Duration;
6use which::which;
7
8/// Get TCP keepalive environment variables for PostgreSQL client tools
9///
10/// Returns environment variables that configure TCP keepalives for external
11/// PostgreSQL tools (pg_dump, pg_restore, psql, pg_dumpall). These prevent
12/// idle connection timeouts when connecting through load balancers like AWS ELB.
13///
14/// Environment variables returned:
15/// - `PGKEEPALIVES=1`: Enable TCP keepalives
16/// - `PGKEEPALIVESIDLE=60`: Send first keepalive after 60 seconds of idle time
17/// - `PGKEEPALIVESINTERVAL=10`: Send subsequent keepalives every 10 seconds
18///
19/// # Returns
20///
21/// A vector of (variable_name, value) tuples to be passed to subprocess commands
22///
23/// # Examples
24///
25/// ```
26/// # use database_replicator::utils::get_keepalive_env_vars;
27/// # use std::process::Command;
28/// let keepalive_vars = get_keepalive_env_vars();
29/// let mut cmd = Command::new("psql");
30/// for (key, value) in keepalive_vars {
31/// cmd.env(key, value);
32/// }
33/// ```
34pub fn get_keepalive_env_vars() -> Vec<(&'static str, &'static str)> {
35 vec![
36 ("PGKEEPALIVES", "1"),
37 ("PGKEEPALIVESIDLE", "60"),
38 ("PGKEEPALIVESINTERVAL", "10"),
39 ]
40}
41
42/// Validate a PostgreSQL connection string
43///
44/// Checks that the connection string has proper format and required components:
45/// - Starts with "postgres://" or "postgresql://"
46/// - Contains user credentials (@ symbol)
47/// - Contains database name (/ separator with at least 3 occurrences)
48///
49/// # Arguments
50///
51/// * `url` - Connection string to validate
52///
53/// # Returns
54///
55/// Returns `Ok(())` if the connection string is valid.
56///
57/// # Errors
58///
59/// Returns an error with helpful message if the connection string is:
60/// - Empty or whitespace only
61/// - Missing proper scheme (postgres:// or postgresql://)
62/// - Missing user credentials (@ symbol)
63/// - Missing database name
64///
65/// # Examples
66///
67/// ```
68/// # use database_replicator::utils::validate_connection_string;
69/// # use anyhow::Result;
70/// # fn example() -> Result<()> {
71/// // Valid connection strings
72/// validate_connection_string("postgresql://user:pass@localhost:5432/mydb")?;
73/// validate_connection_string("postgres://user@host/db")?;
74///
75/// // Invalid - will return error
76/// assert!(validate_connection_string("").is_err());
77/// assert!(validate_connection_string("mysql://localhost/db").is_err());
78/// # Ok(())
79/// # }
80/// ```
81pub fn validate_connection_string(url: &str) -> Result<()> {
82 if url.trim().is_empty() {
83 bail!("Connection string cannot be empty");
84 }
85
86 // Check for common URL schemes
87 if !url.starts_with("postgres://") && !url.starts_with("postgresql://") {
88 bail!(
89 "Invalid connection string format.\n\
90 Expected format: postgresql://user:password@host:port/database\n\
91 Got: {}",
92 url
93 );
94 }
95
96 // Check for minimum required components (user@host/database)
97 if !url.contains('@') {
98 bail!(
99 "Connection string missing user credentials.\n\
100 Expected format: postgresql://user:password@host:port/database"
101 );
102 }
103
104 if !url.contains('/') || url.matches('/').count() < 3 {
105 bail!(
106 "Connection string missing database name.\n\
107 Expected format: postgresql://user:password@host:port/database"
108 );
109 }
110
111 Ok(())
112}
113
114/// Check that required PostgreSQL client tools are available
115///
116/// Verifies that the following tools are installed and in PATH:
117/// - `pg_dump` - For dumping database schema and data
118/// - `pg_dumpall` - For dumping global objects (roles, tablespaces)
119/// - `psql` - For restoring databases
120///
121/// # Returns
122///
123/// Returns `Ok(())` if all required tools are found.
124///
125/// # Errors
126///
127/// Returns an error with installation instructions if any tools are missing.
128///
129/// # Examples
130///
131/// ```
132/// # use database_replicator::utils::check_required_tools;
133/// # use anyhow::Result;
134/// # fn example() -> Result<()> {
135/// // Check if PostgreSQL tools are installed
136/// check_required_tools()?;
137/// # Ok(())
138/// # }
139/// ```
140pub fn check_required_tools() -> Result<()> {
141 let tools = ["pg_dump", "pg_dumpall", "psql"];
142 let mut missing = Vec::new();
143
144 for tool in &tools {
145 if which(tool).is_err() {
146 missing.push(*tool);
147 }
148 }
149
150 if !missing.is_empty() {
151 bail!(
152 "Missing required PostgreSQL client tools: {}\n\
153 \n\
154 Please install PostgreSQL client tools:\n\
155 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
156 - macOS: brew install postgresql\n\
157 - RHEL/CentOS: sudo yum install postgresql\n\
158 - Windows: Download from https://www.postgresql.org/download/windows/",
159 missing.join(", ")
160 );
161 }
162
163 Ok(())
164}
165
166/// Retry a function with exponential backoff
167///
168/// Executes an async operation with automatic retry on failure. Each retry doubles
169/// the delay (exponential backoff) to handle transient failures gracefully.
170///
171/// # Arguments
172///
173/// * `operation` - Async function to retry (FnMut returning Future\<Output = Result\<T\>\>)
174/// * `max_retries` - Maximum number of retry attempts (0 = no retries, just initial attempt)
175/// * `initial_delay` - Delay before first retry (doubles each subsequent retry)
176///
177/// # Returns
178///
179/// Returns the successful result or the last error after all retries exhausted.
180///
181/// # Examples
182///
183/// ```no_run
184/// # use anyhow::Result;
185/// # use std::time::Duration;
186/// # use database_replicator::utils::retry_with_backoff;
187/// # async fn example() -> Result<()> {
188/// let result = retry_with_backoff(
189/// || async { Ok("success") },
190/// 3, // Try up to 3 times
191/// Duration::from_secs(1) // Start with 1s delay
192/// ).await?;
193/// # Ok(())
194/// # }
195/// ```
196pub async fn retry_with_backoff<F, Fut, T>(
197 mut operation: F,
198 max_retries: u32,
199 initial_delay: Duration,
200) -> Result<T>
201where
202 F: FnMut() -> Fut,
203 Fut: std::future::Future<Output = Result<T>>,
204{
205 let mut delay = initial_delay;
206 let mut last_error = None;
207
208 for attempt in 0..=max_retries {
209 match operation().await {
210 Ok(result) => return Ok(result),
211 Err(e) => {
212 last_error = Some(e);
213
214 if attempt < max_retries {
215 tracing::warn!(
216 "Operation failed (attempt {}/{}), retrying in {:?}...",
217 attempt + 1,
218 max_retries + 1,
219 delay
220 );
221 tokio::time::sleep(delay).await;
222 delay *= 2; // Exponential backoff
223 }
224 }
225 }
226 }
227
228 Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Operation failed after retries")))
229}
230
231/// Retry a subprocess execution with exponential backoff on connection errors
232///
233/// Executes a subprocess command with automatic retry on connection-related failures.
234/// Each retry doubles the delay (exponential backoff) to handle transient connection issues.
235///
236/// Connection errors are detected by checking:
237/// - Non-zero exit codes
238/// - Stderr output containing connection-related error patterns:
239/// - "connection closed"
240/// - "connection refused"
241/// - "could not connect"
242/// - "server closed the connection"
243/// - "timeout"
244/// - "Connection timed out"
245///
246/// # Arguments
247///
248/// * `operation` - Function that executes a Command and returns the exit status
249/// * `max_retries` - Maximum number of retry attempts (0 = no retries, just initial attempt)
250/// * `initial_delay` - Delay before first retry (doubles each subsequent retry)
251/// * `operation_name` - Name of the operation for logging (e.g., "pg_restore", "psql")
252///
253/// # Returns
254///
255/// Returns Ok(()) on success or the last error after all retries exhausted.
256///
257/// # Examples
258///
259/// ```no_run
260/// # use anyhow::Result;
261/// # use std::time::Duration;
262/// # use std::process::Command;
263/// # use database_replicator::utils::retry_subprocess_with_backoff;
264/// # fn example() -> Result<()> {
265/// retry_subprocess_with_backoff(
266/// || {
267/// let mut cmd = Command::new("psql");
268/// cmd.arg("--version");
269/// cmd.status().map_err(anyhow::Error::from)
270/// },
271/// 3, // Try up to 3 times
272/// Duration::from_secs(1), // Start with 1s delay
273/// "psql"
274/// )?;
275/// # Ok(())
276/// # }
277/// ```
278pub fn retry_subprocess_with_backoff<F>(
279 mut operation: F,
280 max_retries: u32,
281 initial_delay: Duration,
282 operation_name: &str,
283) -> Result<()>
284where
285 F: FnMut() -> Result<std::process::ExitStatus>,
286{
287 let mut delay = initial_delay;
288 let mut last_error = None;
289
290 for attempt in 0..=max_retries {
291 match operation() {
292 Ok(status) => {
293 if status.success() {
294 return Ok(());
295 } else {
296 // Non-zero exit code - check if it's a connection error
297 // We can't easily capture stderr here, so we'll treat all non-zero
298 // exit codes as potential connection errors for now
299 let error = anyhow::anyhow!(
300 "{} failed with exit code: {}",
301 operation_name,
302 status.code().unwrap_or(-1)
303 );
304 last_error = Some(error);
305
306 if attempt < max_retries {
307 tracing::warn!(
308 "{} failed (attempt {}/{}), retrying in {:?}...",
309 operation_name,
310 attempt + 1,
311 max_retries + 1,
312 delay
313 );
314 std::thread::sleep(delay);
315 delay *= 2; // Exponential backoff
316 }
317 }
318 }
319 Err(e) => {
320 last_error = Some(e);
321
322 if attempt < max_retries {
323 tracing::warn!(
324 "{} failed (attempt {}/{}): {}, retrying in {:?}...",
325 operation_name,
326 attempt + 1,
327 max_retries + 1,
328 last_error.as_ref().unwrap(),
329 delay
330 );
331 std::thread::sleep(delay);
332 delay *= 2; // Exponential backoff
333 }
334 }
335 }
336 }
337
338 Err(last_error.unwrap_or_else(|| {
339 anyhow::anyhow!("{} failed after {} retries", operation_name, max_retries)
340 }))
341}
342
343/// Validate a PostgreSQL identifier (database name, schema name, etc.)
344///
345/// Validates that an identifier follows PostgreSQL naming rules to prevent SQL injection.
346/// PostgreSQL identifiers must:
347/// - Be 1-63 characters long
348/// - Start with a letter (a-z, A-Z) or underscore (_)
349/// - Contain only letters, digits (0-9), or underscores
350///
351/// # Arguments
352///
353/// * `identifier` - The identifier to validate (database name, schema name, etc.)
354///
355/// # Returns
356///
357/// Returns `Ok(())` if the identifier is valid.
358///
359/// # Errors
360///
361/// Returns an error if the identifier:
362/// - Is empty or whitespace-only
363/// - Exceeds 63 characters
364/// - Starts with an invalid character (digit or special character)
365/// - Contains invalid characters (anything except a-z, A-Z, 0-9, _)
366///
367/// # Security
368///
369/// This function is critical for preventing SQL injection attacks. All database
370/// names, schema names, and table names from untrusted sources MUST be validated
371/// before use in SQL statements.
372///
373/// # Examples
374///
375/// ```
376/// # use database_replicator::utils::validate_postgres_identifier;
377/// # use anyhow::Result;
378/// # fn example() -> Result<()> {
379/// // Valid identifiers
380/// validate_postgres_identifier("mydb")?;
381/// validate_postgres_identifier("my_database")?;
382/// validate_postgres_identifier("_private_db")?;
383///
384/// // Invalid - will return error
385/// assert!(validate_postgres_identifier("123db").is_err());
386/// assert!(validate_postgres_identifier("my-database").is_err());
387/// assert!(validate_postgres_identifier("db\"; DROP TABLE users; --").is_err());
388/// # Ok(())
389/// # }
390/// ```
391pub fn validate_postgres_identifier(identifier: &str) -> Result<()> {
392 // Check for empty or whitespace-only
393 let trimmed = identifier.trim();
394 if trimmed.is_empty() {
395 bail!("Identifier cannot be empty or whitespace-only");
396 }
397
398 // Check length (PostgreSQL limit is 63 characters)
399 if trimmed.len() > 63 {
400 bail!(
401 "Identifier '{}' exceeds maximum length of 63 characters (got {})",
402 sanitize_identifier(trimmed),
403 trimmed.len()
404 );
405 }
406
407 // Get first character
408 let first_char = trimmed.chars().next().unwrap();
409
410 // First character must be a letter or underscore
411 if !first_char.is_ascii_alphabetic() && first_char != '_' {
412 bail!(
413 "Identifier '{}' must start with a letter or underscore, not '{}'",
414 sanitize_identifier(trimmed),
415 first_char
416 );
417 }
418
419 // All characters must be alphanumeric or underscore
420 for (i, c) in trimmed.chars().enumerate() {
421 if !c.is_ascii_alphanumeric() && c != '_' {
422 bail!(
423 "Identifier '{}' contains invalid character '{}' at position {}. \
424 Only letters, digits, and underscores are allowed",
425 sanitize_identifier(trimmed),
426 if c.is_control() {
427 format!("\\x{:02x}", c as u32)
428 } else {
429 c.to_string()
430 },
431 i
432 );
433 }
434 }
435
436 Ok(())
437}
438
439/// Sanitize an identifier (table name, schema name, etc.) for display
440///
441/// Removes control characters and limits length to prevent log injection attacks
442/// and ensure readable error messages.
443///
444/// **Note**: This is for display purposes only. For SQL safety, use parameterized
445/// queries instead.
446///
447/// # Arguments
448///
449/// * `identifier` - The identifier to sanitize (table name, schema name, etc.)
450///
451/// # Returns
452///
453/// Sanitized string with control characters removed and length limited to 100 chars.
454///
455/// # Examples
456///
457/// ```
458/// # use database_replicator::utils::sanitize_identifier;
459/// assert_eq!(sanitize_identifier("normal_table"), "normal_table");
460/// assert_eq!(sanitize_identifier("table\x00name"), "tablename");
461/// assert_eq!(sanitize_identifier("table\nname"), "tablename");
462///
463/// // Length limit
464/// let long_name = "a".repeat(200);
465/// assert_eq!(sanitize_identifier(&long_name).len(), 100);
466/// ```
467pub fn sanitize_identifier(identifier: &str) -> String {
468 // Remove any control characters and limit length for display
469 identifier
470 .chars()
471 .filter(|c| !c.is_control())
472 .take(100)
473 .collect()
474}
475
476/// Quote a PostgreSQL identifier (database, schema, table, column)
477///
478/// Assumes the identifier has already been validated. Escapes embedded quotes
479/// and wraps the identifier in double quotes.
480pub fn quote_ident(identifier: &str) -> String {
481 let mut quoted = String::with_capacity(identifier.len() + 2);
482 quoted.push('"');
483 for ch in identifier.chars() {
484 if ch == '"' {
485 quoted.push('"');
486 }
487 quoted.push(ch);
488 }
489 quoted.push('"');
490 quoted
491}
492
493/// Validate that source and target URLs are different to prevent accidental data loss
494///
495/// Compares two PostgreSQL connection URLs to ensure they point to different databases.
496/// This is critical for preventing data loss from operations like `init --drop-existing`
497/// where using the same URL for source and target would destroy the source data.
498///
499/// # Comparison Strategy
500///
501/// URLs are normalized and compared on:
502/// - Host (case-insensitive)
503/// - Port (defaulting to 5432 if not specified)
504/// - Database name (case-sensitive)
505/// - User (if present)
506///
507/// Query parameters (like SSL settings) are ignored as they don't affect database identity.
508///
509/// # Arguments
510///
511/// * `source_url` - Source database connection string
512/// * `target_url` - Target database connection string
513///
514/// # Returns
515///
516/// Returns `Ok(())` if the URLs point to different databases.
517///
518/// # Errors
519///
520/// Returns an error if:
521/// - The URLs point to the same database (same host, port, database name, and user)
522/// - Either URL is malformed and cannot be parsed
523///
524/// # Examples
525///
526/// ```
527/// # use database_replicator::utils::validate_source_target_different;
528/// # use anyhow::Result;
529/// # fn example() -> Result<()> {
530/// // Valid - different hosts
531/// validate_source_target_different(
532/// "postgresql://user:pass@source.com:5432/db",
533/// "postgresql://user:pass@target.com:5432/db"
534/// )?;
535///
536/// // Valid - different databases
537/// validate_source_target_different(
538/// "postgresql://user:pass@host:5432/db1",
539/// "postgresql://user:pass@host:5432/db2"
540/// )?;
541///
542/// // Invalid - same database
543/// assert!(validate_source_target_different(
544/// "postgresql://user:pass@host:5432/db",
545/// "postgresql://user:pass@host:5432/db"
546/// ).is_err());
547/// # Ok(())
548/// # }
549/// ```
550pub fn validate_source_target_different(source_url: &str, target_url: &str) -> Result<()> {
551 // Parse both URLs to extract components
552 let source_parts = parse_postgres_url(source_url)
553 .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
554 let target_parts = parse_postgres_url(target_url)
555 .with_context(|| format!("Failed to parse target URL: {}", target_url))?;
556
557 // Compare normalized components
558 if source_parts.host == target_parts.host
559 && source_parts.port == target_parts.port
560 && source_parts.database == target_parts.database
561 && source_parts.user == target_parts.user
562 {
563 bail!(
564 "Source and target URLs point to the same database!\\n\\\n \\n\\\n This would cause DATA LOSS - the target would overwrite the source.\\n\\\n \\n\\\n Source: {}@{}:{}/{}\\n\\\n Target: {}@{}:{}/{}\\n\\\n \\n\\\n Please ensure source and target are different databases.\\n\\\n Common causes:\\n\\\n - Copy-paste error in connection strings\\n\\\n - Wrong environment variables (e.g., SOURCE_URL == TARGET_URL)\\n\\\n - Typo in database name or host",
565 source_parts.user.as_deref().unwrap_or("(no user)"),
566 source_parts.host,
567 source_parts.port,
568 source_parts.database,
569 target_parts.user.as_deref().unwrap_or("(no user)"),
570 target_parts.host,
571 target_parts.port,
572 target_parts.database
573 );
574 }
575
576 Ok(())
577}
578
579/// Parse a PostgreSQL URL into its components
580///
581/// # Arguments
582///
583/// * `url` - PostgreSQL connection URL (postgres:// or postgresql://)
584///
585/// # Returns
586///
587/// Returns a `PostgresUrlParts` struct with normalized components.
588///
589/// # Security
590///
591/// This function extracts passwords from URLs for use with .pgpass files.
592/// Ensure returned values are handled securely and not logged.
593pub fn parse_postgres_url(url: &str) -> Result<PostgresUrlParts> {
594 // Remove scheme
595 let url_without_scheme = url
596 .trim_start_matches("postgres://")
597 .trim_start_matches("postgresql://");
598
599 // Split into base and query params
600 let (base, query_string) = if let Some((b, q)) = url_without_scheme.split_once('?') {
601 (b, Some(q))
602 } else {
603 (url_without_scheme, None)
604 };
605
606 // Parse query parameters into HashMap
607 let mut query_params = std::collections::HashMap::new();
608 if let Some(query) = query_string {
609 for param in query.split('&') {
610 if let Some((key, value)) = param.split_once('=') {
611 query_params.insert(key.to_string(), value.to_string());
612 }
613 }
614 }
615
616 // Parse: [user[:password]@]host[:port]/database
617 let (auth_and_host, database) = base
618 .rsplit_once('/')
619 .ok_or_else(|| anyhow::anyhow!("Missing database name in URL"))?;
620
621 // Parse authentication and host
622 // Use rsplit_once to split from the right, so passwords can contain '@'
623 let (user, password, host_and_port) = if let Some((auth, hp)) = auth_and_host.rsplit_once('@') {
624 // Has authentication
625 let (user, pass) = if let Some((u, p)) = auth.split_once(':') {
626 (Some(u.to_string()), Some(p.to_string()))
627 } else {
628 (Some(auth.to_string()), None)
629 };
630 (user, pass, hp)
631 } else {
632 // No authentication
633 (None, None, auth_and_host)
634 };
635
636 // Parse host and port
637 let (host, port) = if let Some((h, p)) = host_and_port.rsplit_once(':') {
638 // Port specified
639 let port = p
640 .parse::<u16>()
641 .with_context(|| format!("Invalid port number: {}", p))?;
642 (h, port)
643 } else {
644 // Use default PostgreSQL port
645 (host_and_port, 5432)
646 };
647
648 Ok(PostgresUrlParts {
649 host: host.to_lowercase(), // Hostnames are case-insensitive
650 port,
651 database: database.to_string(), // Database names are case-sensitive in PostgreSQL
652 user,
653 password,
654 query_params,
655 })
656}
657
658/// Strip password from PostgreSQL connection URL
659/// Returns a new URL with password removed, preserving all other components
660/// This is useful for storing connection strings in places where passwords should not be visible
661pub fn strip_password_from_url(url: &str) -> Result<String> {
662 let parts = parse_postgres_url(url)?;
663
664 // Reconstruct URL without password
665 let scheme = if url.starts_with("postgresql://") {
666 "postgresql://"
667 } else if url.starts_with("postgres://") {
668 "postgres://"
669 } else {
670 bail!("Invalid PostgreSQL URL scheme");
671 };
672
673 let mut result = String::from(scheme);
674
675 // Add user if present (without password)
676 if let Some(user) = &parts.user {
677 result.push_str(user);
678 result.push('@');
679 }
680
681 // Add host and port
682 result.push_str(&parts.host);
683 result.push(':');
684 result.push_str(&parts.port.to_string());
685
686 // Add database
687 result.push('/');
688 result.push_str(&parts.database);
689
690 // Preserve query parameters if present
691 if let Some(query_start) = url.find('?') {
692 result.push_str(&url[query_start..]);
693 }
694
695 Ok(result)
696}
697
698/// Parsed components of a PostgreSQL connection URL
699#[derive(Debug, PartialEq)]
700pub struct PostgresUrlParts {
701 pub host: String,
702 pub port: u16,
703 pub database: String,
704 pub user: Option<String>,
705 pub password: Option<String>,
706 pub query_params: std::collections::HashMap<String, String>,
707}
708
709impl PostgresUrlParts {
710 /// Convert query parameters to PostgreSQL environment variables
711 ///
712 /// Maps common connection URL query parameters to their corresponding
713 /// PostgreSQL environment variable names. This allows SSL/TLS and other
714 /// connection settings to be passed to pg_dump, pg_dumpall, psql, etc.
715 ///
716 /// # Supported Parameters
717 ///
718 /// - `sslmode` → `PGSSLMODE`
719 /// - `sslcert` → `PGSSLCERT`
720 /// - `sslkey` → `PGSSLKEY`
721 /// - `sslrootcert` → `PGSSLROOTCERT`
722 /// - `channel_binding` → `PGCHANNELBINDING`
723 /// - `connect_timeout` → `PGCONNECT_TIMEOUT`
724 /// - `application_name` → `PGAPPNAME`
725 /// - `client_encoding` → `PGCLIENTENCODING`
726 ///
727 /// # Returns
728 ///
729 /// Vec of (env_var_name, value) pairs to be set as environment variables
730 pub fn to_pg_env_vars(&self) -> Vec<(&'static str, String)> {
731 let mut env_vars = Vec::new();
732
733 // Map query parameters to PostgreSQL environment variables
734 let param_mapping = [
735 ("sslmode", "PGSSLMODE"),
736 ("sslcert", "PGSSLCERT"),
737 ("sslkey", "PGSSLKEY"),
738 ("sslrootcert", "PGSSLROOTCERT"),
739 ("channel_binding", "PGCHANNELBINDING"),
740 ("connect_timeout", "PGCONNECT_TIMEOUT"),
741 ("application_name", "PGAPPNAME"),
742 ("client_encoding", "PGCLIENTENCODING"),
743 ];
744
745 for (param_name, env_var_name) in param_mapping {
746 if let Some(value) = self.query_params.get(param_name) {
747 env_vars.push((env_var_name, value.clone()));
748 }
749 }
750
751 env_vars
752 }
753}
754
755/// Managed .pgpass file for secure password passing to PostgreSQL tools
756///
757/// This struct creates a temporary .pgpass file with secure permissions (0600)
758/// and automatically cleans it up when dropped. PostgreSQL command-line tools
759/// read credentials from this file instead of accepting passwords in URLs,
760/// which prevents command injection vulnerabilities.
761///
762/// # Security
763///
764/// - File permissions are set to 0600 (owner read/write only)
765/// - File is automatically removed on Drop
766/// - Credentials are never passed on command line
767///
768/// # Format
769///
770/// .pgpass file format: hostname:port:database:username:password
771/// Wildcards (*) are used for maximum compatibility
772///
773/// # Examples
774///
775/// ```no_run
776/// # use database_replicator::utils::{PgPassFile, parse_postgres_url};
777/// # use anyhow::Result;
778/// # fn example() -> Result<()> {
779/// let url = "postgresql://user:pass@localhost:5432/mydb";
780/// let parts = parse_postgres_url(url)?;
781/// let pgpass = PgPassFile::new(&parts)?;
782///
783/// // Use pgpass.path() with PGPASSFILE environment variable
784/// // File is automatically cleaned up when pgpass goes out of scope
785/// # Ok(())
786/// # }
787/// ```
788pub struct PgPassFile {
789 path: std::path::PathBuf,
790}
791
792impl PgPassFile {
793 /// Create a new .pgpass file with credentials from URL parts
794 ///
795 /// # Arguments
796 ///
797 /// * `parts` - Parsed PostgreSQL URL components
798 ///
799 /// # Returns
800 ///
801 /// Returns a PgPassFile that will be automatically cleaned up on Drop
802 ///
803 /// # Errors
804 ///
805 /// Returns an error if the file cannot be created or permissions cannot be set
806 pub fn new(parts: &PostgresUrlParts) -> Result<Self> {
807 use std::fs;
808 use std::io::Write;
809
810 // Create temp file with secure name
811 let temp_dir = std::env::temp_dir();
812 let random: u32 = rand::random();
813 let filename = format!("pgpass-{:08x}", random);
814 let path = temp_dir.join(filename);
815
816 // Write .pgpass entry
817 // Format: hostname:port:database:username:password
818 let username = parts.user.as_deref().unwrap_or("*");
819 let password = parts.password.as_deref().unwrap_or("");
820 let entry = format!(
821 "{}:{}:{}:{}:{}\n",
822 parts.host, parts.port, parts.database, username, password
823 );
824
825 let mut file = fs::File::create(&path)
826 .with_context(|| format!("Failed to create .pgpass file at {}", path.display()))?;
827
828 file.write_all(entry.as_bytes())
829 .with_context(|| format!("Failed to write to .pgpass file at {}", path.display()))?;
830
831 // Set secure permissions (0600) - owner read/write only
832 #[cfg(unix)]
833 {
834 use std::os::unix::fs::PermissionsExt;
835 let permissions = fs::Permissions::from_mode(0o600);
836 fs::set_permissions(&path, permissions).with_context(|| {
837 format!(
838 "Failed to set permissions on .pgpass file at {}",
839 path.display()
840 )
841 })?;
842 }
843
844 // On Windows, .pgpass is stored in %APPDATA%\postgresql\pgpass.conf
845 // but for our temporary use case, we'll just use a temp file
846 // PostgreSQL on Windows also checks permissions but less strictly
847
848 Ok(Self { path })
849 }
850
851 /// Get the path to the .pgpass file
852 ///
853 /// Use this with the PGPASSFILE environment variable when running
854 /// PostgreSQL command-line tools
855 pub fn path(&self) -> &std::path::Path {
856 &self.path
857 }
858}
859
860impl Drop for PgPassFile {
861 fn drop(&mut self) {
862 // Best effort cleanup - don't panic if removal fails
863 let _ = std::fs::remove_file(&self.path);
864 }
865}
866
867/// Create a managed temporary directory with explicit cleanup support
868///
869/// Creates a temporary directory with a timestamped name that can be cleaned up
870/// even if the process is killed with SIGKILL. Unlike `TempDir::new()` which
871/// relies on the Drop trait, this function creates named directories that can
872/// be cleaned up on next process startup.
873///
874/// Directory naming format: `postgres-seren-replicator-{timestamp}-{random}`
875/// Example: `postgres-seren-replicator-20250106-120534-a3b2c1d4`
876///
877/// # Returns
878///
879/// Returns the path to the created temporary directory.
880///
881/// # Errors
882///
883/// Returns an error if the directory cannot be created.
884///
885/// # Examples
886///
887/// ```no_run
888/// # use database_replicator::utils::create_managed_temp_dir;
889/// # use anyhow::Result;
890/// # fn example() -> Result<()> {
891/// let temp_path = create_managed_temp_dir()?;
892/// println!("Using temp directory: {}", temp_path.display());
893/// // ... do work ...
894/// // Cleanup happens automatically on next startup via cleanup_stale_temp_dirs()
895/// # Ok(())
896/// # }
897/// ```
898pub fn create_managed_temp_dir() -> Result<std::path::PathBuf> {
899 use std::fs;
900 use std::time::SystemTime;
901
902 let system_temp = std::env::temp_dir();
903
904 // Generate timestamp for directory name
905 let timestamp = SystemTime::now()
906 .duration_since(SystemTime::UNIX_EPOCH)
907 .unwrap()
908 .as_secs();
909
910 // Generate random suffix for uniqueness
911 let random: u32 = rand::random();
912
913 // Create directory name with timestamp and random suffix
914 let dir_name = format!("postgres-seren-replicator-{}-{:08x}", timestamp, random);
915
916 let temp_path = system_temp.join(dir_name);
917
918 // Create the directory
919 fs::create_dir_all(&temp_path)
920 .with_context(|| format!("Failed to create temp directory at {}", temp_path.display()))?;
921
922 tracing::debug!("Created managed temp directory: {}", temp_path.display());
923
924 Ok(temp_path)
925}
926
927/// Clean up stale temporary directories from previous runs
928///
929/// Removes temporary directories created by `create_managed_temp_dir()` that are
930/// older than the specified age. This should be called on process startup to clean
931/// up directories left behind by processes killed with SIGKILL.
932///
933/// Only directories matching the pattern `postgres-seren-replicator-*` are removed.
934///
935/// # Arguments
936///
937/// * `max_age_secs` - Maximum age in seconds before a directory is considered stale
938/// (recommended: 86400 for 24 hours)
939///
940/// # Returns
941///
942/// Returns the number of directories cleaned up.
943///
944/// # Errors
945///
946/// Returns an error if the system temp directory cannot be read. Individual
947/// directory removal errors are logged but don't fail the entire operation.
948///
949/// # Examples
950///
951/// ```no_run
952/// # use database_replicator::utils::cleanup_stale_temp_dirs;
953/// # use anyhow::Result;
954/// # fn example() -> Result<()> {
955/// // Clean up temp directories older than 24 hours
956/// let cleaned = cleanup_stale_temp_dirs(86400)?;
957/// println!("Cleaned up {} stale temp directories", cleaned);
958/// # Ok(())
959/// # }
960/// ```
961pub fn cleanup_stale_temp_dirs(max_age_secs: u64) -> Result<usize> {
962 use std::fs;
963 use std::time::SystemTime;
964
965 let system_temp = std::env::temp_dir();
966 let now = SystemTime::now();
967 let mut cleaned_count = 0;
968
969 // Read all entries in system temp directory
970 let entries = fs::read_dir(&system_temp).with_context(|| {
971 format!(
972 "Failed to read system temp directory: {}",
973 system_temp.display()
974 )
975 })?;
976
977 for entry in entries.flatten() {
978 let path = entry.path();
979
980 // Only process directories matching our naming pattern
981 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
982 if !name.starts_with("postgres-seren-replicator-") {
983 continue;
984 }
985
986 // Check directory age
987 match entry.metadata() {
988 Ok(metadata) => {
989 if let Ok(modified) = metadata.modified() {
990 if let Ok(age) = now.duration_since(modified) {
991 if age.as_secs() > max_age_secs {
992 // Directory is stale, remove it
993 match fs::remove_dir_all(&path) {
994 Ok(_) => {
995 tracing::info!(
996 "Cleaned up stale temp directory: {} (age: {}s)",
997 path.display(),
998 age.as_secs()
999 );
1000 cleaned_count += 1;
1001 }
1002 Err(e) => {
1003 tracing::warn!(
1004 "Failed to remove stale temp directory {}: {}",
1005 path.display(),
1006 e
1007 );
1008 }
1009 }
1010 }
1011 }
1012 }
1013 }
1014 Err(e) => {
1015 tracing::warn!(
1016 "Failed to get metadata for temp directory {}: {}",
1017 path.display(),
1018 e
1019 );
1020 }
1021 }
1022 }
1023 }
1024
1025 if cleaned_count > 0 {
1026 tracing::info!(
1027 "Cleaned up {} stale temp directory(ies) older than {} seconds",
1028 cleaned_count,
1029 max_age_secs
1030 );
1031 }
1032
1033 Ok(cleaned_count)
1034}
1035
1036/// Remove a managed temporary directory
1037///
1038/// Explicitly removes a temporary directory created by `create_managed_temp_dir()`.
1039/// This should be called when the directory is no longer needed.
1040///
1041/// # Arguments
1042///
1043/// * `path` - Path to the temporary directory to remove
1044///
1045/// # Errors
1046///
1047/// Returns an error if the directory cannot be removed.
1048///
1049/// # Examples
1050///
1051/// ```no_run
1052/// # use database_replicator::utils::{create_managed_temp_dir, remove_managed_temp_dir};
1053/// # use anyhow::Result;
1054/// # fn example() -> Result<()> {
1055/// let temp_path = create_managed_temp_dir()?;
1056/// // ... do work ...
1057/// remove_managed_temp_dir(&temp_path)?;
1058/// # Ok(())
1059/// # }
1060/// ```
1061pub fn remove_managed_temp_dir(path: &std::path::Path) -> Result<()> {
1062 use std::fs;
1063
1064 // Verify this is one of our temp directories (safety check)
1065 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1066 if !name.starts_with("postgres-seren-replicator-") {
1067 bail!(
1068 "Refusing to remove directory that doesn't match our naming pattern: {}",
1069 path.display()
1070 );
1071 }
1072 } else {
1073 bail!("Invalid temp directory path: {}", path.display());
1074 }
1075
1076 tracing::debug!("Removing managed temp directory: {}", path.display());
1077
1078 fs::remove_dir_all(path)
1079 .with_context(|| format!("Failed to remove temp directory at {}", path.display()))?;
1080
1081 Ok(())
1082}
1083
1084#[cfg(test)]
1085mod tests {
1086 use super::*;
1087
1088 #[test]
1089 fn test_validate_connection_string_valid() {
1090 assert!(validate_connection_string("postgresql://user:pass@localhost:5432/dbname").is_ok());
1091 assert!(validate_connection_string("postgres://user@host/db").is_ok());
1092 }
1093
1094 #[test]
1095 fn test_check_required_tools() {
1096 // This test will pass if PostgreSQL client tools are installed
1097 // It will fail (appropriately) if they're not installed
1098 let result = check_required_tools();
1099
1100 // On systems with PostgreSQL installed, this should pass
1101 // On systems without it, we expect a specific error message
1102 if let Err(err) = result {
1103 let err_msg = err.to_string();
1104 assert!(err_msg.contains("Missing required PostgreSQL client tools"));
1105 assert!(
1106 err_msg.contains("pg_dump")
1107 || err_msg.contains("pg_dumpall")
1108 || err_msg.contains("psql")
1109 );
1110 }
1111 }
1112
1113 #[test]
1114 fn test_validate_connection_string_invalid() {
1115 assert!(validate_connection_string("").is_err());
1116 assert!(validate_connection_string(" ").is_err());
1117 assert!(validate_connection_string("mysql://localhost/db").is_err());
1118 assert!(validate_connection_string("postgresql://localhost").is_err());
1119 assert!(validate_connection_string("postgresql://localhost/db").is_err());
1120 // Missing user
1121 }
1122
1123 #[test]
1124 fn test_sanitize_identifier() {
1125 assert_eq!(sanitize_identifier("normal_table"), "normal_table");
1126 assert_eq!(sanitize_identifier("table\x00name"), "tablename");
1127 assert_eq!(sanitize_identifier("table\nname"), "tablename");
1128
1129 // Test length limit
1130 let long_name = "a".repeat(200);
1131 assert_eq!(sanitize_identifier(&long_name).len(), 100);
1132 }
1133
1134 #[tokio::test]
1135 async fn test_retry_with_backoff_success() {
1136 let mut attempts = 0;
1137 let result = retry_with_backoff(
1138 || {
1139 attempts += 1;
1140 async move {
1141 if attempts < 3 {
1142 anyhow::bail!("Temporary failure")
1143 } else {
1144 Ok("Success")
1145 }
1146 }
1147 },
1148 5,
1149 Duration::from_millis(10),
1150 )
1151 .await;
1152
1153 assert!(result.is_ok());
1154 assert_eq!(result.unwrap(), "Success");
1155 assert_eq!(attempts, 3);
1156 }
1157
1158 #[tokio::test]
1159 async fn test_retry_with_backoff_failure() {
1160 let mut attempts = 0;
1161 let result: Result<&str> = retry_with_backoff(
1162 || {
1163 attempts += 1;
1164 async move { anyhow::bail!("Permanent failure") }
1165 },
1166 2,
1167 Duration::from_millis(10),
1168 )
1169 .await;
1170
1171 assert!(result.is_err());
1172 assert_eq!(attempts, 3); // Initial + 2 retries
1173 }
1174
1175 #[test]
1176 fn test_validate_source_target_different_valid() {
1177 // Different hosts
1178 assert!(validate_source_target_different(
1179 "postgresql://user:pass@source.com:5432/db",
1180 "postgresql://user:pass@target.com:5432/db"
1181 )
1182 .is_ok());
1183
1184 // Different databases on same host
1185 assert!(validate_source_target_different(
1186 "postgresql://user:pass@host:5432/db1",
1187 "postgresql://user:pass@host:5432/db2"
1188 )
1189 .is_ok());
1190
1191 // Different ports on same host
1192 assert!(validate_source_target_different(
1193 "postgresql://user:pass@host:5432/db",
1194 "postgresql://user:pass@host:5433/db"
1195 )
1196 .is_ok());
1197
1198 // Different users on same host/db (edge case but allowed)
1199 assert!(validate_source_target_different(
1200 "postgresql://user1:pass@host:5432/db",
1201 "postgresql://user2:pass@host:5432/db"
1202 )
1203 .is_ok());
1204 }
1205
1206 #[test]
1207 fn test_validate_source_target_different_invalid() {
1208 // Exact same URL
1209 assert!(validate_source_target_different(
1210 "postgresql://user:pass@host:5432/db",
1211 "postgresql://user:pass@host:5432/db"
1212 )
1213 .is_err());
1214
1215 // Same URL with different scheme (postgres vs postgresql)
1216 assert!(validate_source_target_different(
1217 "postgres://user:pass@host:5432/db",
1218 "postgresql://user:pass@host:5432/db"
1219 )
1220 .is_err());
1221
1222 // Same URL with default port vs explicit port
1223 assert!(validate_source_target_different(
1224 "postgresql://user:pass@host/db",
1225 "postgresql://user:pass@host:5432/db"
1226 )
1227 .is_err());
1228
1229 // Same URL with different query parameters (still same database)
1230 assert!(validate_source_target_different(
1231 "postgresql://user:pass@host:5432/db?sslmode=require",
1232 "postgresql://user:pass@host:5432/db?sslmode=prefer"
1233 )
1234 .is_err());
1235
1236 // Same host with different case (hostnames are case-insensitive)
1237 assert!(validate_source_target_different(
1238 "postgresql://user:pass@HOST.COM:5432/db",
1239 "postgresql://user:pass@host.com:5432/db"
1240 )
1241 .is_err());
1242 }
1243
1244 #[test]
1245 fn test_parse_postgres_url() {
1246 // Full URL with all components including password
1247 let parts = parse_postgres_url("postgresql://myuser:mypass@localhost:5432/mydb").unwrap();
1248 assert_eq!(parts.host, "localhost");
1249 assert_eq!(parts.port, 5432);
1250 assert_eq!(parts.database, "mydb");
1251 assert_eq!(parts.user, Some("myuser".to_string()));
1252 assert_eq!(parts.password, Some("mypass".to_string()));
1253
1254 // URL without port (should default to 5432)
1255 let parts = parse_postgres_url("postgresql://user:pass@host/db").unwrap();
1256 assert_eq!(parts.host, "host");
1257 assert_eq!(parts.port, 5432);
1258 assert_eq!(parts.database, "db");
1259 assert_eq!(parts.user, Some("user".to_string()));
1260 assert_eq!(parts.password, Some("pass".to_string()));
1261
1262 // URL with user but no password
1263 let parts = parse_postgres_url("postgresql://user@host/db").unwrap();
1264 assert_eq!(parts.host, "host");
1265 assert_eq!(parts.user, Some("user".to_string()));
1266 assert_eq!(parts.password, None);
1267
1268 // URL without authentication
1269 let parts = parse_postgres_url("postgresql://host:5433/db").unwrap();
1270 assert_eq!(parts.host, "host");
1271 assert_eq!(parts.port, 5433);
1272 assert_eq!(parts.database, "db");
1273 assert_eq!(parts.user, None);
1274 assert_eq!(parts.password, None);
1275
1276 // URL with query parameters
1277 let parts = parse_postgres_url("postgresql://user:pass@host/db?sslmode=require").unwrap();
1278 assert_eq!(parts.host, "host");
1279 assert_eq!(parts.database, "db");
1280 assert_eq!(parts.password, Some("pass".to_string()));
1281
1282 // URL with postgres:// scheme (alternative)
1283 let parts = parse_postgres_url("postgres://user:pass@host/db").unwrap();
1284 assert_eq!(parts.host, "host");
1285 assert_eq!(parts.database, "db");
1286 assert_eq!(parts.password, Some("pass".to_string()));
1287
1288 // Host normalization (lowercase)
1289 let parts = parse_postgres_url("postgresql://user:pass@HOST.COM/db").unwrap();
1290 assert_eq!(parts.host, "host.com");
1291 assert_eq!(parts.password, Some("pass".to_string()));
1292
1293 // Password with special characters
1294 let parts = parse_postgres_url("postgresql://user:p@ss!word@host/db").unwrap();
1295 assert_eq!(parts.password, Some("p@ss!word".to_string()));
1296 }
1297
1298 #[test]
1299 fn test_validate_postgres_identifier_valid() {
1300 // Valid identifiers
1301 assert!(validate_postgres_identifier("mydb").is_ok());
1302 assert!(validate_postgres_identifier("my_database").is_ok());
1303 assert!(validate_postgres_identifier("_private_db").is_ok());
1304 assert!(validate_postgres_identifier("db123").is_ok());
1305 assert!(validate_postgres_identifier("Database_2024").is_ok());
1306
1307 // Maximum length (63 characters)
1308 let max_length_name = "a".repeat(63);
1309 assert!(validate_postgres_identifier(&max_length_name).is_ok());
1310 }
1311
1312 #[test]
1313 fn test_pgpass_file_creation() {
1314 let parts = PostgresUrlParts {
1315 host: "localhost".to_string(),
1316 port: 5432,
1317 database: "testdb".to_string(),
1318 user: Some("testuser".to_string()),
1319 password: Some("testpass".to_string()),
1320 query_params: std::collections::HashMap::new(),
1321 };
1322
1323 let pgpass = PgPassFile::new(&parts).unwrap();
1324 assert!(pgpass.path().exists());
1325
1326 // Verify file content
1327 let content = std::fs::read_to_string(pgpass.path()).unwrap();
1328 assert_eq!(content, "localhost:5432:testdb:testuser:testpass\n");
1329
1330 // Verify permissions on Unix
1331 #[cfg(unix)]
1332 {
1333 use std::os::unix::fs::PermissionsExt;
1334 let metadata = std::fs::metadata(pgpass.path()).unwrap();
1335 let permissions = metadata.permissions();
1336 assert_eq!(permissions.mode() & 0o777, 0o600);
1337 }
1338
1339 // File should be cleaned up when pgpass is dropped
1340 let path = pgpass.path().to_path_buf();
1341 drop(pgpass);
1342 assert!(!path.exists());
1343 }
1344
1345 #[test]
1346 fn test_pgpass_file_without_password() {
1347 let parts = PostgresUrlParts {
1348 host: "localhost".to_string(),
1349 port: 5432,
1350 database: "testdb".to_string(),
1351 user: Some("testuser".to_string()),
1352 password: None,
1353 query_params: std::collections::HashMap::new(),
1354 };
1355
1356 let pgpass = PgPassFile::new(&parts).unwrap();
1357 let content = std::fs::read_to_string(pgpass.path()).unwrap();
1358 // Should use empty password
1359 assert_eq!(content, "localhost:5432:testdb:testuser:\n");
1360 }
1361
1362 #[test]
1363 fn test_pgpass_file_without_user() {
1364 let parts = PostgresUrlParts {
1365 host: "localhost".to_string(),
1366 port: 5432,
1367 database: "testdb".to_string(),
1368 user: None,
1369 password: Some("testpass".to_string()),
1370 query_params: std::collections::HashMap::new(),
1371 };
1372
1373 let pgpass = PgPassFile::new(&parts).unwrap();
1374 let content = std::fs::read_to_string(pgpass.path()).unwrap();
1375 // Should use wildcard for user
1376 assert_eq!(content, "localhost:5432:testdb:*:testpass\n");
1377 }
1378
1379 #[test]
1380 fn test_strip_password_from_url() {
1381 // With password
1382 let url = "postgresql://user:p@ssw0rd@host:5432/db";
1383 let stripped = strip_password_from_url(url).unwrap();
1384 assert_eq!(stripped, "postgresql://user@host:5432/db");
1385
1386 // With special characters in password
1387 let url = "postgresql://user:p@ss!w0rd@host:5432/db";
1388 let stripped = strip_password_from_url(url).unwrap();
1389 assert_eq!(stripped, "postgresql://user@host:5432/db");
1390
1391 // Without password
1392 let url = "postgresql://user@host:5432/db";
1393 let stripped = strip_password_from_url(url).unwrap();
1394 assert_eq!(stripped, "postgresql://user@host:5432/db");
1395
1396 // With query parameters
1397 let url = "postgresql://user:pass@host:5432/db?sslmode=require";
1398 let stripped = strip_password_from_url(url).unwrap();
1399 assert_eq!(stripped, "postgresql://user@host:5432/db?sslmode=require");
1400
1401 // No user
1402 let url = "postgresql://host:5432/db";
1403 let stripped = strip_password_from_url(url).unwrap();
1404 assert_eq!(stripped, "postgresql://host:5432/db");
1405 }
1406
1407 #[test]
1408 fn test_validate_postgres_identifier_invalid() {
1409 // SQL injection attempts
1410 assert!(validate_postgres_identifier("mydb\"; DROP DATABASE production; --").is_err());
1411 assert!(validate_postgres_identifier("db'; DELETE FROM users; --").is_err());
1412
1413 // Invalid start characters
1414 assert!(validate_postgres_identifier("123db").is_err()); // Starts with digit
1415 assert!(validate_postgres_identifier("$db").is_err()); // Starts with special char
1416 assert!(validate_postgres_identifier("-db").is_err()); // Starts with dash
1417
1418 // Contains invalid characters
1419 assert!(validate_postgres_identifier("my-database").is_err()); // Contains dash
1420 assert!(validate_postgres_identifier("my.database").is_err()); // Contains dot
1421 assert!(validate_postgres_identifier("my database").is_err()); // Contains space
1422 assert!(validate_postgres_identifier("my@db").is_err()); // Contains @
1423 assert!(validate_postgres_identifier("my#db").is_err()); // Contains #
1424
1425 // Empty or too long
1426 assert!(validate_postgres_identifier("").is_err());
1427 assert!(validate_postgres_identifier(" ").is_err());
1428
1429 // Over maximum length (64+ characters)
1430 let too_long = "a".repeat(64);
1431 assert!(validate_postgres_identifier(&too_long).is_err());
1432
1433 // Control characters
1434 assert!(validate_postgres_identifier("my\ndb").is_err());
1435 assert!(validate_postgres_identifier("my\tdb").is_err());
1436 assert!(validate_postgres_identifier("my\x00db").is_err());
1437 }
1438}