database_replicator/utils.rs
1// ABOUTME: Utility functions for validation and error handling
2// ABOUTME: Provides input validation, retry logic, and resource cleanup
3
4use anyhow::{bail, Context, Result};
5use std::time::Duration;
6use url::Url;
7use which::which;
8
9/// Get TCP keepalive environment variables for PostgreSQL client tools
10///
11/// Returns environment variables that configure TCP keepalives for external
12/// PostgreSQL tools (pg_dump, pg_restore, psql, pg_dumpall). These prevent
13/// idle connection timeouts when connecting through load balancers like AWS ELB.
14///
15/// Environment variables returned:
16/// - `PGKEEPALIVES=1`: Enable TCP keepalives
17/// - `PGKEEPALIVESIDLE=60`: Send first keepalive after 60 seconds of idle time
18/// - `PGKEEPALIVESINTERVAL=10`: Send subsequent keepalives every 10 seconds
19///
20/// # Returns
21///
22/// A vector of (variable_name, value) tuples to be passed to subprocess commands
23///
24/// # Examples
25///
26/// ```
27/// # use database_replicator::utils::get_keepalive_env_vars;
28/// # use std::process::Command;
29/// let keepalive_vars = get_keepalive_env_vars();
30/// let mut cmd = Command::new("psql");
31/// for (key, value) in keepalive_vars {
32/// cmd.env(key, value);
33/// }
34/// ```
35pub fn get_keepalive_env_vars() -> Vec<(&'static str, &'static str)> {
36 vec![
37 ("PGKEEPALIVES", "1"),
38 ("PGKEEPALIVESIDLE", "60"),
39 ("PGKEEPALIVESINTERVAL", "10"),
40 ]
41}
42
43/// Validate a PostgreSQL connection string
44///
45/// Checks that the connection string has proper format and required components:
46/// - Starts with "postgres://" or "postgresql://"
47/// - Contains user credentials (@ symbol)
48/// - Contains database name (/ separator with at least 3 occurrences)
49///
50/// # Arguments
51///
52/// * `url` - Connection string to validate
53///
54/// # Returns
55///
56/// Returns `Ok(())` if the connection string is valid.
57///
58/// # Errors
59///
60/// Returns an error with helpful message if the connection string is:
61/// - Empty or whitespace only
62/// - Missing proper scheme (postgres:// or postgresql://)
63/// - Missing user credentials (@ symbol)
64/// - Missing database name
65///
66/// # Examples
67///
68/// ```
69/// # use database_replicator::utils::validate_connection_string;
70/// # use anyhow::Result;
71/// # fn example() -> Result<()> {
72/// // Valid connection strings
73/// validate_connection_string("postgresql://user:pass@localhost:5432/mydb")?;
74/// validate_connection_string("postgres://user@host/db")?;
75///
76/// // Invalid - will return error
77/// assert!(validate_connection_string("").is_err());
78/// assert!(validate_connection_string("mysql://localhost/db").is_err());
79/// # Ok(())
80/// # }
81/// ```
82pub fn validate_connection_string(url: &str) -> Result<()> {
83 if url.trim().is_empty() {
84 bail!("Connection string cannot be empty");
85 }
86
87 // Check for common URL schemes
88 if !url.starts_with("postgres://") && !url.starts_with("postgresql://") {
89 bail!(
90 "Invalid connection string format.\n\
91 Expected format: postgresql://user:password@host:port/database\n\
92 Got: {}",
93 url
94 );
95 }
96
97 // Check for minimum required components (user@host/database)
98 if !url.contains('@') {
99 bail!(
100 "Connection string missing user credentials.\n\
101 Expected format: postgresql://user:password@host:port/database"
102 );
103 }
104
105 if !url.contains('/') || url.matches('/').count() < 3 {
106 bail!(
107 "Connection string missing database name.\n\
108 Expected format: postgresql://user:password@host:port/database"
109 );
110 }
111
112 Ok(())
113}
114
115/// Check that required PostgreSQL client tools are available
116///
117/// Verifies that the following tools are installed and in PATH:
118/// - `pg_dump` - For dumping database schema and data
119/// - `pg_dumpall` - For dumping global objects (roles, tablespaces)
120/// - `psql` - For restoring databases
121///
122/// # Returns
123///
124/// Returns `Ok(())` if all required tools are found.
125///
126/// # Errors
127///
128/// Returns an error with installation instructions if any tools are missing.
129///
130/// # Examples
131///
132/// ```
133/// # use database_replicator::utils::check_required_tools;
134/// # use anyhow::Result;
135/// # fn example() -> Result<()> {
136/// // Check if PostgreSQL tools are installed
137/// check_required_tools()?;
138/// # Ok(())
139/// # }
140/// ```
141pub fn check_required_tools() -> Result<()> {
142 let tools = ["pg_dump", "pg_dumpall", "psql"];
143 let mut missing = Vec::new();
144
145 for tool in &tools {
146 if which(tool).is_err() {
147 missing.push(*tool);
148 }
149 }
150
151 if !missing.is_empty() {
152 bail!(
153 "Missing required PostgreSQL client tools: {}\n\
154 \n\
155 Please install PostgreSQL client tools:\n\
156 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
157 - macOS: brew install postgresql\n\
158 - RHEL/CentOS: sudo yum install postgresql\n\
159 - Windows: Download from https://www.postgresql.org/download/windows/",
160 missing.join(", ")
161 );
162 }
163
164 Ok(())
165}
166
167/// Retry a function with exponential backoff
168///
169/// Executes an async operation with automatic retry on failure. Each retry doubles
170/// the delay (exponential backoff) to handle transient failures gracefully.
171///
172/// # Arguments
173///
174/// * `operation` - Async function to retry (FnMut returning Future\<Output = Result\<T\>\>)
175/// * `max_retries` - Maximum number of retry attempts (0 = no retries, just initial attempt)
176/// * `initial_delay` - Delay before first retry (doubles each subsequent retry)
177///
178/// # Returns
179///
180/// Returns the successful result or the last error after all retries exhausted.
181///
182/// # Examples
183///
184/// ```no_run
185/// # use anyhow::Result;
186/// # use std::time::Duration;
187/// # use database_replicator::utils::retry_with_backoff;
188/// # async fn example() -> Result<()> {
189/// let result = retry_with_backoff(
190/// || async { Ok("success") },
191/// 3, // Try up to 3 times
192/// Duration::from_secs(1) // Start with 1s delay
193/// ).await?;
194/// # Ok(())
195/// # }
196/// ```
197pub async fn retry_with_backoff<F, Fut, T>(
198 mut operation: F,
199 max_retries: u32,
200 initial_delay: Duration,
201) -> Result<T>
202where
203 F: FnMut() -> Fut,
204 Fut: std::future::Future<Output = Result<T>>,
205{
206 let mut delay = initial_delay;
207 let mut last_error = None;
208
209 for attempt in 0..=max_retries {
210 match operation().await {
211 Ok(result) => return Ok(result),
212 Err(e) => {
213 last_error = Some(e);
214
215 if attempt < max_retries {
216 tracing::warn!(
217 "Operation failed (attempt {}/{}), retrying in {:?}...",
218 attempt + 1,
219 max_retries + 1,
220 delay
221 );
222 tokio::time::sleep(delay).await;
223 delay *= 2; // Exponential backoff
224 }
225 }
226 }
227 }
228
229 Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Operation failed after retries")))
230}
231
232/// Retry a subprocess execution with exponential backoff on connection errors
233///
234/// Executes a subprocess command with automatic retry on connection-related failures.
235/// Each retry doubles the delay (exponential backoff) to handle transient connection issues.
236///
237/// Connection errors are detected by checking:
238/// - Non-zero exit codes
239/// - Stderr output containing connection-related error patterns:
240/// - "connection closed"
241/// - "connection refused"
242/// - "could not connect"
243/// - "server closed the connection"
244/// - "timeout"
245/// - "Connection timed out"
246///
247/// # Arguments
248///
249/// * `operation` - Function that executes a Command and returns the exit status
250/// * `max_retries` - Maximum number of retry attempts (0 = no retries, just initial attempt)
251/// * `initial_delay` - Delay before first retry (doubles each subsequent retry)
252/// * `operation_name` - Name of the operation for logging (e.g., "pg_restore", "psql")
253///
254/// # Returns
255///
256/// Returns Ok(()) on success or the last error after all retries exhausted.
257///
258/// # Examples
259///
260/// ```no_run
261/// # use anyhow::Result;
262/// # use std::time::Duration;
263/// # use std::process::Command;
264/// # use database_replicator::utils::retry_subprocess_with_backoff;
265/// # async fn example() -> Result<()> {
266/// retry_subprocess_with_backoff(
267/// || {
268/// let mut cmd = Command::new("psql");
269/// cmd.arg("--version");
270/// cmd.status().map_err(anyhow::Error::from)
271/// },
272/// 3, // Try up to 3 times
273/// Duration::from_secs(1), // Start with 1s delay
274/// "psql"
275/// ).await?;
276/// # Ok(())
277/// # }
278/// ```
279pub async fn retry_subprocess_with_backoff<F>(
280 mut operation: F,
281 max_retries: u32,
282 initial_delay: Duration,
283 operation_name: &str,
284) -> Result<()>
285where
286 F: FnMut() -> Result<std::process::ExitStatus>,
287{
288 let mut delay = initial_delay;
289 let mut last_error = None;
290
291 for attempt in 0..=max_retries {
292 match operation() {
293 Ok(status) => {
294 if status.success() {
295 return Ok(());
296 } else {
297 // Non-zero exit code - check if it's a connection error
298 // We can't easily capture stderr here, so we'll treat all non-zero
299 // exit codes as potential connection errors for now
300 let error = anyhow::anyhow!(
301 "{} failed with exit code: {}",
302 operation_name,
303 status.code().unwrap_or(-1)
304 );
305 last_error = Some(error);
306
307 if attempt < max_retries {
308 tracing::warn!(
309 "{} failed (attempt {}/{}), retrying in {:?}...",
310 operation_name,
311 attempt + 1,
312 max_retries + 1,
313 delay
314 );
315 tokio::time::sleep(delay).await;
316 delay *= 2; // Exponential backoff
317 }
318 }
319 }
320 Err(e) => {
321 last_error = Some(e);
322
323 if attempt < max_retries {
324 tracing::warn!(
325 "{} failed (attempt {}/{}): {}, retrying in {:?}...",
326 operation_name,
327 attempt + 1,
328 max_retries + 1,
329 last_error.as_ref().unwrap(),
330 delay
331 );
332 tokio::time::sleep(delay).await;
333 delay *= 2; // Exponential backoff
334 }
335 }
336 }
337 }
338
339 Err(last_error.unwrap_or_else(|| {
340 anyhow::anyhow!("{} failed after {} retries", operation_name, max_retries)
341 }))
342}
343
344/// Validate a PostgreSQL identifier (database name, schema name, etc.)
345///
346/// Validates that an identifier follows PostgreSQL naming rules to prevent SQL injection.
347/// PostgreSQL identifiers must:
348/// - Be 1-63 characters long
349/// - Start with a letter (a-z, A-Z) or underscore (_)
350/// - Contain only letters, digits (0-9), or underscores
351///
352/// # Arguments
353///
354/// * `identifier` - The identifier to validate (database name, schema name, etc.)
355///
356/// # Returns
357///
358/// Returns `Ok(())` if the identifier is valid.
359///
360/// # Errors
361///
362/// Returns an error if the identifier:
363/// - Is empty or whitespace-only
364/// - Exceeds 63 characters
365/// - Starts with an invalid character (digit or special character)
366/// - Contains invalid characters (anything except a-z, A-Z, 0-9, _)
367///
368/// # Security
369///
370/// This function is critical for preventing SQL injection attacks. All database
371/// names, schema names, and table names from untrusted sources MUST be validated
372/// before use in SQL statements.
373///
374/// # Examples
375///
376/// ```
377/// # use database_replicator::utils::validate_postgres_identifier;
378/// # use anyhow::Result;
379/// # fn example() -> Result<()> {
380/// // Valid identifiers
381/// validate_postgres_identifier("mydb")?;
382/// validate_postgres_identifier("my_database")?;
383/// validate_postgres_identifier("_private_db")?;
384///
385/// // Invalid - will return error
386/// assert!(validate_postgres_identifier("123db").is_err());
387/// assert!(validate_postgres_identifier("my-database").is_err());
388/// assert!(validate_postgres_identifier("db\"; DROP TABLE users; --").is_err());
389/// # Ok(())
390/// # }
391/// ```
392pub fn validate_postgres_identifier(identifier: &str) -> Result<()> {
393 // Check for empty or whitespace-only
394 let trimmed = identifier.trim();
395 if trimmed.is_empty() {
396 bail!("Identifier cannot be empty or whitespace-only");
397 }
398
399 // Check length (PostgreSQL limit is 63 characters)
400 if trimmed.len() > 63 {
401 bail!(
402 "Identifier '{}' exceeds maximum length of 63 characters (got {})",
403 sanitize_identifier(trimmed),
404 trimmed.len()
405 );
406 }
407
408 // Get first character
409 let first_char = trimmed.chars().next().unwrap();
410
411 // First character must be a letter or underscore
412 if !first_char.is_ascii_alphabetic() && first_char != '_' {
413 bail!(
414 "Identifier '{}' must start with a letter or underscore, not '{}'",
415 sanitize_identifier(trimmed),
416 first_char
417 );
418 }
419
420 // All characters must be alphanumeric or underscore
421 for (i, c) in trimmed.chars().enumerate() {
422 if !c.is_ascii_alphanumeric() && c != '_' {
423 bail!(
424 "Identifier '{}' contains invalid character '{}' at position {}. \
425 Only letters, digits, and underscores are allowed",
426 sanitize_identifier(trimmed),
427 if c.is_control() {
428 format!("\\x{:02x}", c as u32)
429 } else {
430 c.to_string()
431 },
432 i
433 );
434 }
435 }
436
437 Ok(())
438}
439
440/// Sanitize an identifier (table name, schema name, etc.) for display
441///
442/// Removes control characters and limits length to prevent log injection attacks
443/// and ensure readable error messages.
444///
445/// **Note**: This is for display purposes only. For SQL safety, use parameterized
446/// queries instead.
447///
448/// # Arguments
449///
450/// * `identifier` - The identifier to sanitize (table name, schema name, etc.)
451///
452/// # Returns
453///
454/// Sanitized string with control characters removed and length limited to 100 chars.
455///
456/// # Examples
457///
458/// ```
459/// # use database_replicator::utils::sanitize_identifier;
460/// assert_eq!(sanitize_identifier("normal_table"), "normal_table");
461/// assert_eq!(sanitize_identifier("table\x00name"), "tablename");
462/// assert_eq!(sanitize_identifier("table\nname"), "tablename");
463///
464/// // Length limit
465/// let long_name = "a".repeat(200);
466/// assert_eq!(sanitize_identifier(&long_name).len(), 100);
467/// ```
468pub fn sanitize_identifier(identifier: &str) -> String {
469 // Remove any control characters and limit length for display
470 identifier
471 .chars()
472 .filter(|c| !c.is_control())
473 .take(100)
474 .collect()
475}
476
477/// Quote a PostgreSQL identifier (database, schema, table, column)
478///
479/// Assumes the identifier has already been validated. Escapes embedded quotes
480/// and wraps the identifier in double quotes.
481pub fn quote_ident(identifier: &str) -> String {
482 let mut quoted = String::with_capacity(identifier.len() + 2);
483 quoted.push('"');
484 for ch in identifier.chars() {
485 if ch == '"' {
486 quoted.push('"');
487 }
488 quoted.push(ch);
489 }
490 quoted.push('"');
491 quoted
492}
493
494/// Quote a SQL string literal (for use in SQL statements)
495///
496/// Escapes single quotes by doubling them and wraps the string in single quotes.
497/// Use this for string values in SQL, not for identifiers.
498///
499/// # Examples
500///
501/// ```
502/// use database_replicator::utils::quote_literal;
503/// assert_eq!(quote_literal("hello"), "'hello'");
504/// assert_eq!(quote_literal("it's"), "'it''s'");
505/// assert_eq!(quote_literal(""), "''");
506/// ```
507pub fn quote_literal(value: &str) -> String {
508 let mut quoted = String::with_capacity(value.len() + 2);
509 quoted.push('\'');
510 for ch in value.chars() {
511 if ch == '\'' {
512 quoted.push('\'');
513 }
514 quoted.push(ch);
515 }
516 quoted.push('\'');
517 quoted
518}
519
520/// Quote a MySQL identifier (database, table, column)
521///
522/// MySQL uses backticks for identifier quoting. Escapes embedded backticks
523/// by doubling them.
524///
525/// # Examples
526///
527/// ```
528/// use database_replicator::utils::quote_mysql_ident;
529/// assert_eq!(quote_mysql_ident("users"), "`users`");
530/// assert_eq!(quote_mysql_ident("user`name"), "`user``name`");
531/// ```
532pub fn quote_mysql_ident(identifier: &str) -> String {
533 let mut quoted = String::with_capacity(identifier.len() + 2);
534 quoted.push('`');
535 for ch in identifier.chars() {
536 if ch == '`' {
537 quoted.push('`');
538 }
539 quoted.push(ch);
540 }
541 quoted.push('`');
542 quoted
543}
544
545/// Validate that source and target URLs are different to prevent accidental data loss
546///
547/// Compares two PostgreSQL connection URLs to ensure they point to different databases.
548/// This is critical for preventing data loss from operations like `init --drop-existing`
549/// where using the same URL for source and target would destroy the source data.
550///
551/// # Comparison Strategy
552///
553/// URLs are normalized and compared on:
554/// - Host (case-insensitive)
555/// - Port (defaulting to 5432 if not specified)
556/// - Database name (case-sensitive)
557/// - User (if present)
558///
559/// Query parameters (like SSL settings) are ignored as they don't affect database identity.
560///
561/// # Arguments
562///
563/// * `source_url` - Source database connection string
564/// * `target_url` - Target database connection string
565///
566/// # Returns
567///
568/// Returns `Ok(())` if the URLs point to different databases.
569///
570/// # Errors
571///
572/// Returns an error if:
573/// - The URLs point to the same database (same host, port, database name, and user)
574/// - Either URL is malformed and cannot be parsed
575///
576/// # Examples
577///
578/// ```
579/// # use database_replicator::utils::validate_source_target_different;
580/// # use anyhow::Result;
581/// # fn example() -> Result<()> {
582/// // Valid - different hosts
583/// validate_source_target_different(
584/// "postgresql://user:pass@source.com:5432/db",
585/// "postgresql://user:pass@target.com:5432/db"
586/// )?;
587///
588/// // Valid - different databases
589/// validate_source_target_different(
590/// "postgresql://user:pass@host:5432/db1",
591/// "postgresql://user:pass@host:5432/db2"
592/// )?;
593///
594/// // Invalid - same database
595/// assert!(validate_source_target_different(
596/// "postgresql://user:pass@host:5432/db",
597/// "postgresql://user:pass@host:5432/db"
598/// ).is_err());
599/// # Ok(())
600/// # }
601/// ```
602pub fn validate_source_target_different(source_url: &str, target_url: &str) -> Result<()> {
603 // Parse both URLs to extract components
604 let source_parts = parse_postgres_url(source_url)
605 .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
606 let target_parts = parse_postgres_url(target_url)
607 .with_context(|| format!("Failed to parse target URL: {}", target_url))?;
608
609 // Compare normalized components
610 if source_parts.host == target_parts.host
611 && source_parts.port == target_parts.port
612 && source_parts.database == target_parts.database
613 && source_parts.user == target_parts.user
614 {
615 bail!(
616 "Source and target URLs point to the same database!\n\
617 \n\
618 This would cause DATA LOSS - the target would overwrite the source.\n\
619 \n\
620 Source: {}@{}:{}/{}\n\
621 Target: {}@{}:{}/{}\n\
622 \n\
623 Please ensure source and target are different databases.\n\
624 Common causes:\n\
625 - Copy-paste error in connection strings\n\
626 - Wrong environment variables (e.g., SOURCE_URL == TARGET_URL)\n\
627 - Typo in database name or host",
628 source_parts.user.as_deref().unwrap_or("(no user)"),
629 source_parts.host,
630 source_parts.port,
631 source_parts.database,
632 target_parts.user.as_deref().unwrap_or("(no user)"),
633 target_parts.host,
634 target_parts.port,
635 target_parts.database
636 );
637 }
638
639 Ok(())
640}
641
642/// Parse a PostgreSQL URL into its components
643///
644/// # Arguments
645///
646/// * `url` - PostgreSQL connection URL (postgres:// or postgresql://)
647///
648/// # Returns
649///
650/// Returns a `PostgresUrlParts` struct with normalized components.
651///
652/// # Security
653///
654/// This function extracts passwords from URLs for use with .pgpass files.
655/// Ensure returned values are handled securely and not logged.
656pub fn parse_postgres_url(url: &str) -> Result<PostgresUrlParts> {
657 // Remove scheme
658 let url_without_scheme = url
659 .trim_start_matches("postgres://")
660 .trim_start_matches("postgresql://");
661
662 // Split into base and query params
663 let (base, query_string) = if let Some((b, q)) = url_without_scheme.split_once('?') {
664 (b, Some(q))
665 } else {
666 (url_without_scheme, None)
667 };
668
669 // Parse query parameters into HashMap
670 let mut query_params = std::collections::HashMap::new();
671 if let Some(query) = query_string {
672 for param in query.split('&') {
673 if let Some((key, value)) = param.split_once('=') {
674 query_params.insert(key.to_string(), value.to_string());
675 }
676 }
677 }
678
679 // Parse: [user[:password]@]host[:port]/database
680 let (auth_and_host, database) = base
681 .rsplit_once('/')
682 .ok_or_else(|| anyhow::anyhow!("Missing database name in URL"))?;
683
684 // Parse authentication and host
685 // Use rsplit_once to split from the right, so passwords can contain '@'
686 let (user, password, host_and_port) = if let Some((auth, hp)) = auth_and_host.rsplit_once('@') {
687 // Has authentication
688 let (user, pass) = if let Some((u, p)) = auth.split_once(':') {
689 (Some(u.to_string()), Some(p.to_string()))
690 } else {
691 (Some(auth.to_string()), None)
692 };
693 (user, pass, hp)
694 } else {
695 // No authentication
696 (None, None, auth_and_host)
697 };
698
699 // Parse host and port
700 let (host, port) = if let Some((h, p)) = host_and_port.rsplit_once(':') {
701 // Port specified
702 let port = p
703 .parse::<u16>()
704 .with_context(|| format!("Invalid port number: {}", p))?;
705 (h, port)
706 } else {
707 // Use default PostgreSQL port
708 (host_and_port, 5432)
709 };
710
711 Ok(PostgresUrlParts {
712 host: host.to_lowercase(), // Hostnames are case-insensitive
713 port,
714 database: database.to_string(), // Database names are case-sensitive in PostgreSQL
715 user,
716 password,
717 query_params,
718 })
719}
720
721/// Strip password from PostgreSQL connection URL
722/// Returns a new URL with password removed, preserving all other components
723/// This is useful for storing connection strings in places where passwords should not be visible
724pub fn strip_password_from_url(url: &str) -> Result<String> {
725 let parts = parse_postgres_url(url)?;
726
727 // Reconstruct URL without password
728 let scheme = if url.starts_with("postgresql://") {
729 "postgresql://"
730 } else if url.starts_with("postgres://") {
731 "postgres://"
732 } else {
733 bail!("Invalid PostgreSQL URL scheme");
734 };
735
736 let mut result = String::from(scheme);
737
738 // Add user if present (without password)
739 if let Some(user) = &parts.user {
740 result.push_str(user);
741 result.push('@');
742 }
743
744 // Add host and port
745 result.push_str(&parts.host);
746 result.push(':');
747 result.push_str(&parts.port.to_string());
748
749 // Add database
750 result.push('/');
751 result.push_str(&parts.database);
752
753 // Preserve query parameters if present
754 if let Some(query_start) = url.find('?') {
755 result.push_str(&url[query_start..]);
756 }
757
758 Ok(result)
759}
760
761/// Parsed components of a PostgreSQL connection URL
762#[derive(Debug, PartialEq)]
763pub struct PostgresUrlParts {
764 pub host: String,
765 pub port: u16,
766 pub database: String,
767 pub user: Option<String>,
768 pub password: Option<String>,
769 pub query_params: std::collections::HashMap<String, String>,
770}
771
772impl PostgresUrlParts {
773 /// Convert query parameters to PostgreSQL environment variables
774 ///
775 /// Maps common connection URL query parameters to their corresponding
776 /// PostgreSQL environment variable names. This allows SSL/TLS and other
777 /// connection settings to be passed to pg_dump, pg_dumpall, psql, etc.
778 ///
779 /// # Supported Parameters
780 ///
781 /// - `sslmode` → `PGSSLMODE`
782 /// - `sslcert` → `PGSSLCERT`
783 /// - `sslkey` → `PGSSLKEY`
784 /// - `sslrootcert` → `PGSSLROOTCERT`
785 /// - `channel_binding` → `PGCHANNELBINDING`
786 /// - `connect_timeout` → `PGCONNECT_TIMEOUT`
787 /// - `application_name` → `PGAPPNAME`
788 /// - `client_encoding` → `PGCLIENTENCODING`
789 ///
790 /// # Returns
791 ///
792 /// Vec of (env_var_name, value) pairs to be set as environment variables
793 pub fn to_pg_env_vars(&self) -> Vec<(&'static str, String)> {
794 let mut env_vars = Vec::new();
795
796 // Map query parameters to PostgreSQL environment variables
797 let param_mapping = [
798 ("sslmode", "PGSSLMODE"),
799 ("sslcert", "PGSSLCERT"),
800 ("sslkey", "PGSSLKEY"),
801 ("sslrootcert", "PGSSLROOTCERT"),
802 ("channel_binding", "PGCHANNELBINDING"),
803 ("connect_timeout", "PGCONNECT_TIMEOUT"),
804 ("application_name", "PGAPPNAME"),
805 ("client_encoding", "PGCLIENTENCODING"),
806 ];
807
808 for (param_name, env_var_name) in param_mapping {
809 if let Some(value) = self.query_params.get(param_name) {
810 env_vars.push((env_var_name, value.clone()));
811 }
812 }
813
814 env_vars
815 }
816}
817
818/// Managed .pgpass file for secure password passing to PostgreSQL tools
819///
820/// This struct creates a temporary .pgpass file with secure permissions (0600)
821/// and automatically cleans it up when dropped. PostgreSQL command-line tools
822/// read credentials from this file instead of accepting passwords in URLs,
823/// which prevents command injection vulnerabilities.
824///
825/// # Security
826///
827/// - File permissions are set to 0600 (owner read/write only)
828/// - File is automatically removed on Drop
829/// - Credentials are never passed on command line
830///
831/// # Format
832///
833/// .pgpass file format: hostname:port:database:username:password
834/// Wildcards (*) are used for maximum compatibility
835///
836/// # Examples
837///
838/// ```no_run
839/// # use database_replicator::utils::{PgPassFile, parse_postgres_url};
840/// # use anyhow::Result;
841/// # fn example() -> Result<()> {
842/// let url = "postgresql://user:pass@localhost:5432/mydb";
843/// let parts = parse_postgres_url(url)?;
844/// let pgpass = PgPassFile::new(&parts)?;
845///
846/// // Use pgpass.path() with PGPASSFILE environment variable
847/// // File is automatically cleaned up when pgpass goes out of scope
848/// # Ok(())
849/// # }
850/// ```
851pub struct PgPassFile {
852 path: std::path::PathBuf,
853}
854
855impl PgPassFile {
856 /// Create a new .pgpass file with credentials from URL parts
857 ///
858 /// # Arguments
859 ///
860 /// * `parts` - Parsed PostgreSQL URL components
861 ///
862 /// # Returns
863 ///
864 /// Returns a PgPassFile that will be automatically cleaned up on Drop
865 ///
866 /// # Errors
867 ///
868 /// Returns an error if the file cannot be created or permissions cannot be set
869 pub fn new(parts: &PostgresUrlParts) -> Result<Self> {
870 use std::fs;
871 use std::io::Write;
872
873 // Create temp file with secure name
874 let temp_dir = std::env::temp_dir();
875 let random: u32 = rand::random();
876 let filename = format!("pgpass-{:08x}", random);
877 let path = temp_dir.join(filename);
878
879 // Write .pgpass entry
880 // Format: hostname:port:database:username:password
881 let username = parts.user.as_deref().unwrap_or("*");
882 let password = parts.password.as_deref().unwrap_or("");
883 let entry = format!(
884 "{}:{}:{}:{}:{}\n",
885 parts.host, parts.port, parts.database, username, password
886 );
887
888 let mut file = fs::File::create(&path)
889 .with_context(|| format!("Failed to create .pgpass file at {}", path.display()))?;
890
891 file.write_all(entry.as_bytes())
892 .with_context(|| format!("Failed to write to .pgpass file at {}", path.display()))?;
893
894 // Set secure permissions (0600) - owner read/write only
895 #[cfg(unix)]
896 {
897 use std::os::unix::fs::PermissionsExt;
898 let permissions = fs::Permissions::from_mode(0o600);
899 fs::set_permissions(&path, permissions).with_context(|| {
900 format!(
901 "Failed to set permissions on .pgpass file at {}",
902 path.display()
903 )
904 })?;
905 }
906
907 // On Windows, .pgpass is stored in %APPDATA%\postgresql\pgpass.conf
908 // but for our temporary use case, we'll just use a temp file
909 // PostgreSQL on Windows also checks permissions but less strictly
910
911 Ok(Self { path })
912 }
913
914 /// Get the path to the .pgpass file
915 ///
916 /// Use this with the PGPASSFILE environment variable when running
917 /// PostgreSQL command-line tools
918 pub fn path(&self) -> &std::path::Path {
919 &self.path
920 }
921}
922
923impl Drop for PgPassFile {
924 fn drop(&mut self) {
925 // Best effort cleanup - don't panic if removal fails
926 let _ = std::fs::remove_file(&self.path);
927 }
928}
929
930/// Create a managed temporary directory with explicit cleanup support
931///
932/// Creates a temporary directory with a timestamped name that can be cleaned up
933/// even if the process is killed with SIGKILL. Unlike `TempDir::new()` which
934/// relies on the Drop trait, this function creates named directories that can
935/// be cleaned up on next process startup.
936///
937/// Directory naming format: `postgres-seren-replicator-{timestamp}-{random}`
938/// Example: `postgres-seren-replicator-20250106-120534-a3b2c1d4`
939///
940/// # Returns
941///
942/// Returns the path to the created temporary directory.
943///
944/// # Errors
945///
946/// Returns an error if the directory cannot be created.
947///
948/// # Examples
949///
950/// ```no_run
951/// # use database_replicator::utils::create_managed_temp_dir;
952/// # use anyhow::Result;
953/// # fn example() -> Result<()> {
954/// let temp_path = create_managed_temp_dir()?;
955/// println!("Using temp directory: {}", temp_path.display());
956/// // ... do work ...
957/// // Cleanup happens automatically on next startup via cleanup_stale_temp_dirs()
958/// # Ok(())
959/// # }
960/// ```
961pub fn create_managed_temp_dir() -> Result<std::path::PathBuf> {
962 use std::fs;
963 use std::time::SystemTime;
964
965 let system_temp = std::env::temp_dir();
966
967 // Generate timestamp for directory name
968 let timestamp = SystemTime::now()
969 .duration_since(SystemTime::UNIX_EPOCH)
970 .unwrap()
971 .as_secs();
972
973 // Generate random suffix for uniqueness
974 let random: u32 = rand::random();
975
976 // Create directory name with timestamp and random suffix
977 let dir_name = format!("postgres-seren-replicator-{}-{:08x}", timestamp, random);
978
979 let temp_path = system_temp.join(dir_name);
980
981 // Create the directory
982 fs::create_dir_all(&temp_path)
983 .with_context(|| format!("Failed to create temp directory at {}", temp_path.display()))?;
984
985 tracing::debug!("Created managed temp directory: {}", temp_path.display());
986
987 Ok(temp_path)
988}
989
990/// Clean up stale temporary directories from previous runs
991///
992/// Removes temporary directories created by `create_managed_temp_dir()` that are
993/// older than the specified age. This should be called on process startup to clean
994/// up directories left behind by processes killed with SIGKILL.
995///
996/// Only directories matching the pattern `postgres-seren-replicator-*` are removed.
997///
998/// # Arguments
999///
1000/// * `max_age_secs` - Maximum age in seconds before a directory is considered stale
1001/// (recommended: 86400 for 24 hours)
1002///
1003/// # Returns
1004///
1005/// Returns the number of directories cleaned up.
1006///
1007/// # Errors
1008///
1009/// Returns an error if the system temp directory cannot be read. Individual
1010/// directory removal errors are logged but don't fail the entire operation.
1011///
1012/// # Examples
1013///
1014/// ```no_run
1015/// # use database_replicator::utils::cleanup_stale_temp_dirs;
1016/// # use anyhow::Result;
1017/// # fn example() -> Result<()> {
1018/// // Clean up temp directories older than 24 hours
1019/// let cleaned = cleanup_stale_temp_dirs(86400)?;
1020/// println!("Cleaned up {} stale temp directories", cleaned);
1021/// # Ok(())
1022/// # }
1023/// ```
1024pub fn cleanup_stale_temp_dirs(max_age_secs: u64) -> Result<usize> {
1025 use std::fs;
1026 use std::time::SystemTime;
1027
1028 let system_temp = std::env::temp_dir();
1029 let now = SystemTime::now();
1030 let mut cleaned_count = 0;
1031
1032 // Read all entries in system temp directory
1033 let entries = fs::read_dir(&system_temp).with_context(|| {
1034 format!(
1035 "Failed to read system temp directory: {}",
1036 system_temp.display()
1037 )
1038 })?;
1039
1040 for entry in entries.flatten() {
1041 let path = entry.path();
1042
1043 // Only process directories matching our naming pattern
1044 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1045 if !name.starts_with("postgres-seren-replicator-") {
1046 continue;
1047 }
1048
1049 // Check directory age
1050 match entry.metadata() {
1051 Ok(metadata) => {
1052 if let Ok(modified) = metadata.modified() {
1053 if let Ok(age) = now.duration_since(modified) {
1054 if age.as_secs() > max_age_secs {
1055 // Directory is stale, remove it
1056 match fs::remove_dir_all(&path) {
1057 Ok(_) => {
1058 tracing::info!(
1059 "Cleaned up stale temp directory: {} (age: {}s)",
1060 path.display(),
1061 age.as_secs()
1062 );
1063 cleaned_count += 1;
1064 }
1065 Err(e) => {
1066 tracing::warn!(
1067 "Failed to remove stale temp directory {}: {}",
1068 path.display(),
1069 e
1070 );
1071 }
1072 }
1073 }
1074 }
1075 }
1076 }
1077 Err(e) => {
1078 tracing::warn!(
1079 "Failed to get metadata for temp directory {}: {}",
1080 path.display(),
1081 e
1082 );
1083 }
1084 }
1085 }
1086 }
1087
1088 if cleaned_count > 0 {
1089 tracing::info!(
1090 "Cleaned up {} stale temp directory(ies) older than {} seconds",
1091 cleaned_count,
1092 max_age_secs
1093 );
1094 }
1095
1096 Ok(cleaned_count)
1097}
1098
1099/// Parse a SerenDB URL to extract project, branch, and database IDs
1100///
1101/// SerenDB URLs have the format: postgresql://user:pass@<database-id>.<branch-id>.<project-id>.serendb.com:5432/db
1102/// This function extracts the three UUIDs from the hostname.
1103///
1104/// # Arguments
1105///
1106/// * `url` - The SerenDB PostgreSQL connection string
1107///
1108/// # Returns
1109///
1110/// An `Option` containing a tuple of `(project_id, branch_id, database_id)` if the
1111/// URL is a valid SerenDB target and contains the expected ID format, otherwise `None`.
1112pub fn parse_serendb_url_for_ids(url: &str) -> Option<(String, String, String)> {
1113 let parts = parse_postgres_url(url).ok()?;
1114
1115 if !is_serendb_target(url) {
1116 return None;
1117 }
1118
1119 // Hostname format: <database-id>.<branch-id>.<project-id>.serendb.com
1120 // Or with custom subdomains: <database-id>.<branch-id>.<project-id>.<custom>.serendb.com
1121 // We want the last three parts before .serendb.com
1122 let host_parts: Vec<&str> = parts.host.split('.').collect();
1123
1124 if host_parts.len() < 4 {
1125 return None; // Not enough parts for SerenDB ID format
1126 }
1127
1128 let num_host_parts = host_parts.len();
1129 let database_id = host_parts[num_host_parts - 4].to_string();
1130 let branch_id = host_parts[num_host_parts - 3].to_string();
1131 let project_id = host_parts[num_host_parts - 2].to_string();
1132
1133 // Basic UUID format validation (optional but good for robustness)
1134 // A real UUID check would be more extensive, but string length is a good start
1135 if database_id.len() == 36 && branch_id.len() == 36 && project_id.len() == 36 {
1136 Some((project_id, branch_id, database_id))
1137 } else {
1138 None
1139 }
1140}
1141
1142/// Remove a managed temporary directory
1143///
1144/// Explicitly removes a temporary directory created by `create_managed_temp_dir()`.
1145/// This should be called when the directory is no longer needed.
1146///
1147/// # Arguments
1148///
1149/// * `path` - Path to the temporary directory to remove
1150///
1151/// # Errors
1152///
1153/// Returns an error if the directory cannot be removed.
1154///
1155/// # Examples
1156///
1157/// ```no_run
1158/// # use database_replicator::utils::{create_managed_temp_dir, remove_managed_temp_dir};
1159/// # use anyhow::Result;
1160/// # fn example() -> Result<()> {
1161/// let temp_path = create_managed_temp_dir()?;
1162/// // ... do work ...
1163/// remove_managed_temp_dir(&temp_path)?;
1164/// # Ok(())
1165/// # }
1166/// ```
1167pub fn remove_managed_temp_dir(path: &std::path::Path) -> Result<()> {
1168 use std::fs;
1169
1170 // Verify this is one of our temp directories (safety check)
1171 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1172 if !name.starts_with("postgres-seren-replicator-") {
1173 bail!(
1174 "Refusing to remove directory that doesn't match our naming pattern: {}",
1175 path.display()
1176 );
1177 }
1178 } else {
1179 bail!("Invalid temp directory path: {}", path.display());
1180 }
1181
1182 tracing::debug!("Removing managed temp directory: {}", path.display());
1183
1184 fs::remove_dir_all(path)
1185 .with_context(|| format!("Failed to remove temp directory at {}", path.display()))?;
1186
1187 Ok(())
1188}
1189
1190/// Replace the database name in a connection string URL
1191///
1192/// This is used internally by SerenDB to provide a generic connection string
1193/// which then needs to be specialized for a particular database.
1194///
1195/// # Arguments
1196///
1197/// * `url` - The connection string URL (e.g., postgresql://host/template_db)
1198/// * `new_db` - The new database name to insert into the URL
1199///
1200/// # Returns
1201///
1202/// A new URL string with the database name replaced.
1203///
1204/// # Errors
1205///
1206/// Returns an error if the URL is invalid and cannot be parsed.
1207pub fn replace_database_in_connection_string(url: &str, new_db: &str) -> Result<String> {
1208 let mut parsed = Url::parse(url).context("Invalid connection string URL")?;
1209 parsed.set_path(&format!("/{}", new_db));
1210
1211 Ok(parsed.to_string())
1212}
1213
1214/// Check if a PostgreSQL URL points to a SerenDB instance
1215///
1216/// SerenDB hosts have domains ending with `.serendb.com`
1217///
1218/// # Arguments
1219///
1220/// * `url` - PostgreSQL connection string to check
1221///
1222/// # Returns
1223///
1224/// Returns `true` if the URL points to a SerenDB host.
1225///
1226/// # Examples
1227///
1228/// ```
1229/// use database_replicator::utils::is_serendb_target;
1230///
1231/// assert!(is_serendb_target("postgresql://user:pass@db.serendb.com/mydb"));
1232/// assert!(is_serendb_target("postgresql://user:pass@cluster-123.console.serendb.com/mydb"));
1233/// assert!(!is_serendb_target("postgresql://user:pass@localhost/mydb"));
1234/// assert!(!is_serendb_target("postgresql://user:pass@rds.amazonaws.com/mydb"));
1235/// ```
1236pub fn is_serendb_target(url: &str) -> bool {
1237 match parse_postgres_url(url) {
1238 Ok(parts) => parts.host.ends_with(".serendb.com") || parts.host == "serendb.com",
1239 Err(_) => false,
1240 }
1241}
1242
1243/// Get the major version of a PostgreSQL client tool (pg_dump, psql, etc.)
1244///
1245/// Executes `<tool> --version` and parses the output.
1246///
1247/// # Arguments
1248///
1249/// * `tool` - Name of the tool (e.g., "pg_dump", "psql")
1250///
1251/// # Returns
1252///
1253/// The major version number (e.g., 16 for pg_dump 16.10)
1254///
1255/// # Errors
1256///
1257/// Returns an error if:
1258/// - Tool is not found in PATH
1259/// - Tool execution fails
1260/// - Version output cannot be parsed
1261///
1262/// # Examples
1263///
1264/// ```no_run
1265/// use database_replicator::utils::get_pg_tool_version;
1266/// use anyhow::Result;
1267///
1268/// fn example() -> Result<()> {
1269/// let version = get_pg_tool_version("pg_dump")?;
1270/// println!("pg_dump major version: {}", version); // e.g., 16
1271/// Ok(())
1272/// }
1273/// ```
1274pub fn get_pg_tool_version(tool: &str) -> Result<u32> {
1275 use std::process::Command;
1276
1277 let path = which(tool).with_context(|| format!("{} not found in PATH", tool))?;
1278
1279 let output = Command::new(&path)
1280 .arg("--version")
1281 .output()
1282 .with_context(|| format!("Failed to execute {} --version", tool))?;
1283
1284 let version_str = String::from_utf8_lossy(&output.stdout);
1285 parse_pg_version_string(&version_str)
1286}
1287
1288/// Parse major version from PostgreSQL version string
1289///
1290/// Handles formats like:
1291/// - "pg_dump (PostgreSQL) 16.10 (Ubuntu 16.10-0ubuntu0.24.04.1)"
1292/// - "psql (PostgreSQL) 17.2"
1293/// - "17.2 (Debian 17.2-1.pgdg120+1)"
1294///
1295/// # Arguments
1296///
1297/// * `version_str` - Version string output from a PostgreSQL tool
1298///
1299/// # Returns
1300///
1301/// The major version number (e.g., 16, 17)
1302///
1303/// # Errors
1304///
1305/// Returns an error if the version cannot be parsed.
1306pub fn parse_pg_version_string(version_str: &str) -> Result<u32> {
1307 // Look for version pattern: major.minor
1308 for word in version_str.split_whitespace() {
1309 if let Some(major_str) = word.split('.').next() {
1310 if let Ok(major) = major_str.parse::<u32>() {
1311 // Valid PostgreSQL versions are between 9 and 99
1312 if (9..=99).contains(&major) {
1313 return Ok(major);
1314 }
1315 }
1316 }
1317 }
1318 bail!("Could not parse PostgreSQL version from: {}", version_str)
1319}
1320
1321#[cfg(test)]
1322mod tests {
1323 use super::*;
1324
1325 #[test]
1326 fn test_validate_connection_string_valid() {
1327 assert!(validate_connection_string("postgresql://user:pass@localhost:5432/dbname").is_ok());
1328 assert!(validate_connection_string("postgres://user@host/db").is_ok());
1329 }
1330
1331 #[test]
1332 fn test_check_required_tools() {
1333 // This test will pass if PostgreSQL client tools are installed
1334 // It will fail (appropriately) if they're not installed
1335 let result = check_required_tools();
1336
1337 // On systems with PostgreSQL installed, this should pass
1338 // On systems without it, we expect a specific error message
1339 if let Err(err) = result {
1340 let err_msg = err.to_string();
1341 assert!(err_msg.contains("Missing required PostgreSQL client tools"));
1342 assert!(
1343 err_msg.contains("pg_dump")
1344 || err_msg.contains("pg_dumpall")
1345 || err_msg.contains("psql")
1346 );
1347 }
1348 }
1349
1350 #[test]
1351 fn test_validate_connection_string_invalid() {
1352 assert!(validate_connection_string("").is_err());
1353 assert!(validate_connection_string(" ").is_err());
1354 assert!(validate_connection_string("mysql://localhost/db").is_err());
1355 assert!(validate_connection_string("postgresql://localhost").is_err());
1356 assert!(validate_connection_string("postgresql://localhost/db").is_err());
1357 // Missing user
1358 }
1359
1360 #[test]
1361 fn test_sanitize_identifier() {
1362 assert_eq!(sanitize_identifier("normal_table"), "normal_table");
1363 assert_eq!(sanitize_identifier("table\x00name"), "tablename");
1364 assert_eq!(sanitize_identifier("table\nname"), "tablename");
1365
1366 // Test length limit
1367 let long_name = "a".repeat(200);
1368 assert_eq!(sanitize_identifier(&long_name).len(), 100);
1369 }
1370
1371 #[tokio::test]
1372 async fn test_retry_with_backoff_success() {
1373 let mut attempts = 0;
1374 let result = retry_with_backoff(
1375 || {
1376 attempts += 1;
1377 async move {
1378 if attempts < 3 {
1379 anyhow::bail!("Temporary failure")
1380 } else {
1381 Ok("Success")
1382 }
1383 }
1384 },
1385 5,
1386 Duration::from_millis(10),
1387 )
1388 .await;
1389
1390 assert!(result.is_ok());
1391 assert_eq!(result.unwrap(), "Success");
1392 assert_eq!(attempts, 3);
1393 }
1394
1395 #[tokio::test]
1396 async fn test_retry_with_backoff_failure() {
1397 let mut attempts = 0;
1398 let result: Result<&str> = retry_with_backoff(
1399 || {
1400 attempts += 1;
1401 async move { anyhow::bail!("Permanent failure") }
1402 },
1403 2,
1404 Duration::from_millis(10),
1405 )
1406 .await;
1407
1408 assert!(result.is_err());
1409 assert_eq!(attempts, 3); // Initial + 2 retries
1410 }
1411
1412 #[test]
1413 fn test_validate_source_target_different_valid() {
1414 // Different hosts
1415 assert!(validate_source_target_different(
1416 "postgresql://user:pass@source.com:5432/db",
1417 "postgresql://user:pass@target.com:5432/db"
1418 )
1419 .is_ok());
1420
1421 // Different databases on same host
1422 assert!(validate_source_target_different(
1423 "postgresql://user:pass@host:5432/db1",
1424 "postgresql://user:pass@host:5432/db2"
1425 )
1426 .is_ok());
1427
1428 // Different ports on same host
1429 assert!(validate_source_target_different(
1430 "postgresql://user:pass@host:5432/db",
1431 "postgresql://user:pass@host:5433/db"
1432 )
1433 .is_ok());
1434
1435 // Different users on same host/db (edge case but allowed)
1436 assert!(validate_source_target_different(
1437 "postgresql://user1:pass@host:5432/db",
1438 "postgresql://user2:pass@host:5432/db"
1439 )
1440 .is_ok());
1441 }
1442
1443 #[test]
1444 fn test_validate_source_target_different_invalid() {
1445 // Exact same URL
1446 assert!(validate_source_target_different(
1447 "postgresql://user:pass@host:5432/db",
1448 "postgresql://user:pass@host:5432/db"
1449 )
1450 .is_err());
1451
1452 // Same URL with different scheme (postgres vs postgresql)
1453 assert!(validate_source_target_different(
1454 "postgres://user:pass@host:5432/db",
1455 "postgresql://user:pass@host:5432/db"
1456 )
1457 .is_err());
1458
1459 // Same URL with default port vs explicit port
1460 assert!(validate_source_target_different(
1461 "postgresql://user:pass@host/db",
1462 "postgresql://user:pass@host:5432/db"
1463 )
1464 .is_err());
1465
1466 // Same URL with different query parameters (still same database)
1467 assert!(validate_source_target_different(
1468 "postgresql://user:pass@host:5432/db?sslmode=require",
1469 "postgresql://user:pass@host:5432/db?sslmode=prefer"
1470 )
1471 .is_err());
1472
1473 // Same host with different case (hostnames are case-insensitive)
1474 assert!(validate_source_target_different(
1475 "postgresql://user:pass@HOST.COM:5432/db",
1476 "postgresql://user:pass@host.com:5432/db"
1477 )
1478 .is_err());
1479 }
1480
1481 #[test]
1482 fn test_parse_postgres_url() {
1483 // Full URL with all components including password
1484 let parts = parse_postgres_url("postgresql://myuser:mypass@localhost:5432/mydb").unwrap();
1485 assert_eq!(parts.host, "localhost");
1486 assert_eq!(parts.port, 5432);
1487 assert_eq!(parts.database, "mydb");
1488 assert_eq!(parts.user, Some("myuser".to_string()));
1489 assert_eq!(parts.password, Some("mypass".to_string()));
1490
1491 // URL without port (should default to 5432)
1492 let parts = parse_postgres_url("postgresql://user:pass@host/db").unwrap();
1493 assert_eq!(parts.host, "host");
1494 assert_eq!(parts.port, 5432);
1495 assert_eq!(parts.database, "db");
1496 assert_eq!(parts.user, Some("user".to_string()));
1497 assert_eq!(parts.password, Some("pass".to_string()));
1498
1499 // URL with user but no password
1500 let parts = parse_postgres_url("postgresql://user@host/db").unwrap();
1501 assert_eq!(parts.host, "host");
1502 assert_eq!(parts.user, Some("user".to_string()));
1503 assert_eq!(parts.password, None);
1504
1505 // URL without authentication
1506 let parts = parse_postgres_url("postgresql://host:5433/db").unwrap();
1507 assert_eq!(parts.host, "host");
1508 assert_eq!(parts.port, 5433);
1509 assert_eq!(parts.database, "db");
1510 assert_eq!(parts.user, None);
1511 assert_eq!(parts.password, None);
1512
1513 // URL with query parameters
1514 let parts = parse_postgres_url("postgresql://user:pass@host/db?sslmode=require").unwrap();
1515 assert_eq!(parts.host, "host");
1516 assert_eq!(parts.database, "db");
1517 assert_eq!(parts.password, Some("pass".to_string()));
1518
1519 // URL with postgres:// scheme (alternative)
1520 let parts = parse_postgres_url("postgres://user:pass@host/db").unwrap();
1521 assert_eq!(parts.host, "host");
1522 assert_eq!(parts.database, "db");
1523 assert_eq!(parts.password, Some("pass".to_string()));
1524
1525 // Host normalization (lowercase)
1526 let parts = parse_postgres_url("postgresql://user:pass@HOST.COM/db").unwrap();
1527 assert_eq!(parts.host, "host.com");
1528 assert_eq!(parts.password, Some("pass".to_string()));
1529
1530 // Password with special characters
1531 let parts = parse_postgres_url("postgresql://user:p@ss!word@host/db").unwrap();
1532 assert_eq!(parts.password, Some("p@ss!word".to_string()));
1533 }
1534
1535 #[test]
1536 fn test_validate_postgres_identifier_valid() {
1537 // Valid identifiers
1538 assert!(validate_postgres_identifier("mydb").is_ok());
1539 assert!(validate_postgres_identifier("my_database").is_ok());
1540 assert!(validate_postgres_identifier("_private_db").is_ok());
1541 assert!(validate_postgres_identifier("db123").is_ok());
1542 assert!(validate_postgres_identifier("Database_2024").is_ok());
1543
1544 // Maximum length (63 characters)
1545 let max_length_name = "a".repeat(63);
1546 assert!(validate_postgres_identifier(&max_length_name).is_ok());
1547 }
1548
1549 #[test]
1550 fn test_pgpass_file_creation() {
1551 let parts = PostgresUrlParts {
1552 host: "localhost".to_string(),
1553 port: 5432,
1554 database: "testdb".to_string(),
1555 user: Some("testuser".to_string()),
1556 password: Some("testpass".to_string()),
1557 query_params: std::collections::HashMap::new(),
1558 };
1559
1560 let pgpass = PgPassFile::new(&parts).unwrap();
1561 assert!(pgpass.path().exists());
1562
1563 // Verify file content
1564 let content = std::fs::read_to_string(pgpass.path()).unwrap();
1565 assert_eq!(content, "localhost:5432:testdb:testuser:testpass\n");
1566
1567 // Verify permissions on Unix
1568 #[cfg(unix)]
1569 {
1570 use std::os::unix::fs::PermissionsExt;
1571 let metadata = std::fs::metadata(pgpass.path()).unwrap();
1572 let permissions = metadata.permissions();
1573 assert_eq!(permissions.mode() & 0o777, 0o600);
1574 }
1575
1576 // File should be cleaned up when pgpass is dropped
1577 let path = pgpass.path().to_path_buf();
1578 drop(pgpass);
1579 assert!(!path.exists());
1580 }
1581
1582 #[test]
1583 fn test_pgpass_file_without_password() {
1584 let parts = PostgresUrlParts {
1585 host: "localhost".to_string(),
1586 port: 5432,
1587 database: "testdb".to_string(),
1588 user: Some("testuser".to_string()),
1589 password: None,
1590 query_params: std::collections::HashMap::new(),
1591 };
1592
1593 let pgpass = PgPassFile::new(&parts).unwrap();
1594 let content = std::fs::read_to_string(pgpass.path()).unwrap();
1595 // Should use empty password
1596 assert_eq!(content, "localhost:5432:testdb:testuser:\n");
1597 }
1598
1599 #[test]
1600 fn test_pgpass_file_without_user() {
1601 let parts = PostgresUrlParts {
1602 host: "localhost".to_string(),
1603 port: 5432,
1604 database: "testdb".to_string(),
1605 user: None,
1606 password: Some("testpass".to_string()),
1607 query_params: std::collections::HashMap::new(),
1608 };
1609
1610 let pgpass = PgPassFile::new(&parts).unwrap();
1611 let content = std::fs::read_to_string(pgpass.path()).unwrap();
1612 // Should use wildcard for user
1613 assert_eq!(content, "localhost:5432:testdb:*:testpass\n");
1614 }
1615
1616 #[test]
1617 fn test_strip_password_from_url() {
1618 // With password
1619 let url = "postgresql://user:p@ssw0rd@host:5432/db";
1620 let stripped = strip_password_from_url(url).unwrap();
1621 assert_eq!(stripped, "postgresql://user@host:5432/db");
1622
1623 // With special characters in password
1624 let url = "postgresql://user:p@ss!w0rd@host:5432/db";
1625 let stripped = strip_password_from_url(url).unwrap();
1626 assert_eq!(stripped, "postgresql://user@host:5432/db");
1627
1628 // Without password
1629 let url = "postgresql://user@host:5432/db";
1630 let stripped = strip_password_from_url(url).unwrap();
1631 assert_eq!(stripped, "postgresql://user@host:5432/db");
1632
1633 // With query parameters
1634 let url = "postgresql://user:pass@host:5432/db?sslmode=require";
1635 let stripped = strip_password_from_url(url).unwrap();
1636 assert_eq!(stripped, "postgresql://user@host:5432/db?sslmode=require");
1637
1638 // No user
1639 let url = "postgresql://host:5432/db";
1640 let stripped = strip_password_from_url(url).unwrap();
1641 assert_eq!(stripped, "postgresql://host:5432/db");
1642 }
1643
1644 #[test]
1645 fn test_validate_postgres_identifier_invalid() {
1646 // SQL injection attempts
1647 assert!(validate_postgres_identifier("mydb\"; DROP DATABASE production; --").is_err());
1648 assert!(validate_postgres_identifier("db'; DELETE FROM users; --").is_err());
1649
1650 // Invalid start characters
1651 assert!(validate_postgres_identifier("123db").is_err()); // Starts with digit
1652 assert!(validate_postgres_identifier("$db").is_err()); // Starts with special char
1653 assert!(validate_postgres_identifier("-db").is_err()); // Starts with dash
1654
1655 // Contains invalid characters
1656 assert!(validate_postgres_identifier("my-database").is_err()); // Contains dash
1657 assert!(validate_postgres_identifier("my.database").is_err()); // Contains dot
1658 assert!(validate_postgres_identifier("my database").is_err()); // Contains space
1659 assert!(validate_postgres_identifier("my@db").is_err()); // Contains @
1660 assert!(validate_postgres_identifier("my#db").is_err()); // Contains #
1661
1662 // Empty or too long
1663 assert!(validate_postgres_identifier("").is_err());
1664 assert!(validate_postgres_identifier(" ").is_err());
1665
1666 // Over maximum length (64+ characters)
1667 let too_long = "a".repeat(64);
1668 assert!(validate_postgres_identifier(&too_long).is_err());
1669
1670 // Control characters
1671 assert!(validate_postgres_identifier("my\ndb").is_err());
1672 assert!(validate_postgres_identifier("my\tdb").is_err());
1673 assert!(validate_postgres_identifier("my\x00db").is_err());
1674 }
1675
1676 #[test]
1677 fn test_is_serendb_target() {
1678 // Positive cases - SerenDB hosts
1679 assert!(is_serendb_target(
1680 "postgresql://user:pass@db.serendb.com/mydb"
1681 ));
1682 assert!(is_serendb_target(
1683 "postgresql://user:pass@cluster.console.serendb.com/mydb"
1684 ));
1685 assert!(is_serendb_target(
1686 "postgres://u:p@x.serendb.com:5432/db?sslmode=require"
1687 ));
1688 assert!(is_serendb_target("postgresql://user:pass@serendb.com/mydb"));
1689
1690 // Negative cases - not SerenDB
1691 assert!(!is_serendb_target("postgresql://user:pass@localhost/mydb"));
1692 assert!(!is_serendb_target(
1693 "postgresql://user:pass@rds.amazonaws.com/mydb"
1694 ));
1695 assert!(!is_serendb_target("postgresql://user:pass@neon.tech/mydb"));
1696 // Domain spoofing attempt - should NOT match
1697 assert!(!is_serendb_target(
1698 "postgresql://user:pass@serendb.com.evil.com/mydb"
1699 ));
1700 assert!(!is_serendb_target(
1701 "postgresql://user:pass@notserendb.com/mydb"
1702 ));
1703 // Invalid URL
1704 assert!(!is_serendb_target("not-a-url"));
1705 }
1706
1707 #[test]
1708 fn test_parse_pg_version_string() {
1709 // Standard pg_dump output
1710 assert_eq!(
1711 parse_pg_version_string("pg_dump (PostgreSQL) 16.10 (Ubuntu 16.10-0ubuntu0.24.04.1)")
1712 .unwrap(),
1713 16
1714 );
1715
1716 // Standard psql output
1717 assert_eq!(
1718 parse_pg_version_string("psql (PostgreSQL) 17.2").unwrap(),
1719 17
1720 );
1721
1722 // pg_restore output
1723 assert_eq!(
1724 parse_pg_version_string("pg_restore (PostgreSQL) 15.4").unwrap(),
1725 15
1726 );
1727
1728 // Debian-style version
1729 assert_eq!(
1730 parse_pg_version_string("17.2 (Debian 17.2-1.pgdg120+1)").unwrap(),
1731 17
1732 );
1733
1734 // Should fail on invalid input
1735 assert!(parse_pg_version_string("not a version").is_err());
1736 assert!(parse_pg_version_string("version 1.2.3").is_err()); // 1 is < 9
1737 assert!(parse_pg_version_string("").is_err());
1738 }
1739
1740 #[test]
1741 fn test_get_pg_tool_version() {
1742 // This test will only pass if pg_dump is installed
1743 // Skip gracefully if not available
1744 if which("pg_dump").is_ok() {
1745 let version = get_pg_tool_version("pg_dump").unwrap();
1746 assert!(
1747 version >= 12,
1748 "Expected pg_dump version >= 12, got {}",
1749 version
1750 );
1751 assert!(
1752 version <= 99,
1753 "Expected pg_dump version <= 99, got {}",
1754 version
1755 );
1756 }
1757
1758 // Non-existent tool should fail
1759 assert!(get_pg_tool_version("nonexistent_pg_tool_xyz").is_err());
1760 }
1761}