database_replicator/utils.rs
1// ABOUTME: Utility functions for validation and error handling
2// ABOUTME: Provides input validation, retry logic, and resource cleanup
3
4use anyhow::{bail, Context, Result};
5use std::time::Duration;
6use which::which;
7
8/// Get TCP keepalive environment variables for PostgreSQL client tools
9///
10/// Returns environment variables that configure TCP keepalives for external
11/// PostgreSQL tools (pg_dump, pg_restore, psql, pg_dumpall). These prevent
12/// idle connection timeouts when connecting through load balancers like AWS ELB.
13///
14/// Environment variables returned:
15/// - `PGKEEPALIVES=1`: Enable TCP keepalives
16/// - `PGKEEPALIVESIDLE=60`: Send first keepalive after 60 seconds of idle time
17/// - `PGKEEPALIVESINTERVAL=10`: Send subsequent keepalives every 10 seconds
18///
19/// # Returns
20///
21/// A vector of (variable_name, value) tuples to be passed to subprocess commands
22///
23/// # Examples
24///
25/// ```
26/// # use database_replicator::utils::get_keepalive_env_vars;
27/// # use std::process::Command;
28/// let keepalive_vars = get_keepalive_env_vars();
29/// let mut cmd = Command::new("psql");
30/// for (key, value) in keepalive_vars {
31/// cmd.env(key, value);
32/// }
33/// ```
34pub fn get_keepalive_env_vars() -> Vec<(&'static str, &'static str)> {
35 vec![
36 ("PGKEEPALIVES", "1"),
37 ("PGKEEPALIVESIDLE", "60"),
38 ("PGKEEPALIVESINTERVAL", "10"),
39 ]
40}
41
42/// Validate a PostgreSQL connection string
43///
44/// Checks that the connection string has proper format and required components:
45/// - Starts with "postgres://" or "postgresql://"
46/// - Contains user credentials (@ symbol)
47/// - Contains database name (/ separator with at least 3 occurrences)
48///
49/// # Arguments
50///
51/// * `url` - Connection string to validate
52///
53/// # Returns
54///
55/// Returns `Ok(())` if the connection string is valid.
56///
57/// # Errors
58///
59/// Returns an error with helpful message if the connection string is:
60/// - Empty or whitespace only
61/// - Missing proper scheme (postgres:// or postgresql://)
62/// - Missing user credentials (@ symbol)
63/// - Missing database name
64///
65/// # Examples
66///
67/// ```
68/// # use database_replicator::utils::validate_connection_string;
69/// # use anyhow::Result;
70/// # fn example() -> Result<()> {
71/// // Valid connection strings
72/// validate_connection_string("postgresql://user:pass@localhost:5432/mydb")?;
73/// validate_connection_string("postgres://user@host/db")?;
74///
75/// // Invalid - will return error
76/// assert!(validate_connection_string("").is_err());
77/// assert!(validate_connection_string("mysql://localhost/db").is_err());
78/// # Ok(())
79/// # }
80/// ```
81pub fn validate_connection_string(url: &str) -> Result<()> {
82 if url.trim().is_empty() {
83 bail!("Connection string cannot be empty");
84 }
85
86 // Check for common URL schemes
87 if !url.starts_with("postgres://") && !url.starts_with("postgresql://") {
88 bail!(
89 "Invalid connection string format.\n\
90 Expected format: postgresql://user:password@host:port/database\n\
91 Got: {}",
92 url
93 );
94 }
95
96 // Check for minimum required components (user@host/database)
97 if !url.contains('@') {
98 bail!(
99 "Connection string missing user credentials.\n\
100 Expected format: postgresql://user:password@host:port/database"
101 );
102 }
103
104 if !url.contains('/') || url.matches('/').count() < 3 {
105 bail!(
106 "Connection string missing database name.\n\
107 Expected format: postgresql://user:password@host:port/database"
108 );
109 }
110
111 Ok(())
112}
113
114/// Check that required PostgreSQL client tools are available
115///
116/// Verifies that the following tools are installed and in PATH:
117/// - `pg_dump` - For dumping database schema and data
118/// - `pg_dumpall` - For dumping global objects (roles, tablespaces)
119/// - `psql` - For restoring databases
120///
121/// # Returns
122///
123/// Returns `Ok(())` if all required tools are found.
124///
125/// # Errors
126///
127/// Returns an error with installation instructions if any tools are missing.
128///
129/// # Examples
130///
131/// ```
132/// # use database_replicator::utils::check_required_tools;
133/// # use anyhow::Result;
134/// # fn example() -> Result<()> {
135/// // Check if PostgreSQL tools are installed
136/// check_required_tools()?;
137/// # Ok(())
138/// # }
139/// ```
140pub fn check_required_tools() -> Result<()> {
141 let tools = ["pg_dump", "pg_dumpall", "psql"];
142 let mut missing = Vec::new();
143
144 for tool in &tools {
145 if which(tool).is_err() {
146 missing.push(*tool);
147 }
148 }
149
150 if !missing.is_empty() {
151 bail!(
152 "Missing required PostgreSQL client tools: {}\n\
153 \n\
154 Please install PostgreSQL client tools:\n\
155 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
156 - macOS: brew install postgresql\n\
157 - RHEL/CentOS: sudo yum install postgresql\n\
158 - Windows: Download from https://www.postgresql.org/download/windows/",
159 missing.join(", ")
160 );
161 }
162
163 Ok(())
164}
165
166/// Retry a function with exponential backoff
167///
168/// Executes an async operation with automatic retry on failure. Each retry doubles
169/// the delay (exponential backoff) to handle transient failures gracefully.
170///
171/// # Arguments
172///
173/// * `operation` - Async function to retry (FnMut returning Future\<Output = Result\<T\>\>)
174/// * `max_retries` - Maximum number of retry attempts (0 = no retries, just initial attempt)
175/// * `initial_delay` - Delay before first retry (doubles each subsequent retry)
176///
177/// # Returns
178///
179/// Returns the successful result or the last error after all retries exhausted.
180///
181/// # Examples
182///
183/// ```no_run
184/// # use anyhow::Result;
185/// # use std::time::Duration;
186/// # use database_replicator::utils::retry_with_backoff;
187/// # async fn example() -> Result<()> {
188/// let result = retry_with_backoff(
189/// || async { Ok("success") },
190/// 3, // Try up to 3 times
191/// Duration::from_secs(1) // Start with 1s delay
192/// ).await?;
193/// # Ok(())
194/// # }
195/// ```
196pub async fn retry_with_backoff<F, Fut, T>(
197 mut operation: F,
198 max_retries: u32,
199 initial_delay: Duration,
200) -> Result<T>
201where
202 F: FnMut() -> Fut,
203 Fut: std::future::Future<Output = Result<T>>,
204{
205 let mut delay = initial_delay;
206 let mut last_error = None;
207
208 for attempt in 0..=max_retries {
209 match operation().await {
210 Ok(result) => return Ok(result),
211 Err(e) => {
212 last_error = Some(e);
213
214 if attempt < max_retries {
215 tracing::warn!(
216 "Operation failed (attempt {}/{}), retrying in {:?}...",
217 attempt + 1,
218 max_retries + 1,
219 delay
220 );
221 tokio::time::sleep(delay).await;
222 delay *= 2; // Exponential backoff
223 }
224 }
225 }
226 }
227
228 Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Operation failed after retries")))
229}
230
231/// Retry a subprocess execution with exponential backoff on connection errors
232///
233/// Executes a subprocess command with automatic retry on connection-related failures.
234/// Each retry doubles the delay (exponential backoff) to handle transient connection issues.
235///
236/// Connection errors are detected by checking:
237/// - Non-zero exit codes
238/// - Stderr output containing connection-related error patterns:
239/// - "connection closed"
240/// - "connection refused"
241/// - "could not connect"
242/// - "server closed the connection"
243/// - "timeout"
244/// - "Connection timed out"
245///
246/// # Arguments
247///
248/// * `operation` - Function that executes a Command and returns the exit status
249/// * `max_retries` - Maximum number of retry attempts (0 = no retries, just initial attempt)
250/// * `initial_delay` - Delay before first retry (doubles each subsequent retry)
251/// * `operation_name` - Name of the operation for logging (e.g., "pg_restore", "psql")
252///
253/// # Returns
254///
255/// Returns Ok(()) on success or the last error after all retries exhausted.
256///
257/// # Examples
258///
259/// ```no_run
260/// # use anyhow::Result;
261/// # use std::time::Duration;
262/// # use std::process::Command;
263/// # use database_replicator::utils::retry_subprocess_with_backoff;
264/// # async fn example() -> Result<()> {
265/// retry_subprocess_with_backoff(
266/// || {
267/// let mut cmd = Command::new("psql");
268/// cmd.arg("--version");
269/// cmd.status().map_err(anyhow::Error::from)
270/// },
271/// 3, // Try up to 3 times
272/// Duration::from_secs(1), // Start with 1s delay
273/// "psql"
274/// ).await?;
275/// # Ok(())
276/// # }
277/// ```
278pub async fn retry_subprocess_with_backoff<F>(
279 mut operation: F,
280 max_retries: u32,
281 initial_delay: Duration,
282 operation_name: &str,
283) -> Result<()>
284where
285 F: FnMut() -> Result<std::process::ExitStatus>,
286{
287 let mut delay = initial_delay;
288 let mut last_error = None;
289
290 for attempt in 0..=max_retries {
291 match operation() {
292 Ok(status) => {
293 if status.success() {
294 return Ok(());
295 } else {
296 // Non-zero exit code - check if it's a connection error
297 // We can't easily capture stderr here, so we'll treat all non-zero
298 // exit codes as potential connection errors for now
299 let error = anyhow::anyhow!(
300 "{} failed with exit code: {}",
301 operation_name,
302 status.code().unwrap_or(-1)
303 );
304 last_error = Some(error);
305
306 if attempt < max_retries {
307 tracing::warn!(
308 "{} failed (attempt {}/{}), retrying in {:?}...",
309 operation_name,
310 attempt + 1,
311 max_retries + 1,
312 delay
313 );
314 tokio::time::sleep(delay).await;
315 delay *= 2; // Exponential backoff
316 }
317 }
318 }
319 Err(e) => {
320 last_error = Some(e);
321
322 if attempt < max_retries {
323 tracing::warn!(
324 "{} failed (attempt {}/{}): {}, retrying in {:?}...",
325 operation_name,
326 attempt + 1,
327 max_retries + 1,
328 last_error.as_ref().unwrap(),
329 delay
330 );
331 tokio::time::sleep(delay).await;
332 delay *= 2; // Exponential backoff
333 }
334 }
335 }
336 }
337
338 Err(last_error.unwrap_or_else(|| {
339 anyhow::anyhow!("{} failed after {} retries", operation_name, max_retries)
340 }))
341}
342
343/// Validate a PostgreSQL identifier (database name, schema name, etc.)
344///
345/// Validates that an identifier follows PostgreSQL naming rules to prevent SQL injection.
346/// PostgreSQL identifiers must:
347/// - Be 1-63 characters long
348/// - Start with a letter (a-z, A-Z) or underscore (_)
349/// - Contain only letters, digits (0-9), or underscores
350///
351/// # Arguments
352///
353/// * `identifier` - The identifier to validate (database name, schema name, etc.)
354///
355/// # Returns
356///
357/// Returns `Ok(())` if the identifier is valid.
358///
359/// # Errors
360///
361/// Returns an error if the identifier:
362/// - Is empty or whitespace-only
363/// - Exceeds 63 characters
364/// - Starts with an invalid character (digit or special character)
365/// - Contains invalid characters (anything except a-z, A-Z, 0-9, _)
366///
367/// # Security
368///
369/// This function is critical for preventing SQL injection attacks. All database
370/// names, schema names, and table names from untrusted sources MUST be validated
371/// before use in SQL statements.
372///
373/// # Examples
374///
375/// ```
376/// # use database_replicator::utils::validate_postgres_identifier;
377/// # use anyhow::Result;
378/// # fn example() -> Result<()> {
379/// // Valid identifiers
380/// validate_postgres_identifier("mydb")?;
381/// validate_postgres_identifier("my_database")?;
382/// validate_postgres_identifier("_private_db")?;
383///
384/// // Invalid - will return error
385/// assert!(validate_postgres_identifier("123db").is_err());
386/// assert!(validate_postgres_identifier("my-database").is_err());
387/// assert!(validate_postgres_identifier("db\"; DROP TABLE users; --").is_err());
388/// # Ok(())
389/// # }
390/// ```
391pub fn validate_postgres_identifier(identifier: &str) -> Result<()> {
392 // Check for empty or whitespace-only
393 let trimmed = identifier.trim();
394 if trimmed.is_empty() {
395 bail!("Identifier cannot be empty or whitespace-only");
396 }
397
398 // Check length (PostgreSQL limit is 63 characters)
399 if trimmed.len() > 63 {
400 bail!(
401 "Identifier '{}' exceeds maximum length of 63 characters (got {})",
402 sanitize_identifier(trimmed),
403 trimmed.len()
404 );
405 }
406
407 // Get first character
408 let first_char = trimmed.chars().next().unwrap();
409
410 // First character must be a letter or underscore
411 if !first_char.is_ascii_alphabetic() && first_char != '_' {
412 bail!(
413 "Identifier '{}' must start with a letter or underscore, not '{}'",
414 sanitize_identifier(trimmed),
415 first_char
416 );
417 }
418
419 // All characters must be alphanumeric or underscore
420 for (i, c) in trimmed.chars().enumerate() {
421 if !c.is_ascii_alphanumeric() && c != '_' {
422 bail!(
423 "Identifier '{}' contains invalid character '{}' at position {}. \
424 Only letters, digits, and underscores are allowed",
425 sanitize_identifier(trimmed),
426 if c.is_control() {
427 format!("\\x{:02x}", c as u32)
428 } else {
429 c.to_string()
430 },
431 i
432 );
433 }
434 }
435
436 Ok(())
437}
438
439/// Sanitize an identifier (table name, schema name, etc.) for display
440///
441/// Removes control characters and limits length to prevent log injection attacks
442/// and ensure readable error messages.
443///
444/// **Note**: This is for display purposes only. For SQL safety, use parameterized
445/// queries instead.
446///
447/// # Arguments
448///
449/// * `identifier` - The identifier to sanitize (table name, schema name, etc.)
450///
451/// # Returns
452///
453/// Sanitized string with control characters removed and length limited to 100 chars.
454///
455/// # Examples
456///
457/// ```
458/// # use database_replicator::utils::sanitize_identifier;
459/// assert_eq!(sanitize_identifier("normal_table"), "normal_table");
460/// assert_eq!(sanitize_identifier("table\x00name"), "tablename");
461/// assert_eq!(sanitize_identifier("table\nname"), "tablename");
462///
463/// // Length limit
464/// let long_name = "a".repeat(200);
465/// assert_eq!(sanitize_identifier(&long_name).len(), 100);
466/// ```
467pub fn sanitize_identifier(identifier: &str) -> String {
468 // Remove any control characters and limit length for display
469 identifier
470 .chars()
471 .filter(|c| !c.is_control())
472 .take(100)
473 .collect()
474}
475
476/// Quote a PostgreSQL identifier (database, schema, table, column)
477///
478/// Assumes the identifier has already been validated. Escapes embedded quotes
479/// and wraps the identifier in double quotes.
480pub fn quote_ident(identifier: &str) -> String {
481 let mut quoted = String::with_capacity(identifier.len() + 2);
482 quoted.push('"');
483 for ch in identifier.chars() {
484 if ch == '"' {
485 quoted.push('"');
486 }
487 quoted.push(ch);
488 }
489 quoted.push('"');
490 quoted
491}
492
493/// Quote a SQL string literal (for use in SQL statements)
494///
495/// Escapes single quotes by doubling them and wraps the string in single quotes.
496/// Use this for string values in SQL, not for identifiers.
497///
498/// # Examples
499///
500/// ```
501/// use database_replicator::utils::quote_literal;
502/// assert_eq!(quote_literal("hello"), "'hello'");
503/// assert_eq!(quote_literal("it's"), "'it''s'");
504/// assert_eq!(quote_literal(""), "''");
505/// ```
506pub fn quote_literal(value: &str) -> String {
507 let mut quoted = String::with_capacity(value.len() + 2);
508 quoted.push('\'');
509 for ch in value.chars() {
510 if ch == '\'' {
511 quoted.push('\'');
512 }
513 quoted.push(ch);
514 }
515 quoted.push('\'');
516 quoted
517}
518
519/// Quote a MySQL identifier (database, table, column)
520///
521/// MySQL uses backticks for identifier quoting. Escapes embedded backticks
522/// by doubling them.
523///
524/// # Examples
525///
526/// ```
527/// use database_replicator::utils::quote_mysql_ident;
528/// assert_eq!(quote_mysql_ident("users"), "`users`");
529/// assert_eq!(quote_mysql_ident("user`name"), "`user``name`");
530/// ```
531pub fn quote_mysql_ident(identifier: &str) -> String {
532 let mut quoted = String::with_capacity(identifier.len() + 2);
533 quoted.push('`');
534 for ch in identifier.chars() {
535 if ch == '`' {
536 quoted.push('`');
537 }
538 quoted.push(ch);
539 }
540 quoted.push('`');
541 quoted
542}
543
544/// Validate that source and target URLs are different to prevent accidental data loss
545///
546/// Compares two PostgreSQL connection URLs to ensure they point to different databases.
547/// This is critical for preventing data loss from operations like `init --drop-existing`
548/// where using the same URL for source and target would destroy the source data.
549///
550/// # Comparison Strategy
551///
552/// URLs are normalized and compared on:
553/// - Host (case-insensitive)
554/// - Port (defaulting to 5432 if not specified)
555/// - Database name (case-sensitive)
556/// - User (if present)
557///
558/// Query parameters (like SSL settings) are ignored as they don't affect database identity.
559///
560/// # Arguments
561///
562/// * `source_url` - Source database connection string
563/// * `target_url` - Target database connection string
564///
565/// # Returns
566///
567/// Returns `Ok(())` if the URLs point to different databases.
568///
569/// # Errors
570///
571/// Returns an error if:
572/// - The URLs point to the same database (same host, port, database name, and user)
573/// - Either URL is malformed and cannot be parsed
574///
575/// # Examples
576///
577/// ```
578/// # use database_replicator::utils::validate_source_target_different;
579/// # use anyhow::Result;
580/// # fn example() -> Result<()> {
581/// // Valid - different hosts
582/// validate_source_target_different(
583/// "postgresql://user:pass@source.com:5432/db",
584/// "postgresql://user:pass@target.com:5432/db"
585/// )?;
586///
587/// // Valid - different databases
588/// validate_source_target_different(
589/// "postgresql://user:pass@host:5432/db1",
590/// "postgresql://user:pass@host:5432/db2"
591/// )?;
592///
593/// // Invalid - same database
594/// assert!(validate_source_target_different(
595/// "postgresql://user:pass@host:5432/db",
596/// "postgresql://user:pass@host:5432/db"
597/// ).is_err());
598/// # Ok(())
599/// # }
600/// ```
601pub fn validate_source_target_different(source_url: &str, target_url: &str) -> Result<()> {
602 // Parse both URLs to extract components
603 let source_parts = parse_postgres_url(source_url)
604 .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
605 let target_parts = parse_postgres_url(target_url)
606 .with_context(|| format!("Failed to parse target URL: {}", target_url))?;
607
608 // Compare normalized components
609 if source_parts.host == target_parts.host
610 && source_parts.port == target_parts.port
611 && source_parts.database == target_parts.database
612 && source_parts.user == target_parts.user
613 {
614 bail!(
615 "Source and target URLs point to the same database!\n\
616 \n\
617 This would cause DATA LOSS - the target would overwrite the source.\n\
618 \n\
619 Source: {}@{}:{}/{}\n\
620 Target: {}@{}:{}/{}\n\
621 \n\
622 Please ensure source and target are different databases.\n\
623 Common causes:\n\
624 - Copy-paste error in connection strings\n\
625 - Wrong environment variables (e.g., SOURCE_URL == TARGET_URL)\n\
626 - Typo in database name or host",
627 source_parts.user.as_deref().unwrap_or("(no user)"),
628 source_parts.host,
629 source_parts.port,
630 source_parts.database,
631 target_parts.user.as_deref().unwrap_or("(no user)"),
632 target_parts.host,
633 target_parts.port,
634 target_parts.database
635 );
636 }
637
638 Ok(())
639}
640
641/// Parse a PostgreSQL URL into its components
642///
643/// # Arguments
644///
645/// * `url` - PostgreSQL connection URL (postgres:// or postgresql://)
646///
647/// # Returns
648///
649/// Returns a `PostgresUrlParts` struct with normalized components.
650///
651/// # Security
652///
653/// This function extracts passwords from URLs for use with .pgpass files.
654/// Ensure returned values are handled securely and not logged.
655pub fn parse_postgres_url(url: &str) -> Result<PostgresUrlParts> {
656 // Remove scheme
657 let url_without_scheme = url
658 .trim_start_matches("postgres://")
659 .trim_start_matches("postgresql://");
660
661 // Split into base and query params
662 let (base, query_string) = if let Some((b, q)) = url_without_scheme.split_once('?') {
663 (b, Some(q))
664 } else {
665 (url_without_scheme, None)
666 };
667
668 // Parse query parameters into HashMap
669 let mut query_params = std::collections::HashMap::new();
670 if let Some(query) = query_string {
671 for param in query.split('&') {
672 if let Some((key, value)) = param.split_once('=') {
673 query_params.insert(key.to_string(), value.to_string());
674 }
675 }
676 }
677
678 // Parse: [user[:password]@]host[:port]/database
679 let (auth_and_host, database) = base
680 .rsplit_once('/')
681 .ok_or_else(|| anyhow::anyhow!("Missing database name in URL"))?;
682
683 // Parse authentication and host
684 // Use rsplit_once to split from the right, so passwords can contain '@'
685 let (user, password, host_and_port) = if let Some((auth, hp)) = auth_and_host.rsplit_once('@') {
686 // Has authentication
687 let (user, pass) = if let Some((u, p)) = auth.split_once(':') {
688 (Some(u.to_string()), Some(p.to_string()))
689 } else {
690 (Some(auth.to_string()), None)
691 };
692 (user, pass, hp)
693 } else {
694 // No authentication
695 (None, None, auth_and_host)
696 };
697
698 // Parse host and port
699 let (host, port) = if let Some((h, p)) = host_and_port.rsplit_once(':') {
700 // Port specified
701 let port = p
702 .parse::<u16>()
703 .with_context(|| format!("Invalid port number: {}", p))?;
704 (h, port)
705 } else {
706 // Use default PostgreSQL port
707 (host_and_port, 5432)
708 };
709
710 Ok(PostgresUrlParts {
711 host: host.to_lowercase(), // Hostnames are case-insensitive
712 port,
713 database: database.to_string(), // Database names are case-sensitive in PostgreSQL
714 user,
715 password,
716 query_params,
717 })
718}
719
720/// Strip password from PostgreSQL connection URL
721/// Returns a new URL with password removed, preserving all other components
722/// This is useful for storing connection strings in places where passwords should not be visible
723pub fn strip_password_from_url(url: &str) -> Result<String> {
724 let parts = parse_postgres_url(url)?;
725
726 // Reconstruct URL without password
727 let scheme = if url.starts_with("postgresql://") {
728 "postgresql://"
729 } else if url.starts_with("postgres://") {
730 "postgres://"
731 } else {
732 bail!("Invalid PostgreSQL URL scheme");
733 };
734
735 let mut result = String::from(scheme);
736
737 // Add user if present (without password)
738 if let Some(user) = &parts.user {
739 result.push_str(user);
740 result.push('@');
741 }
742
743 // Add host and port
744 result.push_str(&parts.host);
745 result.push(':');
746 result.push_str(&parts.port.to_string());
747
748 // Add database
749 result.push('/');
750 result.push_str(&parts.database);
751
752 // Preserve query parameters if present
753 if let Some(query_start) = url.find('?') {
754 result.push_str(&url[query_start..]);
755 }
756
757 Ok(result)
758}
759
760/// Parsed components of a PostgreSQL connection URL
761#[derive(Debug, PartialEq)]
762pub struct PostgresUrlParts {
763 pub host: String,
764 pub port: u16,
765 pub database: String,
766 pub user: Option<String>,
767 pub password: Option<String>,
768 pub query_params: std::collections::HashMap<String, String>,
769}
770
771impl PostgresUrlParts {
772 /// Convert query parameters to PostgreSQL environment variables
773 ///
774 /// Maps common connection URL query parameters to their corresponding
775 /// PostgreSQL environment variable names. This allows SSL/TLS and other
776 /// connection settings to be passed to pg_dump, pg_dumpall, psql, etc.
777 ///
778 /// # Supported Parameters
779 ///
780 /// - `sslmode` → `PGSSLMODE`
781 /// - `sslcert` → `PGSSLCERT`
782 /// - `sslkey` → `PGSSLKEY`
783 /// - `sslrootcert` → `PGSSLROOTCERT`
784 /// - `channel_binding` → `PGCHANNELBINDING`
785 /// - `connect_timeout` → `PGCONNECT_TIMEOUT`
786 /// - `application_name` → `PGAPPNAME`
787 /// - `client_encoding` → `PGCLIENTENCODING`
788 ///
789 /// # Returns
790 ///
791 /// Vec of (env_var_name, value) pairs to be set as environment variables
792 pub fn to_pg_env_vars(&self) -> Vec<(&'static str, String)> {
793 let mut env_vars = Vec::new();
794
795 // Map query parameters to PostgreSQL environment variables
796 let param_mapping = [
797 ("sslmode", "PGSSLMODE"),
798 ("sslcert", "PGSSLCERT"),
799 ("sslkey", "PGSSLKEY"),
800 ("sslrootcert", "PGSSLROOTCERT"),
801 ("channel_binding", "PGCHANNELBINDING"),
802 ("connect_timeout", "PGCONNECT_TIMEOUT"),
803 ("application_name", "PGAPPNAME"),
804 ("client_encoding", "PGCLIENTENCODING"),
805 ];
806
807 for (param_name, env_var_name) in param_mapping {
808 if let Some(value) = self.query_params.get(param_name) {
809 env_vars.push((env_var_name, value.clone()));
810 }
811 }
812
813 env_vars
814 }
815}
816
817/// Managed .pgpass file for secure password passing to PostgreSQL tools
818///
819/// This struct creates a temporary .pgpass file with secure permissions (0600)
820/// and automatically cleans it up when dropped. PostgreSQL command-line tools
821/// read credentials from this file instead of accepting passwords in URLs,
822/// which prevents command injection vulnerabilities.
823///
824/// # Security
825///
826/// - File permissions are set to 0600 (owner read/write only)
827/// - File is automatically removed on Drop
828/// - Credentials are never passed on command line
829///
830/// # Format
831///
832/// .pgpass file format: hostname:port:database:username:password
833/// Wildcards (*) are used for maximum compatibility
834///
835/// # Examples
836///
837/// ```no_run
838/// # use database_replicator::utils::{PgPassFile, parse_postgres_url};
839/// # use anyhow::Result;
840/// # fn example() -> Result<()> {
841/// let url = "postgresql://user:pass@localhost:5432/mydb";
842/// let parts = parse_postgres_url(url)?;
843/// let pgpass = PgPassFile::new(&parts)?;
844///
845/// // Use pgpass.path() with PGPASSFILE environment variable
846/// // File is automatically cleaned up when pgpass goes out of scope
847/// # Ok(())
848/// # }
849/// ```
850pub struct PgPassFile {
851 path: std::path::PathBuf,
852}
853
854impl PgPassFile {
855 /// Create a new .pgpass file with credentials from URL parts
856 ///
857 /// # Arguments
858 ///
859 /// * `parts` - Parsed PostgreSQL URL components
860 ///
861 /// # Returns
862 ///
863 /// Returns a PgPassFile that will be automatically cleaned up on Drop
864 ///
865 /// # Errors
866 ///
867 /// Returns an error if the file cannot be created or permissions cannot be set
868 pub fn new(parts: &PostgresUrlParts) -> Result<Self> {
869 use std::fs;
870 use std::io::Write;
871
872 // Create temp file with secure name
873 let temp_dir = std::env::temp_dir();
874 let random: u32 = rand::random();
875 let filename = format!("pgpass-{:08x}", random);
876 let path = temp_dir.join(filename);
877
878 // Write .pgpass entry
879 // Format: hostname:port:database:username:password
880 let username = parts.user.as_deref().unwrap_or("*");
881 let password = parts.password.as_deref().unwrap_or("");
882 let entry = format!(
883 "{}:{}:{}:{}:{}\n",
884 parts.host, parts.port, parts.database, username, password
885 );
886
887 let mut file = fs::File::create(&path)
888 .with_context(|| format!("Failed to create .pgpass file at {}", path.display()))?;
889
890 file.write_all(entry.as_bytes())
891 .with_context(|| format!("Failed to write to .pgpass file at {}", path.display()))?;
892
893 // Set secure permissions (0600) - owner read/write only
894 #[cfg(unix)]
895 {
896 use std::os::unix::fs::PermissionsExt;
897 let permissions = fs::Permissions::from_mode(0o600);
898 fs::set_permissions(&path, permissions).with_context(|| {
899 format!(
900 "Failed to set permissions on .pgpass file at {}",
901 path.display()
902 )
903 })?;
904 }
905
906 // On Windows, .pgpass is stored in %APPDATA%\postgresql\pgpass.conf
907 // but for our temporary use case, we'll just use a temp file
908 // PostgreSQL on Windows also checks permissions but less strictly
909
910 Ok(Self { path })
911 }
912
913 /// Get the path to the .pgpass file
914 ///
915 /// Use this with the PGPASSFILE environment variable when running
916 /// PostgreSQL command-line tools
917 pub fn path(&self) -> &std::path::Path {
918 &self.path
919 }
920}
921
922impl Drop for PgPassFile {
923 fn drop(&mut self) {
924 // Best effort cleanup - don't panic if removal fails
925 let _ = std::fs::remove_file(&self.path);
926 }
927}
928
929/// Create a managed temporary directory with explicit cleanup support
930///
931/// Creates a temporary directory with a timestamped name that can be cleaned up
932/// even if the process is killed with SIGKILL. Unlike `TempDir::new()` which
933/// relies on the Drop trait, this function creates named directories that can
934/// be cleaned up on next process startup.
935///
936/// Directory naming format: `postgres-seren-replicator-{timestamp}-{random}`
937/// Example: `postgres-seren-replicator-20250106-120534-a3b2c1d4`
938///
939/// # Returns
940///
941/// Returns the path to the created temporary directory.
942///
943/// # Errors
944///
945/// Returns an error if the directory cannot be created.
946///
947/// # Examples
948///
949/// ```no_run
950/// # use database_replicator::utils::create_managed_temp_dir;
951/// # use anyhow::Result;
952/// # fn example() -> Result<()> {
953/// let temp_path = create_managed_temp_dir()?;
954/// println!("Using temp directory: {}", temp_path.display());
955/// // ... do work ...
956/// // Cleanup happens automatically on next startup via cleanup_stale_temp_dirs()
957/// # Ok(())
958/// # }
959/// ```
960pub fn create_managed_temp_dir() -> Result<std::path::PathBuf> {
961 use std::fs;
962 use std::time::SystemTime;
963
964 let system_temp = std::env::temp_dir();
965
966 // Generate timestamp for directory name
967 let timestamp = SystemTime::now()
968 .duration_since(SystemTime::UNIX_EPOCH)
969 .unwrap()
970 .as_secs();
971
972 // Generate random suffix for uniqueness
973 let random: u32 = rand::random();
974
975 // Create directory name with timestamp and random suffix
976 let dir_name = format!("postgres-seren-replicator-{}-{:08x}", timestamp, random);
977
978 let temp_path = system_temp.join(dir_name);
979
980 // Create the directory
981 fs::create_dir_all(&temp_path)
982 .with_context(|| format!("Failed to create temp directory at {}", temp_path.display()))?;
983
984 tracing::debug!("Created managed temp directory: {}", temp_path.display());
985
986 Ok(temp_path)
987}
988
989/// Clean up stale temporary directories from previous runs
990///
991/// Removes temporary directories created by `create_managed_temp_dir()` that are
992/// older than the specified age. This should be called on process startup to clean
993/// up directories left behind by processes killed with SIGKILL.
994///
995/// Only directories matching the pattern `postgres-seren-replicator-*` are removed.
996///
997/// # Arguments
998///
999/// * `max_age_secs` - Maximum age in seconds before a directory is considered stale
1000/// (recommended: 86400 for 24 hours)
1001///
1002/// # Returns
1003///
1004/// Returns the number of directories cleaned up.
1005///
1006/// # Errors
1007///
1008/// Returns an error if the system temp directory cannot be read. Individual
1009/// directory removal errors are logged but don't fail the entire operation.
1010///
1011/// # Examples
1012///
1013/// ```no_run
1014/// # use database_replicator::utils::cleanup_stale_temp_dirs;
1015/// # use anyhow::Result;
1016/// # fn example() -> Result<()> {
1017/// // Clean up temp directories older than 24 hours
1018/// let cleaned = cleanup_stale_temp_dirs(86400)?;
1019/// println!("Cleaned up {} stale temp directories", cleaned);
1020/// # Ok(())
1021/// # }
1022/// ```
1023pub fn cleanup_stale_temp_dirs(max_age_secs: u64) -> Result<usize> {
1024 use std::fs;
1025 use std::time::SystemTime;
1026
1027 let system_temp = std::env::temp_dir();
1028 let now = SystemTime::now();
1029 let mut cleaned_count = 0;
1030
1031 // Read all entries in system temp directory
1032 let entries = fs::read_dir(&system_temp).with_context(|| {
1033 format!(
1034 "Failed to read system temp directory: {}",
1035 system_temp.display()
1036 )
1037 })?;
1038
1039 for entry in entries.flatten() {
1040 let path = entry.path();
1041
1042 // Only process directories matching our naming pattern
1043 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1044 if !name.starts_with("postgres-seren-replicator-") {
1045 continue;
1046 }
1047
1048 // Check directory age
1049 match entry.metadata() {
1050 Ok(metadata) => {
1051 if let Ok(modified) = metadata.modified() {
1052 if let Ok(age) = now.duration_since(modified) {
1053 if age.as_secs() > max_age_secs {
1054 // Directory is stale, remove it
1055 match fs::remove_dir_all(&path) {
1056 Ok(_) => {
1057 tracing::info!(
1058 "Cleaned up stale temp directory: {} (age: {}s)",
1059 path.display(),
1060 age.as_secs()
1061 );
1062 cleaned_count += 1;
1063 }
1064 Err(e) => {
1065 tracing::warn!(
1066 "Failed to remove stale temp directory {}: {}",
1067 path.display(),
1068 e
1069 );
1070 }
1071 }
1072 }
1073 }
1074 }
1075 }
1076 Err(e) => {
1077 tracing::warn!(
1078 "Failed to get metadata for temp directory {}: {}",
1079 path.display(),
1080 e
1081 );
1082 }
1083 }
1084 }
1085 }
1086
1087 if cleaned_count > 0 {
1088 tracing::info!(
1089 "Cleaned up {} stale temp directory(ies) older than {} seconds",
1090 cleaned_count,
1091 max_age_secs
1092 );
1093 }
1094
1095 Ok(cleaned_count)
1096}
1097
1098/// Remove a managed temporary directory
1099///
1100/// Explicitly removes a temporary directory created by `create_managed_temp_dir()`.
1101/// This should be called when the directory is no longer needed.
1102///
1103/// # Arguments
1104///
1105/// * `path` - Path to the temporary directory to remove
1106///
1107/// # Errors
1108///
1109/// Returns an error if the directory cannot be removed.
1110///
1111/// # Examples
1112///
1113/// ```no_run
1114/// # use database_replicator::utils::{create_managed_temp_dir, remove_managed_temp_dir};
1115/// # use anyhow::Result;
1116/// # fn example() -> Result<()> {
1117/// let temp_path = create_managed_temp_dir()?;
1118/// // ... do work ...
1119/// remove_managed_temp_dir(&temp_path)?;
1120/// # Ok(())
1121/// # }
1122/// ```
1123pub fn remove_managed_temp_dir(path: &std::path::Path) -> Result<()> {
1124 use std::fs;
1125
1126 // Verify this is one of our temp directories (safety check)
1127 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1128 if !name.starts_with("postgres-seren-replicator-") {
1129 bail!(
1130 "Refusing to remove directory that doesn't match our naming pattern: {}",
1131 path.display()
1132 );
1133 }
1134 } else {
1135 bail!("Invalid temp directory path: {}", path.display());
1136 }
1137
1138 tracing::debug!("Removing managed temp directory: {}", path.display());
1139
1140 fs::remove_dir_all(path)
1141 .with_context(|| format!("Failed to remove temp directory at {}", path.display()))?;
1142
1143 Ok(())
1144}
1145
1146/// Check if a PostgreSQL URL points to a SerenDB instance
1147///
1148/// SerenDB hosts have domains ending with `.serendb.com`
1149///
1150/// # Arguments
1151///
1152/// * `url` - PostgreSQL connection string to check
1153///
1154/// # Returns
1155///
1156/// Returns `true` if the URL points to a SerenDB host.
1157///
1158/// # Examples
1159///
1160/// ```
1161/// use database_replicator::utils::is_serendb_target;
1162///
1163/// assert!(is_serendb_target("postgresql://user:pass@db.serendb.com/mydb"));
1164/// assert!(is_serendb_target("postgresql://user:pass@cluster-123.console.serendb.com/mydb"));
1165/// assert!(!is_serendb_target("postgresql://user:pass@localhost/mydb"));
1166/// assert!(!is_serendb_target("postgresql://user:pass@rds.amazonaws.com/mydb"));
1167/// ```
1168pub fn is_serendb_target(url: &str) -> bool {
1169 match parse_postgres_url(url) {
1170 Ok(parts) => parts.host.ends_with(".serendb.com") || parts.host == "serendb.com",
1171 Err(_) => false,
1172 }
1173}
1174
1175/// Get the major version of a PostgreSQL client tool (pg_dump, psql, etc.)
1176///
1177/// Executes `<tool> --version` and parses the output.
1178///
1179/// # Arguments
1180///
1181/// * `tool` - Name of the tool (e.g., "pg_dump", "psql")
1182///
1183/// # Returns
1184///
1185/// The major version number (e.g., 16 for pg_dump 16.10)
1186///
1187/// # Errors
1188///
1189/// Returns an error if:
1190/// - Tool is not found in PATH
1191/// - Tool execution fails
1192/// - Version output cannot be parsed
1193///
1194/// # Examples
1195///
1196/// ```no_run
1197/// use database_replicator::utils::get_pg_tool_version;
1198/// use anyhow::Result;
1199///
1200/// fn example() -> Result<()> {
1201/// let version = get_pg_tool_version("pg_dump")?;
1202/// println!("pg_dump major version: {}", version); // e.g., 16
1203/// Ok(())
1204/// }
1205/// ```
1206pub fn get_pg_tool_version(tool: &str) -> Result<u32> {
1207 use std::process::Command;
1208
1209 let path = which(tool).with_context(|| format!("{} not found in PATH", tool))?;
1210
1211 let output = Command::new(&path)
1212 .arg("--version")
1213 .output()
1214 .with_context(|| format!("Failed to execute {} --version", tool))?;
1215
1216 let version_str = String::from_utf8_lossy(&output.stdout);
1217 parse_pg_version_string(&version_str)
1218}
1219
1220/// Parse major version from PostgreSQL version string
1221///
1222/// Handles formats like:
1223/// - "pg_dump (PostgreSQL) 16.10 (Ubuntu 16.10-0ubuntu0.24.04.1)"
1224/// - "psql (PostgreSQL) 17.2"
1225/// - "17.2 (Debian 17.2-1.pgdg120+1)"
1226///
1227/// # Arguments
1228///
1229/// * `version_str` - Version string output from a PostgreSQL tool
1230///
1231/// # Returns
1232///
1233/// The major version number (e.g., 16, 17)
1234///
1235/// # Errors
1236///
1237/// Returns an error if the version cannot be parsed.
1238pub fn parse_pg_version_string(version_str: &str) -> Result<u32> {
1239 // Look for version pattern: major.minor
1240 for word in version_str.split_whitespace() {
1241 if let Some(major_str) = word.split('.').next() {
1242 if let Ok(major) = major_str.parse::<u32>() {
1243 // Valid PostgreSQL versions are between 9 and 99
1244 if (9..=99).contains(&major) {
1245 return Ok(major);
1246 }
1247 }
1248 }
1249 }
1250 bail!("Could not parse PostgreSQL version from: {}", version_str)
1251}
1252
1253#[cfg(test)]
1254mod tests {
1255 use super::*;
1256
1257 #[test]
1258 fn test_validate_connection_string_valid() {
1259 assert!(validate_connection_string("postgresql://user:pass@localhost:5432/dbname").is_ok());
1260 assert!(validate_connection_string("postgres://user@host/db").is_ok());
1261 }
1262
1263 #[test]
1264 fn test_check_required_tools() {
1265 // This test will pass if PostgreSQL client tools are installed
1266 // It will fail (appropriately) if they're not installed
1267 let result = check_required_tools();
1268
1269 // On systems with PostgreSQL installed, this should pass
1270 // On systems without it, we expect a specific error message
1271 if let Err(err) = result {
1272 let err_msg = err.to_string();
1273 assert!(err_msg.contains("Missing required PostgreSQL client tools"));
1274 assert!(
1275 err_msg.contains("pg_dump")
1276 || err_msg.contains("pg_dumpall")
1277 || err_msg.contains("psql")
1278 );
1279 }
1280 }
1281
1282 #[test]
1283 fn test_validate_connection_string_invalid() {
1284 assert!(validate_connection_string("").is_err());
1285 assert!(validate_connection_string(" ").is_err());
1286 assert!(validate_connection_string("mysql://localhost/db").is_err());
1287 assert!(validate_connection_string("postgresql://localhost").is_err());
1288 assert!(validate_connection_string("postgresql://localhost/db").is_err());
1289 // Missing user
1290 }
1291
1292 #[test]
1293 fn test_sanitize_identifier() {
1294 assert_eq!(sanitize_identifier("normal_table"), "normal_table");
1295 assert_eq!(sanitize_identifier("table\x00name"), "tablename");
1296 assert_eq!(sanitize_identifier("table\nname"), "tablename");
1297
1298 // Test length limit
1299 let long_name = "a".repeat(200);
1300 assert_eq!(sanitize_identifier(&long_name).len(), 100);
1301 }
1302
1303 #[tokio::test]
1304 async fn test_retry_with_backoff_success() {
1305 let mut attempts = 0;
1306 let result = retry_with_backoff(
1307 || {
1308 attempts += 1;
1309 async move {
1310 if attempts < 3 {
1311 anyhow::bail!("Temporary failure")
1312 } else {
1313 Ok("Success")
1314 }
1315 }
1316 },
1317 5,
1318 Duration::from_millis(10),
1319 )
1320 .await;
1321
1322 assert!(result.is_ok());
1323 assert_eq!(result.unwrap(), "Success");
1324 assert_eq!(attempts, 3);
1325 }
1326
1327 #[tokio::test]
1328 async fn test_retry_with_backoff_failure() {
1329 let mut attempts = 0;
1330 let result: Result<&str> = retry_with_backoff(
1331 || {
1332 attempts += 1;
1333 async move { anyhow::bail!("Permanent failure") }
1334 },
1335 2,
1336 Duration::from_millis(10),
1337 )
1338 .await;
1339
1340 assert!(result.is_err());
1341 assert_eq!(attempts, 3); // Initial + 2 retries
1342 }
1343
1344 #[test]
1345 fn test_validate_source_target_different_valid() {
1346 // Different hosts
1347 assert!(validate_source_target_different(
1348 "postgresql://user:pass@source.com:5432/db",
1349 "postgresql://user:pass@target.com:5432/db"
1350 )
1351 .is_ok());
1352
1353 // Different databases on same host
1354 assert!(validate_source_target_different(
1355 "postgresql://user:pass@host:5432/db1",
1356 "postgresql://user:pass@host:5432/db2"
1357 )
1358 .is_ok());
1359
1360 // Different ports on same host
1361 assert!(validate_source_target_different(
1362 "postgresql://user:pass@host:5432/db",
1363 "postgresql://user:pass@host:5433/db"
1364 )
1365 .is_ok());
1366
1367 // Different users on same host/db (edge case but allowed)
1368 assert!(validate_source_target_different(
1369 "postgresql://user1:pass@host:5432/db",
1370 "postgresql://user2:pass@host:5432/db"
1371 )
1372 .is_ok());
1373 }
1374
1375 #[test]
1376 fn test_validate_source_target_different_invalid() {
1377 // Exact same URL
1378 assert!(validate_source_target_different(
1379 "postgresql://user:pass@host:5432/db",
1380 "postgresql://user:pass@host:5432/db"
1381 )
1382 .is_err());
1383
1384 // Same URL with different scheme (postgres vs postgresql)
1385 assert!(validate_source_target_different(
1386 "postgres://user:pass@host:5432/db",
1387 "postgresql://user:pass@host:5432/db"
1388 )
1389 .is_err());
1390
1391 // Same URL with default port vs explicit port
1392 assert!(validate_source_target_different(
1393 "postgresql://user:pass@host/db",
1394 "postgresql://user:pass@host:5432/db"
1395 )
1396 .is_err());
1397
1398 // Same URL with different query parameters (still same database)
1399 assert!(validate_source_target_different(
1400 "postgresql://user:pass@host:5432/db?sslmode=require",
1401 "postgresql://user:pass@host:5432/db?sslmode=prefer"
1402 )
1403 .is_err());
1404
1405 // Same host with different case (hostnames are case-insensitive)
1406 assert!(validate_source_target_different(
1407 "postgresql://user:pass@HOST.COM:5432/db",
1408 "postgresql://user:pass@host.com:5432/db"
1409 )
1410 .is_err());
1411 }
1412
1413 #[test]
1414 fn test_parse_postgres_url() {
1415 // Full URL with all components including password
1416 let parts = parse_postgres_url("postgresql://myuser:mypass@localhost:5432/mydb").unwrap();
1417 assert_eq!(parts.host, "localhost");
1418 assert_eq!(parts.port, 5432);
1419 assert_eq!(parts.database, "mydb");
1420 assert_eq!(parts.user, Some("myuser".to_string()));
1421 assert_eq!(parts.password, Some("mypass".to_string()));
1422
1423 // URL without port (should default to 5432)
1424 let parts = parse_postgres_url("postgresql://user:pass@host/db").unwrap();
1425 assert_eq!(parts.host, "host");
1426 assert_eq!(parts.port, 5432);
1427 assert_eq!(parts.database, "db");
1428 assert_eq!(parts.user, Some("user".to_string()));
1429 assert_eq!(parts.password, Some("pass".to_string()));
1430
1431 // URL with user but no password
1432 let parts = parse_postgres_url("postgresql://user@host/db").unwrap();
1433 assert_eq!(parts.host, "host");
1434 assert_eq!(parts.user, Some("user".to_string()));
1435 assert_eq!(parts.password, None);
1436
1437 // URL without authentication
1438 let parts = parse_postgres_url("postgresql://host:5433/db").unwrap();
1439 assert_eq!(parts.host, "host");
1440 assert_eq!(parts.port, 5433);
1441 assert_eq!(parts.database, "db");
1442 assert_eq!(parts.user, None);
1443 assert_eq!(parts.password, None);
1444
1445 // URL with query parameters
1446 let parts = parse_postgres_url("postgresql://user:pass@host/db?sslmode=require").unwrap();
1447 assert_eq!(parts.host, "host");
1448 assert_eq!(parts.database, "db");
1449 assert_eq!(parts.password, Some("pass".to_string()));
1450
1451 // URL with postgres:// scheme (alternative)
1452 let parts = parse_postgres_url("postgres://user:pass@host/db").unwrap();
1453 assert_eq!(parts.host, "host");
1454 assert_eq!(parts.database, "db");
1455 assert_eq!(parts.password, Some("pass".to_string()));
1456
1457 // Host normalization (lowercase)
1458 let parts = parse_postgres_url("postgresql://user:pass@HOST.COM/db").unwrap();
1459 assert_eq!(parts.host, "host.com");
1460 assert_eq!(parts.password, Some("pass".to_string()));
1461
1462 // Password with special characters
1463 let parts = parse_postgres_url("postgresql://user:p@ss!word@host/db").unwrap();
1464 assert_eq!(parts.password, Some("p@ss!word".to_string()));
1465 }
1466
1467 #[test]
1468 fn test_validate_postgres_identifier_valid() {
1469 // Valid identifiers
1470 assert!(validate_postgres_identifier("mydb").is_ok());
1471 assert!(validate_postgres_identifier("my_database").is_ok());
1472 assert!(validate_postgres_identifier("_private_db").is_ok());
1473 assert!(validate_postgres_identifier("db123").is_ok());
1474 assert!(validate_postgres_identifier("Database_2024").is_ok());
1475
1476 // Maximum length (63 characters)
1477 let max_length_name = "a".repeat(63);
1478 assert!(validate_postgres_identifier(&max_length_name).is_ok());
1479 }
1480
1481 #[test]
1482 fn test_pgpass_file_creation() {
1483 let parts = PostgresUrlParts {
1484 host: "localhost".to_string(),
1485 port: 5432,
1486 database: "testdb".to_string(),
1487 user: Some("testuser".to_string()),
1488 password: Some("testpass".to_string()),
1489 query_params: std::collections::HashMap::new(),
1490 };
1491
1492 let pgpass = PgPassFile::new(&parts).unwrap();
1493 assert!(pgpass.path().exists());
1494
1495 // Verify file content
1496 let content = std::fs::read_to_string(pgpass.path()).unwrap();
1497 assert_eq!(content, "localhost:5432:testdb:testuser:testpass\n");
1498
1499 // Verify permissions on Unix
1500 #[cfg(unix)]
1501 {
1502 use std::os::unix::fs::PermissionsExt;
1503 let metadata = std::fs::metadata(pgpass.path()).unwrap();
1504 let permissions = metadata.permissions();
1505 assert_eq!(permissions.mode() & 0o777, 0o600);
1506 }
1507
1508 // File should be cleaned up when pgpass is dropped
1509 let path = pgpass.path().to_path_buf();
1510 drop(pgpass);
1511 assert!(!path.exists());
1512 }
1513
1514 #[test]
1515 fn test_pgpass_file_without_password() {
1516 let parts = PostgresUrlParts {
1517 host: "localhost".to_string(),
1518 port: 5432,
1519 database: "testdb".to_string(),
1520 user: Some("testuser".to_string()),
1521 password: None,
1522 query_params: std::collections::HashMap::new(),
1523 };
1524
1525 let pgpass = PgPassFile::new(&parts).unwrap();
1526 let content = std::fs::read_to_string(pgpass.path()).unwrap();
1527 // Should use empty password
1528 assert_eq!(content, "localhost:5432:testdb:testuser:\n");
1529 }
1530
1531 #[test]
1532 fn test_pgpass_file_without_user() {
1533 let parts = PostgresUrlParts {
1534 host: "localhost".to_string(),
1535 port: 5432,
1536 database: "testdb".to_string(),
1537 user: None,
1538 password: Some("testpass".to_string()),
1539 query_params: std::collections::HashMap::new(),
1540 };
1541
1542 let pgpass = PgPassFile::new(&parts).unwrap();
1543 let content = std::fs::read_to_string(pgpass.path()).unwrap();
1544 // Should use wildcard for user
1545 assert_eq!(content, "localhost:5432:testdb:*:testpass\n");
1546 }
1547
1548 #[test]
1549 fn test_strip_password_from_url() {
1550 // With password
1551 let url = "postgresql://user:p@ssw0rd@host:5432/db";
1552 let stripped = strip_password_from_url(url).unwrap();
1553 assert_eq!(stripped, "postgresql://user@host:5432/db");
1554
1555 // With special characters in password
1556 let url = "postgresql://user:p@ss!w0rd@host:5432/db";
1557 let stripped = strip_password_from_url(url).unwrap();
1558 assert_eq!(stripped, "postgresql://user@host:5432/db");
1559
1560 // Without password
1561 let url = "postgresql://user@host:5432/db";
1562 let stripped = strip_password_from_url(url).unwrap();
1563 assert_eq!(stripped, "postgresql://user@host:5432/db");
1564
1565 // With query parameters
1566 let url = "postgresql://user:pass@host:5432/db?sslmode=require";
1567 let stripped = strip_password_from_url(url).unwrap();
1568 assert_eq!(stripped, "postgresql://user@host:5432/db?sslmode=require");
1569
1570 // No user
1571 let url = "postgresql://host:5432/db";
1572 let stripped = strip_password_from_url(url).unwrap();
1573 assert_eq!(stripped, "postgresql://host:5432/db");
1574 }
1575
1576 #[test]
1577 fn test_validate_postgres_identifier_invalid() {
1578 // SQL injection attempts
1579 assert!(validate_postgres_identifier("mydb\"; DROP DATABASE production; --").is_err());
1580 assert!(validate_postgres_identifier("db'; DELETE FROM users; --").is_err());
1581
1582 // Invalid start characters
1583 assert!(validate_postgres_identifier("123db").is_err()); // Starts with digit
1584 assert!(validate_postgres_identifier("$db").is_err()); // Starts with special char
1585 assert!(validate_postgres_identifier("-db").is_err()); // Starts with dash
1586
1587 // Contains invalid characters
1588 assert!(validate_postgres_identifier("my-database").is_err()); // Contains dash
1589 assert!(validate_postgres_identifier("my.database").is_err()); // Contains dot
1590 assert!(validate_postgres_identifier("my database").is_err()); // Contains space
1591 assert!(validate_postgres_identifier("my@db").is_err()); // Contains @
1592 assert!(validate_postgres_identifier("my#db").is_err()); // Contains #
1593
1594 // Empty or too long
1595 assert!(validate_postgres_identifier("").is_err());
1596 assert!(validate_postgres_identifier(" ").is_err());
1597
1598 // Over maximum length (64+ characters)
1599 let too_long = "a".repeat(64);
1600 assert!(validate_postgres_identifier(&too_long).is_err());
1601
1602 // Control characters
1603 assert!(validate_postgres_identifier("my\ndb").is_err());
1604 assert!(validate_postgres_identifier("my\tdb").is_err());
1605 assert!(validate_postgres_identifier("my\x00db").is_err());
1606 }
1607
1608 #[test]
1609 fn test_is_serendb_target() {
1610 // Positive cases - SerenDB hosts
1611 assert!(is_serendb_target(
1612 "postgresql://user:pass@db.serendb.com/mydb"
1613 ));
1614 assert!(is_serendb_target(
1615 "postgresql://user:pass@cluster.console.serendb.com/mydb"
1616 ));
1617 assert!(is_serendb_target(
1618 "postgres://u:p@x.serendb.com:5432/db?sslmode=require"
1619 ));
1620 assert!(is_serendb_target("postgresql://user:pass@serendb.com/mydb"));
1621
1622 // Negative cases - not SerenDB
1623 assert!(!is_serendb_target("postgresql://user:pass@localhost/mydb"));
1624 assert!(!is_serendb_target(
1625 "postgresql://user:pass@rds.amazonaws.com/mydb"
1626 ));
1627 assert!(!is_serendb_target("postgresql://user:pass@neon.tech/mydb"));
1628 // Domain spoofing attempt - should NOT match
1629 assert!(!is_serendb_target(
1630 "postgresql://user:pass@serendb.com.evil.com/mydb"
1631 ));
1632 assert!(!is_serendb_target(
1633 "postgresql://user:pass@notserendb.com/mydb"
1634 ));
1635 // Invalid URL
1636 assert!(!is_serendb_target("not-a-url"));
1637 }
1638
1639 #[test]
1640 fn test_parse_pg_version_string() {
1641 // Standard pg_dump output
1642 assert_eq!(
1643 parse_pg_version_string("pg_dump (PostgreSQL) 16.10 (Ubuntu 16.10-0ubuntu0.24.04.1)")
1644 .unwrap(),
1645 16
1646 );
1647
1648 // Standard psql output
1649 assert_eq!(
1650 parse_pg_version_string("psql (PostgreSQL) 17.2").unwrap(),
1651 17
1652 );
1653
1654 // pg_restore output
1655 assert_eq!(
1656 parse_pg_version_string("pg_restore (PostgreSQL) 15.4").unwrap(),
1657 15
1658 );
1659
1660 // Debian-style version
1661 assert_eq!(
1662 parse_pg_version_string("17.2 (Debian 17.2-1.pgdg120+1)").unwrap(),
1663 17
1664 );
1665
1666 // Should fail on invalid input
1667 assert!(parse_pg_version_string("not a version").is_err());
1668 assert!(parse_pg_version_string("version 1.2.3").is_err()); // 1 is < 9
1669 assert!(parse_pg_version_string("").is_err());
1670 }
1671
1672 #[test]
1673 fn test_get_pg_tool_version() {
1674 // This test will only pass if pg_dump is installed
1675 // Skip gracefully if not available
1676 if which("pg_dump").is_ok() {
1677 let version = get_pg_tool_version("pg_dump").unwrap();
1678 assert!(
1679 version >= 12,
1680 "Expected pg_dump version >= 12, got {}",
1681 version
1682 );
1683 assert!(
1684 version <= 99,
1685 "Expected pg_dump version <= 99, got {}",
1686 version
1687 );
1688 }
1689
1690 // Non-existent tool should fail
1691 assert!(get_pg_tool_version("nonexistent_pg_tool_xyz").is_err());
1692 }
1693}