database_replicator/utils.rs
1// ABOUTME: Utility functions for validation and error handling
2// ABOUTME: Provides input validation, retry logic, and resource cleanup
3
4use anyhow::{bail, Context, Result};
5use std::time::Duration;
6use url::Url;
7use which::which;
8
9/// Get TCP keepalive environment variables for PostgreSQL client tools
10///
11/// Returns environment variables that configure TCP keepalives for external
12/// PostgreSQL tools (pg_dump, pg_restore, psql, pg_dumpall). These prevent
13/// idle connection timeouts when connecting through load balancers like AWS ELB.
14///
15/// Environment variables returned:
16/// - `PGKEEPALIVES=1`: Enable TCP keepalives
17/// - `PGKEEPALIVESIDLE=60`: Send first keepalive after 60 seconds of idle time
18/// - `PGKEEPALIVESINTERVAL=10`: Send subsequent keepalives every 10 seconds
19///
20/// # Returns
21///
22/// A vector of (variable_name, value) tuples to be passed to subprocess commands
23///
24/// # Examples
25///
26/// ```
27/// # use database_replicator::utils::get_keepalive_env_vars;
28/// # use std::process::Command;
29/// let keepalive_vars = get_keepalive_env_vars();
30/// let mut cmd = Command::new("psql");
31/// for (key, value) in keepalive_vars {
32/// cmd.env(key, value);
33/// }
34/// ```
35pub fn get_keepalive_env_vars() -> Vec<(&'static str, &'static str)> {
36 vec![
37 ("PGKEEPALIVES", "1"),
38 ("PGKEEPALIVESIDLE", "60"),
39 ("PGKEEPALIVESINTERVAL", "10"),
40 ]
41}
42
43/// Validate a PostgreSQL connection string
44///
45/// Checks that the connection string has proper format and required components:
46/// - Starts with "postgres://" or "postgresql://"
47/// - Contains user credentials (@ symbol)
48/// - Contains database name (/ separator with at least 3 occurrences)
49///
50/// # Arguments
51///
52/// * `url` - Connection string to validate
53///
54/// # Returns
55///
56/// Returns `Ok(())` if the connection string is valid.
57///
58/// # Errors
59///
60/// Returns an error with helpful message if the connection string is:
61/// - Empty or whitespace only
62/// - Missing proper scheme (postgres:// or postgresql://)
63/// - Missing user credentials (@ symbol)
64/// - Missing database name
65///
66/// # Examples
67///
68/// ```
69/// # use database_replicator::utils::validate_connection_string;
70/// # use anyhow::Result;
71/// # fn example() -> Result<()> {
72/// // Valid connection strings
73/// validate_connection_string("postgresql://user:pass@localhost:5432/mydb")?;
74/// validate_connection_string("postgres://user@host/db")?;
75///
76/// // Invalid - will return error
77/// assert!(validate_connection_string("").is_err());
78/// assert!(validate_connection_string("mysql://localhost/db").is_err());
79/// # Ok(())
80/// # }
81/// ```
82pub fn validate_connection_string(url: &str) -> Result<()> {
83 if url.trim().is_empty() {
84 bail!("Connection string cannot be empty");
85 }
86
87 // Check for common URL schemes
88 if !url.starts_with("postgres://") && !url.starts_with("postgresql://") {
89 bail!(
90 "Invalid connection string format.\n\
91 Expected format: postgresql://user:password@host:port/database\n\
92 Got: {}",
93 url
94 );
95 }
96
97 // Check for minimum required components (user@host/database)
98 if !url.contains('@') {
99 bail!(
100 "Connection string missing user credentials.\n\
101 Expected format: postgresql://user:password@host:port/database"
102 );
103 }
104
105 if !url.contains('/') || url.matches('/').count() < 3 {
106 bail!(
107 "Connection string missing database name.\n\
108 Expected format: postgresql://user:password@host:port/database"
109 );
110 }
111
112 Ok(())
113}
114
115/// Check that required PostgreSQL client tools are available
116///
117/// Verifies that the following tools are installed and in PATH:
118/// - `pg_dump` - For dumping database schema and data
119/// - `pg_dumpall` - For dumping global objects (roles, tablespaces)
120/// - `psql` - For restoring databases
121///
122/// # Returns
123///
124/// Returns `Ok(())` if all required tools are found.
125///
126/// # Errors
127///
128/// Returns an error with installation instructions if any tools are missing.
129///
130/// # Examples
131///
132/// ```
133/// # use database_replicator::utils::check_required_tools;
134/// # use anyhow::Result;
135/// # fn example() -> Result<()> {
136/// // Check if PostgreSQL tools are installed
137/// check_required_tools()?;
138/// # Ok(())
139/// # }
140/// ```
141pub fn check_required_tools() -> Result<()> {
142 let tools = ["pg_dump", "pg_dumpall", "psql"];
143 let mut missing = Vec::new();
144
145 for tool in &tools {
146 if which(tool).is_err() {
147 missing.push(*tool);
148 }
149 }
150
151 if !missing.is_empty() {
152 bail!(
153 "Missing required PostgreSQL client tools: {}\n\
154 \n\
155 Please install PostgreSQL client tools:\n\
156 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
157 - macOS: brew install postgresql\n\
158 - RHEL/CentOS: sudo yum install postgresql\n\
159 - Windows: Download from https://www.postgresql.org/download/windows/",
160 missing.join(", ")
161 );
162 }
163
164 Ok(())
165}
166
167/// Retry a function with exponential backoff
168///
169/// Executes an async operation with automatic retry on failure. Each retry doubles
170/// the delay (exponential backoff) to handle transient failures gracefully.
171///
172/// # Arguments
173///
174/// * `operation` - Async function to retry (FnMut returning Future\<Output = Result\<T\>\>)
175/// * `max_retries` - Maximum number of retry attempts (0 = no retries, just initial attempt)
176/// * `initial_delay` - Delay before first retry (doubles each subsequent retry)
177///
178/// # Returns
179///
180/// Returns the successful result or the last error after all retries exhausted.
181///
182/// # Examples
183///
184/// ```no_run
185/// # use anyhow::Result;
186/// # use std::time::Duration;
187/// # use database_replicator::utils::retry_with_backoff;
188/// # async fn example() -> Result<()> {
189/// let result = retry_with_backoff(
190/// || async { Ok("success") },
191/// 3, // Try up to 3 times
192/// Duration::from_secs(1) // Start with 1s delay
193/// ).await?;
194/// # Ok(())
195/// # }
196/// ```
197pub async fn retry_with_backoff<F, Fut, T>(
198 mut operation: F,
199 max_retries: u32,
200 initial_delay: Duration,
201) -> Result<T>
202where
203 F: FnMut() -> Fut,
204 Fut: std::future::Future<Output = Result<T>>,
205{
206 let mut delay = initial_delay;
207 let mut last_error = None;
208
209 for attempt in 0..=max_retries {
210 match operation().await {
211 Ok(result) => return Ok(result),
212 Err(e) => {
213 last_error = Some(e);
214
215 if attempt < max_retries {
216 tracing::warn!(
217 "Operation failed (attempt {}/{}), retrying in {:?}...",
218 attempt + 1,
219 max_retries + 1,
220 delay
221 );
222 tokio::time::sleep(delay).await;
223 delay *= 2; // Exponential backoff
224 }
225 }
226 }
227 }
228
229 Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Operation failed after retries")))
230}
231
232/// Retry a subprocess execution with exponential backoff on connection errors
233///
234/// Executes a subprocess command with automatic retry on connection-related failures.
235/// Each retry doubles the delay (exponential backoff) to handle transient connection issues.
236///
237/// Connection errors are detected by checking:
238/// - Non-zero exit codes
239/// - Stderr output containing connection-related error patterns:
240/// - "connection closed"
241/// - "connection refused"
242/// - "could not connect"
243/// - "server closed the connection"
244/// - "timeout"
245/// - "Connection timed out"
246///
247/// # Arguments
248///
249/// * `operation` - Function that executes a Command and returns the exit status
250/// * `max_retries` - Maximum number of retry attempts (0 = no retries, just initial attempt)
251/// * `initial_delay` - Delay before first retry (doubles each subsequent retry)
252/// * `operation_name` - Name of the operation for logging (e.g., "pg_restore", "psql")
253///
254/// # Returns
255///
256/// Returns Ok(()) on success or the last error after all retries exhausted.
257///
258/// # Examples
259///
260/// ```no_run
261/// # use anyhow::Result;
262/// # use std::time::Duration;
263/// # use std::process::Command;
264/// # use database_replicator::utils::retry_subprocess_with_backoff;
265/// # async fn example() -> Result<()> {
266/// retry_subprocess_with_backoff(
267/// || {
268/// let mut cmd = Command::new("psql");
269/// cmd.arg("--version");
270/// cmd.status().map_err(anyhow::Error::from)
271/// },
272/// 3, // Try up to 3 times
273/// Duration::from_secs(1), // Start with 1s delay
274/// "psql"
275/// ).await?;
276/// # Ok(())
277/// # }
278/// ```
279pub async fn retry_subprocess_with_backoff<F>(
280 mut operation: F,
281 max_retries: u32,
282 initial_delay: Duration,
283 operation_name: &str,
284) -> Result<()>
285where
286 F: FnMut() -> Result<std::process::ExitStatus>,
287{
288 let mut delay = initial_delay;
289 let mut last_error = None;
290
291 for attempt in 0..=max_retries {
292 match operation() {
293 Ok(status) => {
294 if status.success() {
295 return Ok(());
296 } else {
297 // Non-zero exit code - check if it's a connection error
298 // We can't easily capture stderr here, so we'll treat all non-zero
299 // exit codes as potential connection errors for now
300 let error = anyhow::anyhow!(
301 "{} failed with exit code: {}",
302 operation_name,
303 status.code().unwrap_or(-1)
304 );
305 last_error = Some(error);
306
307 if attempt < max_retries {
308 tracing::warn!(
309 "{} failed (attempt {}/{}), retrying in {:?}...",
310 operation_name,
311 attempt + 1,
312 max_retries + 1,
313 delay
314 );
315 tokio::time::sleep(delay).await;
316 delay *= 2; // Exponential backoff
317 }
318 }
319 }
320 Err(e) => {
321 last_error = Some(e);
322
323 if attempt < max_retries {
324 tracing::warn!(
325 "{} failed (attempt {}/{}): {}, retrying in {:?}...",
326 operation_name,
327 attempt + 1,
328 max_retries + 1,
329 last_error.as_ref().unwrap(),
330 delay
331 );
332 tokio::time::sleep(delay).await;
333 delay *= 2; // Exponential backoff
334 }
335 }
336 }
337 }
338
339 Err(last_error.unwrap_or_else(|| {
340 anyhow::anyhow!("{} failed after {} retries", operation_name, max_retries)
341 }))
342}
343
344/// Validate a PostgreSQL identifier (database name, schema name, etc.)
345///
346/// Validates that an identifier follows PostgreSQL naming rules to prevent SQL injection.
347/// PostgreSQL identifiers must:
348/// - Be 1-63 characters long
349/// - Start with a letter (a-z, A-Z) or underscore (_)
350/// - Contain only letters, digits (0-9), or underscores
351///
352/// # Arguments
353///
354/// * `identifier` - The identifier to validate (database name, schema name, etc.)
355///
356/// # Returns
357///
358/// Returns `Ok(())` if the identifier is valid.
359///
360/// # Errors
361///
362/// Returns an error if the identifier:
363/// - Is empty or whitespace-only
364/// - Exceeds 63 characters
365/// - Starts with an invalid character (digit or special character)
366/// - Contains invalid characters (anything except a-z, A-Z, 0-9, _)
367///
368/// # Security
369///
370/// This function is critical for preventing SQL injection attacks. All database
371/// names, schema names, and table names from untrusted sources MUST be validated
372/// before use in SQL statements.
373///
374/// # Examples
375///
376/// ```
377/// # use database_replicator::utils::validate_postgres_identifier;
378/// # use anyhow::Result;
379/// # fn example() -> Result<()> {
380/// // Valid identifiers
381/// validate_postgres_identifier("mydb")?;
382/// validate_postgres_identifier("my_database")?;
383/// validate_postgres_identifier("_private_db")?;
384///
385/// // Invalid - will return error
386/// assert!(validate_postgres_identifier("123db").is_err());
387/// assert!(validate_postgres_identifier("my-database").is_err());
388/// assert!(validate_postgres_identifier("db\"; DROP TABLE users; --").is_err());
389/// # Ok(())
390/// # }
391/// ```
392pub fn validate_postgres_identifier(identifier: &str) -> Result<()> {
393 // Check for empty or whitespace-only
394 let trimmed = identifier.trim();
395 if trimmed.is_empty() {
396 bail!("Identifier cannot be empty or whitespace-only");
397 }
398
399 // Check length (PostgreSQL limit is 63 characters)
400 if trimmed.len() > 63 {
401 bail!(
402 "Identifier '{}' exceeds maximum length of 63 characters (got {})",
403 sanitize_identifier(trimmed),
404 trimmed.len()
405 );
406 }
407
408 // Get first character
409 let first_char = trimmed.chars().next().unwrap();
410
411 // First character must be a letter or underscore
412 if !first_char.is_ascii_alphabetic() && first_char != '_' {
413 bail!(
414 "Identifier '{}' must start with a letter or underscore, not '{}'",
415 sanitize_identifier(trimmed),
416 first_char
417 );
418 }
419
420 // All characters must be alphanumeric or underscore
421 for (i, c) in trimmed.chars().enumerate() {
422 if !c.is_ascii_alphanumeric() && c != '_' {
423 bail!(
424 "Identifier '{}' contains invalid character '{}' at position {}. \
425 Only letters, digits, and underscores are allowed",
426 sanitize_identifier(trimmed),
427 if c.is_control() {
428 format!("\\x{:02x}", c as u32)
429 } else {
430 c.to_string()
431 },
432 i
433 );
434 }
435 }
436
437 Ok(())
438}
439
440/// Sanitize an identifier (table name, schema name, etc.) for display
441///
442/// Removes control characters and limits length to prevent log injection attacks
443/// and ensure readable error messages.
444///
445/// **Note**: This is for display purposes only. For SQL safety, use parameterized
446/// queries instead.
447///
448/// # Arguments
449///
450/// * `identifier` - The identifier to sanitize (table name, schema name, etc.)
451///
452/// # Returns
453///
454/// Sanitized string with control characters removed and length limited to 100 chars.
455///
456/// # Examples
457///
458/// ```
459/// # use database_replicator::utils::sanitize_identifier;
460/// assert_eq!(sanitize_identifier("normal_table"), "normal_table");
461/// assert_eq!(sanitize_identifier("table\x00name"), "tablename");
462/// assert_eq!(sanitize_identifier("table\nname"), "tablename");
463///
464/// // Length limit
465/// let long_name = "a".repeat(200);
466/// assert_eq!(sanitize_identifier(&long_name).len(), 100);
467/// ```
468pub fn sanitize_identifier(identifier: &str) -> String {
469 // Remove any control characters and limit length for display
470 identifier
471 .chars()
472 .filter(|c| !c.is_control())
473 .take(100)
474 .collect()
475}
476
477/// Quote a PostgreSQL identifier (database, schema, table, column)
478///
479/// Assumes the identifier has already been validated. Escapes embedded quotes
480/// and wraps the identifier in double quotes.
481pub fn quote_ident(identifier: &str) -> String {
482 let mut quoted = String::with_capacity(identifier.len() + 2);
483 quoted.push('"');
484 for ch in identifier.chars() {
485 if ch == '"' {
486 quoted.push('"');
487 }
488 quoted.push(ch);
489 }
490 quoted.push('"');
491 quoted
492}
493
494/// Quote a SQL string literal (for use in SQL statements)
495///
496/// Escapes single quotes by doubling them and wraps the string in single quotes.
497/// Use this for string values in SQL, not for identifiers.
498///
499/// # Examples
500///
501/// ```
502/// use database_replicator::utils::quote_literal;
503/// assert_eq!(quote_literal("hello"), "'hello'");
504/// assert_eq!(quote_literal("it's"), "'it''s'");
505/// assert_eq!(quote_literal(""), "''");
506/// ```
507pub fn quote_literal(value: &str) -> String {
508 let mut quoted = String::with_capacity(value.len() + 2);
509 quoted.push('\'');
510 for ch in value.chars() {
511 if ch == '\'' {
512 quoted.push('\'');
513 }
514 quoted.push(ch);
515 }
516 quoted.push('\'');
517 quoted
518}
519
520/// Quote a MySQL identifier (database, table, column)
521///
522/// MySQL uses backticks for identifier quoting. Escapes embedded backticks
523/// by doubling them.
524///
525/// # Examples
526///
527/// ```
528/// use database_replicator::utils::quote_mysql_ident;
529/// assert_eq!(quote_mysql_ident("users"), "`users`");
530/// assert_eq!(quote_mysql_ident("user`name"), "`user``name`");
531/// ```
532pub fn quote_mysql_ident(identifier: &str) -> String {
533 let mut quoted = String::with_capacity(identifier.len() + 2);
534 quoted.push('`');
535 for ch in identifier.chars() {
536 if ch == '`' {
537 quoted.push('`');
538 }
539 quoted.push(ch);
540 }
541 quoted.push('`');
542 quoted
543}
544
545/// Validate that source and target URLs are different to prevent accidental data loss
546///
547/// Compares two PostgreSQL connection URLs to ensure they point to different databases.
548/// This is critical for preventing data loss from operations like `init --drop-existing`
549/// where using the same URL for source and target would destroy the source data.
550///
551/// # Comparison Strategy
552///
553/// URLs are normalized and compared on:
554/// - Host (case-insensitive)
555/// - Port (defaulting to 5432 if not specified)
556/// - Database name (case-sensitive)
557/// - User (if present)
558///
559/// Query parameters (like SSL settings) are ignored as they don't affect database identity.
560///
561/// # Arguments
562///
563/// * `source_url` - Source database connection string
564/// * `target_url` - Target database connection string
565///
566/// # Returns
567///
568/// Returns `Ok(())` if the URLs point to different databases.
569///
570/// # Errors
571///
572/// Returns an error if:
573/// - The URLs point to the same database (same host, port, database name, and user)
574/// - Either URL is malformed and cannot be parsed
575///
576/// # Examples
577///
578/// ```
579/// # use database_replicator::utils::validate_source_target_different;
580/// # use anyhow::Result;
581/// # fn example() -> Result<()> {
582/// // Valid - different hosts
583/// validate_source_target_different(
584/// "postgresql://user:pass@source.com:5432/db",
585/// "postgresql://user:pass@target.com:5432/db"
586/// )?;
587///
588/// // Valid - different databases
589/// validate_source_target_different(
590/// "postgresql://user:pass@host:5432/db1",
591/// "postgresql://user:pass@host:5432/db2"
592/// )?;
593///
594/// // Invalid - same database
595/// assert!(validate_source_target_different(
596/// "postgresql://user:pass@host:5432/db",
597/// "postgresql://user:pass@host:5432/db"
598/// ).is_err());
599/// # Ok(())
600/// # }
601/// ```
602pub fn validate_source_target_different(source_url: &str, target_url: &str) -> Result<()> {
603 // Parse both URLs to extract components
604 let source_parts = parse_postgres_url(source_url)
605 .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
606 let target_parts = parse_postgres_url(target_url)
607 .with_context(|| format!("Failed to parse target URL: {}", target_url))?;
608
609 // Compare normalized components
610 if source_parts.host == target_parts.host
611 && source_parts.port == target_parts.port
612 && source_parts.database == target_parts.database
613 && source_parts.user == target_parts.user
614 {
615 bail!(
616 "Source and target URLs point to the same database!\n\
617 \n\
618 This would cause DATA LOSS - the target would overwrite the source.\n\
619 \n\
620 Source: {}@{}:{}/{}\n\
621 Target: {}@{}:{}/{}\n\
622 \n\
623 Please ensure source and target are different databases.\n\
624 Common causes:\n\
625 - Copy-paste error in connection strings\n\
626 - Wrong environment variables (e.g., SOURCE_URL == TARGET_URL)\n\
627 - Typo in database name or host",
628 source_parts.user.as_deref().unwrap_or("(no user)"),
629 source_parts.host,
630 source_parts.port,
631 source_parts.database,
632 target_parts.user.as_deref().unwrap_or("(no user)"),
633 target_parts.host,
634 target_parts.port,
635 target_parts.database
636 );
637 }
638
639 Ok(())
640}
641
642/// Parse a PostgreSQL URL into its components
643///
644/// # Arguments
645///
646/// * `url` - PostgreSQL connection URL (postgres:// or postgresql://)
647///
648/// # Returns
649///
650/// Returns a `PostgresUrlParts` struct with normalized components.
651///
652/// # Security
653///
654/// This function extracts passwords from URLs for use with .pgpass files.
655/// Ensure returned values are handled securely and not logged.
656pub fn parse_postgres_url(url: &str) -> Result<PostgresUrlParts> {
657 // Remove scheme
658 let url_without_scheme = url
659 .trim_start_matches("postgres://")
660 .trim_start_matches("postgresql://");
661
662 // Split into base and query params
663 let (base, query_string) = if let Some((b, q)) = url_without_scheme.split_once('?') {
664 (b, Some(q))
665 } else {
666 (url_without_scheme, None)
667 };
668
669 // Parse query parameters into HashMap
670 let mut query_params = std::collections::HashMap::new();
671 if let Some(query) = query_string {
672 for param in query.split('&') {
673 if let Some((key, value)) = param.split_once('=') {
674 query_params.insert(key.to_string(), value.to_string());
675 }
676 }
677 }
678
679 // Parse: [user[:password]@]host[:port]/database
680 let (auth_and_host, database) = base
681 .rsplit_once('/')
682 .ok_or_else(|| anyhow::anyhow!("Missing database name in URL"))?;
683
684 // Parse authentication and host
685 // Use rsplit_once to split from the right, so passwords can contain '@'
686 let (user, password, host_and_port) = if let Some((auth, hp)) = auth_and_host.rsplit_once('@') {
687 // Has authentication
688 let (user, pass) = if let Some((u, p)) = auth.split_once(':') {
689 (Some(u.to_string()), Some(p.to_string()))
690 } else {
691 (Some(auth.to_string()), None)
692 };
693 (user, pass, hp)
694 } else {
695 // No authentication
696 (None, None, auth_and_host)
697 };
698
699 // Parse host and port
700 let (host, port) = if let Some((h, p)) = host_and_port.rsplit_once(':') {
701 // Port specified
702 let port = p
703 .parse::<u16>()
704 .with_context(|| format!("Invalid port number: {}", p))?;
705 (h, port)
706 } else {
707 // Use default PostgreSQL port
708 (host_and_port, 5432)
709 };
710
711 Ok(PostgresUrlParts {
712 host: host.to_lowercase(), // Hostnames are case-insensitive
713 port,
714 database: database.to_string(), // Database names are case-sensitive in PostgreSQL
715 user,
716 password,
717 query_params,
718 })
719}
720
721/// Strip password from PostgreSQL connection URL
722/// Returns a new URL with password removed, preserving all other components
723/// This is useful for storing connection strings in places where passwords should not be visible
724pub fn strip_password_from_url(url: &str) -> Result<String> {
725 let parts = parse_postgres_url(url)?;
726
727 // Reconstruct URL without password
728 let scheme = if url.starts_with("postgresql://") {
729 "postgresql://"
730 } else if url.starts_with("postgres://") {
731 "postgres://"
732 } else {
733 bail!("Invalid PostgreSQL URL scheme");
734 };
735
736 let mut result = String::from(scheme);
737
738 // Add user if present (without password)
739 if let Some(user) = &parts.user {
740 result.push_str(user);
741 result.push('@');
742 }
743
744 // Add host and port
745 result.push_str(&parts.host);
746 result.push(':');
747 result.push_str(&parts.port.to_string());
748
749 // Add database
750 result.push('/');
751 result.push_str(&parts.database);
752
753 // Preserve query parameters if present
754 if let Some(query_start) = url.find('?') {
755 result.push_str(&url[query_start..]);
756 }
757
758 Ok(result)
759}
760
761/// Parsed components of a PostgreSQL connection URL
762#[derive(Debug, PartialEq)]
763pub struct PostgresUrlParts {
764 pub host: String,
765 pub port: u16,
766 pub database: String,
767 pub user: Option<String>,
768 pub password: Option<String>,
769 pub query_params: std::collections::HashMap<String, String>,
770}
771
772impl PostgresUrlParts {
773 /// Convert query parameters to PostgreSQL environment variables
774 ///
775 /// Maps common connection URL query parameters to their corresponding
776 /// PostgreSQL environment variable names. This allows SSL/TLS and other
777 /// connection settings to be passed to pg_dump, pg_dumpall, psql, etc.
778 ///
779 /// # Supported Parameters
780 ///
781 /// - `sslmode` → `PGSSLMODE`
782 /// - `sslcert` → `PGSSLCERT`
783 /// - `sslkey` → `PGSSLKEY`
784 /// - `sslrootcert` → `PGSSLROOTCERT`
785 /// - `channel_binding` → `PGCHANNELBINDING`
786 /// - `connect_timeout` → `PGCONNECT_TIMEOUT`
787 /// - `application_name` → `PGAPPNAME`
788 /// - `client_encoding` → `PGCLIENTENCODING`
789 ///
790 /// # Returns
791 ///
792 /// Vec of (env_var_name, value) pairs to be set as environment variables
793 pub fn to_pg_env_vars(&self) -> Vec<(&'static str, String)> {
794 let mut env_vars = Vec::new();
795
796 // Map query parameters to PostgreSQL environment variables
797 let param_mapping = [
798 ("sslmode", "PGSSLMODE"),
799 ("sslcert", "PGSSLCERT"),
800 ("sslkey", "PGSSLKEY"),
801 ("sslrootcert", "PGSSLROOTCERT"),
802 ("channel_binding", "PGCHANNELBINDING"),
803 ("connect_timeout", "PGCONNECT_TIMEOUT"),
804 ("application_name", "PGAPPNAME"),
805 ("client_encoding", "PGCLIENTENCODING"),
806 ];
807
808 for (param_name, env_var_name) in param_mapping {
809 if let Some(value) = self.query_params.get(param_name) {
810 env_vars.push((env_var_name, value.clone()));
811 }
812 }
813
814 env_vars
815 }
816}
817
818/// Managed .pgpass file for secure password passing to PostgreSQL tools
819///
820/// This struct creates a temporary .pgpass file with secure permissions (0600)
821/// and automatically cleans it up when dropped. PostgreSQL command-line tools
822/// read credentials from this file instead of accepting passwords in URLs,
823/// which prevents command injection vulnerabilities.
824///
825/// # Security
826///
827/// - File permissions are set to 0600 (owner read/write only)
828/// - File is automatically removed on Drop
829/// - Credentials are never passed on command line
830///
831/// # Format
832///
833/// .pgpass file format: hostname:port:database:username:password
834/// Wildcards (*) are used for maximum compatibility
835///
836/// # Examples
837///
838/// ```no_run
839/// # use database_replicator::utils::{PgPassFile, parse_postgres_url};
840/// # use anyhow::Result;
841/// # fn example() -> Result<()> {
842/// let url = "postgresql://user:pass@localhost:5432/mydb";
843/// let parts = parse_postgres_url(url)?;
844/// let pgpass = PgPassFile::new(&parts)?;
845///
846/// // Use pgpass.path() with PGPASSFILE environment variable
847/// // File is automatically cleaned up when pgpass goes out of scope
848/// # Ok(())
849/// # }
850/// ```
851pub struct PgPassFile {
852 path: std::path::PathBuf,
853}
854
855impl PgPassFile {
856 /// Create a new .pgpass file with credentials from URL parts
857 ///
858 /// # Arguments
859 ///
860 /// * `parts` - Parsed PostgreSQL URL components
861 ///
862 /// # Returns
863 ///
864 /// Returns a PgPassFile that will be automatically cleaned up on Drop
865 ///
866 /// # Errors
867 ///
868 /// Returns an error if the file cannot be created or permissions cannot be set
869 pub fn new(parts: &PostgresUrlParts) -> Result<Self> {
870 use std::fs;
871 use std::io::Write;
872
873 // Create temp file with secure name
874 let temp_dir = std::env::temp_dir();
875 let random: u32 = rand::random();
876 let filename = format!("pgpass-{:08x}", random);
877 let path = temp_dir.join(filename);
878
879 // Write .pgpass entry
880 // Format: hostname:port:database:username:password
881 let username = parts.user.as_deref().unwrap_or("*");
882 let password = parts.password.as_deref().unwrap_or("");
883 let entry = format!(
884 "{}:{}:{}:{}:{}\n",
885 parts.host, parts.port, parts.database, username, password
886 );
887
888 let mut file = fs::File::create(&path)
889 .with_context(|| format!("Failed to create .pgpass file at {}", path.display()))?;
890
891 file.write_all(entry.as_bytes())
892 .with_context(|| format!("Failed to write to .pgpass file at {}", path.display()))?;
893
894 // Set secure permissions (0600) - owner read/write only
895 #[cfg(unix)]
896 {
897 use std::os::unix::fs::PermissionsExt;
898 let permissions = fs::Permissions::from_mode(0o600);
899 fs::set_permissions(&path, permissions).with_context(|| {
900 format!(
901 "Failed to set permissions on .pgpass file at {}",
902 path.display()
903 )
904 })?;
905 }
906
907 // On Windows, .pgpass is stored in %APPDATA%\postgresql\pgpass.conf
908 // but for our temporary use case, we'll just use a temp file
909 // PostgreSQL on Windows also checks permissions but less strictly
910
911 Ok(Self { path })
912 }
913
914 /// Get the path to the .pgpass file
915 ///
916 /// Use this with the PGPASSFILE environment variable when running
917 /// PostgreSQL command-line tools
918 pub fn path(&self) -> &std::path::Path {
919 &self.path
920 }
921}
922
923impl Drop for PgPassFile {
924 fn drop(&mut self) {
925 // Best effort cleanup - don't panic if removal fails
926 let _ = std::fs::remove_file(&self.path);
927 }
928}
929
930/// Create a managed temporary directory with explicit cleanup support
931///
932/// Creates a temporary directory with a timestamped name that can be cleaned up
933/// even if the process is killed with SIGKILL. Unlike `TempDir::new()` which
934/// relies on the Drop trait, this function creates named directories that can
935/// be cleaned up on next process startup.
936///
937/// Directory naming format: `postgres-seren-replicator-{timestamp}-{random}`
938/// Example: `postgres-seren-replicator-20250106-120534-a3b2c1d4`
939///
940/// # Returns
941///
942/// Returns the path to the created temporary directory.
943///
944/// # Errors
945///
946/// Returns an error if the directory cannot be created.
947///
948/// # Examples
949///
950/// ```no_run
951/// # use database_replicator::utils::create_managed_temp_dir;
952/// # use anyhow::Result;
953/// # fn example() -> Result<()> {
954/// let temp_path = create_managed_temp_dir()?;
955/// println!("Using temp directory: {}", temp_path.display());
956/// // ... do work ...
957/// // Cleanup happens automatically on next startup via cleanup_stale_temp_dirs()
958/// # Ok(())
959/// # }
960/// ```
961pub fn create_managed_temp_dir() -> Result<std::path::PathBuf> {
962 use std::fs;
963 use std::time::SystemTime;
964
965 let system_temp = std::env::temp_dir();
966
967 // Generate timestamp for directory name
968 let timestamp = SystemTime::now()
969 .duration_since(SystemTime::UNIX_EPOCH)
970 .unwrap()
971 .as_secs();
972
973 // Generate random suffix for uniqueness
974 let random: u32 = rand::random();
975
976 // Create directory name with timestamp and random suffix
977 let dir_name = format!("postgres-seren-replicator-{}-{:08x}", timestamp, random);
978
979 let temp_path = system_temp.join(dir_name);
980
981 // Create the directory
982 fs::create_dir_all(&temp_path)
983 .with_context(|| format!("Failed to create temp directory at {}", temp_path.display()))?;
984
985 tracing::debug!("Created managed temp directory: {}", temp_path.display());
986
987 Ok(temp_path)
988}
989
990/// Clean up stale temporary directories from previous runs
991///
992/// Removes temporary directories created by `create_managed_temp_dir()` that are
993/// older than the specified age. This should be called on process startup to clean
994/// up directories left behind by processes killed with SIGKILL.
995///
996/// Only directories matching the pattern `postgres-seren-replicator-*` are removed.
997///
998/// # Arguments
999///
1000/// * `max_age_secs` - Maximum age in seconds before a directory is considered stale
1001/// (recommended: 86400 for 24 hours)
1002///
1003/// # Returns
1004///
1005/// Returns the number of directories cleaned up.
1006///
1007/// # Errors
1008///
1009/// Returns an error if the system temp directory cannot be read. Individual
1010/// directory removal errors are logged but don't fail the entire operation.
1011///
1012/// # Examples
1013///
1014/// ```no_run
1015/// # use database_replicator::utils::cleanup_stale_temp_dirs;
1016/// # use anyhow::Result;
1017/// # fn example() -> Result<()> {
1018/// // Clean up temp directories older than 24 hours
1019/// let cleaned = cleanup_stale_temp_dirs(86400)?;
1020/// println!("Cleaned up {} stale temp directories", cleaned);
1021/// # Ok(())
1022/// # }
1023/// ```
1024pub fn cleanup_stale_temp_dirs(max_age_secs: u64) -> Result<usize> {
1025 use std::fs;
1026 use std::time::SystemTime;
1027
1028 let system_temp = std::env::temp_dir();
1029 let now = SystemTime::now();
1030 let mut cleaned_count = 0;
1031
1032 // Read all entries in system temp directory
1033 let entries = fs::read_dir(&system_temp).with_context(|| {
1034 format!(
1035 "Failed to read system temp directory: {}",
1036 system_temp.display()
1037 )
1038 })?;
1039
1040 for entry in entries.flatten() {
1041 let path = entry.path();
1042
1043 // Only process directories matching our naming pattern
1044 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1045 if !name.starts_with("postgres-seren-replicator-") {
1046 continue;
1047 }
1048
1049 // Check directory age
1050 match entry.metadata() {
1051 Ok(metadata) => {
1052 if let Ok(modified) = metadata.modified() {
1053 if let Ok(age) = now.duration_since(modified) {
1054 if age.as_secs() > max_age_secs {
1055 // Directory is stale, remove it
1056 match fs::remove_dir_all(&path) {
1057 Ok(_) => {
1058 tracing::info!(
1059 "Cleaned up stale temp directory: {} (age: {}s)",
1060 path.display(),
1061 age.as_secs()
1062 );
1063 cleaned_count += 1;
1064 }
1065 Err(e) => {
1066 tracing::warn!(
1067 "Failed to remove stale temp directory {}: {}",
1068 path.display(),
1069 e
1070 );
1071 }
1072 }
1073 }
1074 }
1075 }
1076 }
1077 Err(e) => {
1078 tracing::warn!(
1079 "Failed to get metadata for temp directory {}: {}",
1080 path.display(),
1081 e
1082 );
1083 }
1084 }
1085 }
1086 }
1087
1088 if cleaned_count > 0 {
1089 tracing::info!(
1090 "Cleaned up {} stale temp directory(ies) older than {} seconds",
1091 cleaned_count,
1092 max_age_secs
1093 );
1094 }
1095
1096 Ok(cleaned_count)
1097}
1098
1099/// Parse a SerenDB URL to extract project, branch, and database IDs
1100///
1101/// SerenDB URLs have the format: postgresql://user:pass@<database-id>.<branch-id>.<project-id>.serendb.com:5432/db
1102/// This function extracts the three UUIDs from the hostname.
1103///
1104/// # Arguments
1105///
1106/// * `url` - The SerenDB PostgreSQL connection string
1107///
1108/// # Returns
1109///
1110/// An `Option` containing a tuple of `(project_id, branch_id, database_id)` if the
1111/// URL is a valid SerenDB target and contains the expected ID format, otherwise `None`.
1112pub fn parse_serendb_url_for_ids(url: &str) -> Option<(String, String, String)> {
1113 let parts = parse_postgres_url(url).ok()?;
1114
1115 if !is_serendb_target(url) {
1116 return None;
1117 }
1118
1119 // Hostname format: <database-id>.<branch-id>.<project-id>.serendb.com
1120 // Or with custom subdomains: <database-id>.<branch-id>.<project-id>.<custom>.serendb.com
1121 // We want the last three parts before .serendb.com
1122 let host_parts: Vec<&str> = parts.host.split('.').collect();
1123
1124 if host_parts.len() < 4 {
1125 return None; // Not enough parts for SerenDB ID format
1126 }
1127
1128 let num_host_parts = host_parts.len();
1129 let database_id = host_parts[num_host_parts - 4].to_string();
1130 let branch_id = host_parts[num_host_parts - 3].to_string();
1131 let project_id = host_parts[num_host_parts - 2].to_string();
1132
1133 // Basic UUID format validation (optional but good for robustness)
1134 // A real UUID check would be more extensive, but string length is a good start
1135 if database_id.len() == 36 && branch_id.len() == 36 && project_id.len() == 36 {
1136 Some((project_id, branch_id, database_id))
1137 } else {
1138 None
1139 }
1140}
1141
1142/// Remove a managed temporary directory
1143///
1144/// Explicitly removes a temporary directory created by `create_managed_temp_dir()`.
1145/// This should be called when the directory is no longer needed.
1146///
1147/// # Arguments
1148///
1149/// * `path` - Path to the temporary directory to remove
1150///
1151/// # Errors
1152///
1153/// Returns an error if the directory cannot be removed.
1154///
1155/// # Examples
1156///
1157/// ```no_run
1158/// # use database_replicator::utils::{create_managed_temp_dir, remove_managed_temp_dir};
1159/// # use anyhow::Result;
1160/// # fn example() -> Result<()> {
1161/// let temp_path = create_managed_temp_dir()?;
1162/// // ... do work ...
1163/// remove_managed_temp_dir(&temp_path)?;
1164/// # Ok(())
1165/// # }
1166/// ```
1167pub fn remove_managed_temp_dir(path: &std::path::Path) -> Result<()> {
1168 use std::fs;
1169
1170 // Verify this is one of our temp directories (safety check)
1171 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
1172 if !name.starts_with("postgres-seren-replicator-") {
1173 bail!(
1174 "Refusing to remove directory that doesn't match our naming pattern: {}",
1175 path.display()
1176 );
1177 }
1178 } else {
1179 bail!("Invalid temp directory path: {}", path.display());
1180 }
1181
1182 tracing::debug!("Removing managed temp directory: {}", path.display());
1183
1184 fs::remove_dir_all(path)
1185 .with_context(|| format!("Failed to remove temp directory at {}", path.display()))?;
1186
1187 Ok(())
1188}
1189
1190/// Replace the database name in a connection string URL
1191///
1192/// This is used internally by SerenDB to provide a generic connection string
1193/// which then needs to be specialized for a particular database.
1194///
1195/// # Arguments
1196///
1197/// * `url` - The connection string URL (e.g., postgresql://host/template_db)
1198/// * `new_db` - The new database name to insert into the URL
1199///
1200/// # Returns
1201///
1202/// A new URL string with the database name replaced.
1203///
1204/// # Errors
1205///
1206/// Returns an error if the URL is invalid and cannot be parsed.
1207pub fn replace_database_in_connection_string(url: &str, new_db: &str) -> Result<String> {
1208 let mut parsed = Url::parse(url).context("Invalid connection string URL")?;
1209 parsed.set_path(&format!("/{}", new_db));
1210
1211 Ok(parsed.to_string())
1212}
1213
1214/// Check if a PostgreSQL URL points to a SerenDB instance
1215///
1216/// SerenDB hosts have domains ending with `.serendb.com`
1217///
1218/// # Arguments
1219///
1220/// * `url` - PostgreSQL connection string to check
1221///
1222/// # Returns
1223///
1224/// Returns `true` if the URL points to a SerenDB host.
1225///
1226/// # Examples
1227///
1228/// ```
1229/// use database_replicator::utils::is_serendb_target;
1230///
1231/// assert!(is_serendb_target("postgresql://user:pass@db.serendb.com/mydb"));
1232/// assert!(is_serendb_target("postgresql://user:pass@cluster-123.console.serendb.com/mydb"));
1233/// assert!(!is_serendb_target("postgresql://user:pass@localhost/mydb"));
1234/// assert!(!is_serendb_target("postgresql://user:pass@rds.amazonaws.com/mydb"));
1235/// ```
1236pub fn is_serendb_target(url: &str) -> bool {
1237 match parse_postgres_url(url) {
1238 Ok(parts) => parts.host.ends_with(".serendb.com") || parts.host == "serendb.com",
1239 Err(_) => false,
1240 }
1241}
1242
1243/// Get the major version of a PostgreSQL client tool (pg_dump, psql, etc.)
1244///
1245/// Executes `<tool> --version` and parses the output.
1246///
1247/// # Arguments
1248///
1249/// * `tool` - Name of the tool (e.g., "pg_dump", "psql")
1250///
1251/// # Returns
1252///
1253/// The major version number (e.g., 16 for pg_dump 16.10)
1254///
1255/// # Errors
1256///
1257/// Returns an error if:
1258/// - Tool is not found in PATH
1259/// - Tool execution fails
1260/// - Version output cannot be parsed
1261///
1262/// # Examples
1263///
1264/// ```no_run
1265/// use database_replicator::utils::get_pg_tool_version;
1266/// use anyhow::Result;
1267///
1268/// fn example() -> Result<()> {
1269/// let version = get_pg_tool_version("pg_dump")?;
1270/// println!("pg_dump major version: {}", version); // e.g., 16
1271/// Ok(())
1272/// }
1273/// ```
1274pub fn get_pg_tool_version(tool: &str) -> Result<u32> {
1275 use std::process::Command;
1276
1277 let path = which(tool).with_context(|| format!("{} not found in PATH", tool))?;
1278
1279 let output = Command::new(&path)
1280 .arg("--version")
1281 .output()
1282 .with_context(|| format!("Failed to execute {} --version", tool))?;
1283
1284 let version_str = String::from_utf8_lossy(&output.stdout);
1285 parse_pg_version_string(&version_str)
1286}
1287
1288/// Parse major version from PostgreSQL version string
1289///
1290/// Handles formats like:
1291/// - "pg_dump (PostgreSQL) 16.10 (Ubuntu 16.10-0ubuntu0.24.04.1)"
1292/// - "psql (PostgreSQL) 17.2"
1293/// - "17.2 (Debian 17.2-1.pgdg120+1)"
1294///
1295/// # Arguments
1296///
1297/// * `version_str` - Version string output from a PostgreSQL tool
1298///
1299/// # Returns
1300///
1301/// The major version number (e.g., 16, 17)
1302///
1303/// # Errors
1304///
1305/// Returns an error if the version cannot be parsed.
1306pub fn parse_pg_version_string(version_str: &str) -> Result<u32> {
1307 // Look for version pattern: major.minor
1308 for word in version_str.split_whitespace() {
1309 if let Some(major_str) = word.split('.').next() {
1310 if let Ok(major) = major_str.parse::<u32>() {
1311 // Valid PostgreSQL versions are between 9 and 99
1312 if (9..=99).contains(&major) {
1313 return Ok(major);
1314 }
1315 }
1316 }
1317 }
1318 bail!("Could not parse PostgreSQL version from: {}", version_str)
1319}
1320
1321/// Get available system memory in bytes
1322///
1323/// Cross-platform function that works on Linux, macOS, and Windows.
1324/// Returns the amount of memory available for use by applications.
1325///
1326/// # Platform Details
1327///
1328/// - **Linux**: Reads `MemAvailable` from `/proc/meminfo`
1329/// - **macOS**: Uses `sysctl hw.memsize` for total memory, estimates available
1330/// - **Windows**: Uses `GlobalMemoryStatusEx` API
1331///
1332/// # Returns
1333///
1334/// Available memory in bytes, or an error if detection fails.
1335///
1336/// # Examples
1337///
1338/// ```no_run
1339/// use database_replicator::utils::get_available_memory;
1340///
1341/// let available = get_available_memory().unwrap();
1342/// println!("Available memory: {} MB", available / 1024 / 1024);
1343/// ```
1344pub fn get_available_memory() -> Result<u64> {
1345 #[cfg(target_os = "linux")]
1346 {
1347 get_available_memory_linux()
1348 }
1349
1350 #[cfg(target_os = "macos")]
1351 {
1352 get_available_memory_macos()
1353 }
1354
1355 #[cfg(target_os = "windows")]
1356 {
1357 get_available_memory_windows()
1358 }
1359
1360 #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
1361 {
1362 // Fallback: assume 1GB available for unknown platforms
1363 tracing::warn!("Memory detection not supported on this platform, assuming 1GB available");
1364 Ok(1024 * 1024 * 1024)
1365 }
1366}
1367
1368#[cfg(target_os = "linux")]
1369fn get_available_memory_linux() -> Result<u64> {
1370 use std::fs;
1371
1372 let meminfo = fs::read_to_string("/proc/meminfo").context("Failed to read /proc/meminfo")?;
1373
1374 // Try MemAvailable first (more accurate, available since Linux 3.14)
1375 for line in meminfo.lines() {
1376 if line.starts_with("MemAvailable:") {
1377 let parts: Vec<&str> = line.split_whitespace().collect();
1378 if parts.len() >= 2 {
1379 let kb: u64 = parts[1]
1380 .parse()
1381 .context("Failed to parse MemAvailable value")?;
1382 return Ok(kb * 1024); // Convert KB to bytes
1383 }
1384 }
1385 }
1386
1387 // Fallback: MemFree + Buffers + Cached (less accurate but works on older kernels)
1388 let mut mem_free: u64 = 0;
1389 let mut buffers: u64 = 0;
1390 let mut cached: u64 = 0;
1391
1392 for line in meminfo.lines() {
1393 let parts: Vec<&str> = line.split_whitespace().collect();
1394 if parts.len() >= 2 {
1395 let value: u64 = parts[1].parse().unwrap_or(0);
1396 if line.starts_with("MemFree:") {
1397 mem_free = value;
1398 } else if line.starts_with("Buffers:") {
1399 buffers = value;
1400 } else if line.starts_with("Cached:") && !line.starts_with("SwapCached:") {
1401 cached = value;
1402 }
1403 }
1404 }
1405
1406 Ok((mem_free + buffers + cached) * 1024) // Convert KB to bytes
1407}
1408
1409#[cfg(target_os = "macos")]
1410fn get_available_memory_macos() -> Result<u64> {
1411 use std::process::Command;
1412
1413 // Get total physical memory using sysctl
1414 let output = Command::new("sysctl")
1415 .args(["-n", "hw.memsize"])
1416 .output()
1417 .context("Failed to execute sysctl")?;
1418
1419 let total_str = String::from_utf8_lossy(&output.stdout);
1420 let total_bytes: u64 = total_str
1421 .trim()
1422 .parse()
1423 .context("Failed to parse hw.memsize")?;
1424
1425 // Get actual page size using sysctl hw.pagesize
1426 // Intel Macs use 4KB (4096), Apple Silicon uses 16KB (16384)
1427 let page_size: u64 = {
1428 let page_output = Command::new("sysctl")
1429 .args(["-n", "hw.pagesize"])
1430 .output()
1431 .context("Failed to execute sysctl hw.pagesize")?;
1432
1433 let page_str = String::from_utf8_lossy(&page_output.stdout);
1434 page_str.trim().parse().unwrap_or(4096) // Default to 4KB if parsing fails
1435 };
1436
1437 // Get free pages using vm_stat
1438 let vm_output = Command::new("vm_stat")
1439 .output()
1440 .context("Failed to execute vm_stat")?;
1441
1442 let vm_stat = String::from_utf8_lossy(&vm_output.stdout);
1443
1444 // Parse free and inactive pages from vm_stat
1445 let mut pages_free: u64 = 0;
1446 let mut pages_inactive: u64 = 0;
1447 let mut pages_purgeable: u64 = 0;
1448
1449 for line in vm_stat.lines() {
1450 if line.starts_with("Pages free:") {
1451 pages_free = parse_vm_stat_value(line);
1452 } else if line.starts_with("Pages inactive:") {
1453 pages_inactive = parse_vm_stat_value(line);
1454 } else if line.starts_with("Pages purgeable:") {
1455 pages_purgeable = parse_vm_stat_value(line);
1456 }
1457 }
1458
1459 // Available = free + inactive + purgeable (conservative estimate)
1460 let available = (pages_free + pages_inactive + pages_purgeable) * page_size;
1461
1462 // If vm_stat parsing failed, estimate 50% of total as available
1463 if available == 0 {
1464 tracing::debug!("vm_stat parsing returned 0, estimating 50% of total memory as available");
1465 return Ok(total_bytes / 2);
1466 }
1467
1468 Ok(available)
1469}
1470
1471#[cfg(target_os = "macos")]
1472fn parse_vm_stat_value(line: &str) -> u64 {
1473 // Format: "Pages free: 12345."
1474 line.split(':')
1475 .nth(1)
1476 .and_then(|s| s.trim().trim_end_matches('.').parse().ok())
1477 .unwrap_or(0)
1478}
1479
1480#[cfg(target_os = "windows")]
1481fn get_available_memory_windows() -> Result<u64> {
1482 use std::mem;
1483
1484 // MEMORYSTATUSEX structure
1485 #[repr(C)]
1486 #[allow(non_snake_case)]
1487 struct MEMORYSTATUSEX {
1488 dwLength: u32,
1489 dwMemoryLoad: u32,
1490 ullTotalPhys: u64,
1491 ullAvailPhys: u64,
1492 ullTotalPageFile: u64,
1493 ullAvailPageFile: u64,
1494 ullTotalVirtual: u64,
1495 ullAvailVirtual: u64,
1496 ullAvailExtendedVirtual: u64,
1497 }
1498
1499 #[link(name = "kernel32")]
1500 extern "system" {
1501 fn GlobalMemoryStatusEx(lpBuffer: *mut MEMORYSTATUSEX) -> i32;
1502 }
1503
1504 let mut mem_status: MEMORYSTATUSEX = unsafe { mem::zeroed() };
1505 mem_status.dwLength = mem::size_of::<MEMORYSTATUSEX>() as u32;
1506
1507 let result = unsafe { GlobalMemoryStatusEx(&mut mem_status) };
1508
1509 if result == 0 {
1510 bail!("GlobalMemoryStatusEx failed");
1511 }
1512
1513 Ok(mem_status.ullAvailPhys)
1514}
1515
1516/// Calculate optimal batch size based on available system memory
1517///
1518/// Automatically determines an appropriate batch size for the sync daemon
1519/// based on the amount of available system memory. This prevents OOM errors
1520/// on memory-constrained instances while maximizing throughput on larger ones.
1521///
1522/// # Memory Model
1523///
1524/// The calculation assumes:
1525/// - Each row consumes approximately 2KB in memory (conservative estimate for wide tables)
1526/// - We should use at most 25% of available memory for batch processing
1527/// - Minimum batch size: 1,000 rows (for very constrained systems)
1528/// - Maximum batch size: 50,000 rows (diminishing returns beyond this)
1529///
1530/// # Returns
1531///
1532/// Optimal batch size in number of rows, or default of 10,000 if detection fails.
1533///
1534/// # Examples
1535///
1536/// ```no_run
1537/// use database_replicator::utils::calculate_optimal_batch_size;
1538///
1539/// let batch_size = calculate_optimal_batch_size();
1540/// println!("Using batch size: {}", batch_size);
1541/// // On t3.nano (512MB): ~1,000-2,000
1542/// // On t3.small (2GB): ~10,000
1543/// // On t3.large (8GB): ~50,000 (capped)
1544/// ```
1545pub fn calculate_optimal_batch_size() -> usize {
1546 const BYTES_PER_ROW: u64 = 2048; // Conservative: 2KB per row
1547 const MEMORY_FRACTION: f64 = 0.25; // Use at most 25% of available memory
1548 const MIN_BATCH_SIZE: usize = 1_000;
1549 const MAX_BATCH_SIZE: usize = 50_000;
1550 const DEFAULT_BATCH_SIZE: usize = 10_000;
1551
1552 match get_available_memory() {
1553 Ok(available_bytes) => {
1554 // Calculate how many rows we can fit in 25% of available memory
1555 let usable_bytes = (available_bytes as f64 * MEMORY_FRACTION) as u64;
1556 let calculated_size = (usable_bytes / BYTES_PER_ROW) as usize;
1557
1558 // Clamp to min/max range
1559 let batch_size = calculated_size.clamp(MIN_BATCH_SIZE, MAX_BATCH_SIZE);
1560
1561 tracing::info!(
1562 "Auto-detected batch size: {} (available memory: {} MB)",
1563 batch_size,
1564 available_bytes / 1024 / 1024
1565 );
1566
1567 batch_size
1568 }
1569 Err(e) => {
1570 tracing::warn!(
1571 "Failed to detect available memory: {}. Using default batch size: {}",
1572 e,
1573 DEFAULT_BATCH_SIZE
1574 );
1575 DEFAULT_BATCH_SIZE
1576 }
1577 }
1578}
1579
1580#[cfg(test)]
1581mod tests {
1582 use super::*;
1583
1584 #[test]
1585 fn test_get_available_memory() {
1586 // This should work on all supported platforms
1587 let result = get_available_memory();
1588
1589 // Should succeed (may fail in very restricted environments)
1590 if let Ok(available) = result {
1591 // Sanity check: should be at least 10MB, less than 1TB
1592 assert!(
1593 available > 10 * 1024 * 1024,
1594 "Available memory too low: {}",
1595 available
1596 );
1597 assert!(
1598 available < 1024 * 1024 * 1024 * 1024,
1599 "Available memory too high: {}",
1600 available
1601 );
1602 }
1603 }
1604
1605 #[test]
1606 fn test_calculate_optimal_batch_size() {
1607 let batch_size = calculate_optimal_batch_size();
1608
1609 // Should be within expected range
1610 assert!(batch_size >= 1_000, "Batch size too small: {}", batch_size);
1611 assert!(batch_size <= 50_000, "Batch size too large: {}", batch_size);
1612 }
1613
1614 #[test]
1615 fn test_validate_connection_string_valid() {
1616 assert!(validate_connection_string("postgresql://user:pass@localhost:5432/dbname").is_ok());
1617 assert!(validate_connection_string("postgres://user@host/db").is_ok());
1618 }
1619
1620 #[test]
1621 fn test_check_required_tools() {
1622 // This test will pass if PostgreSQL client tools are installed
1623 // It will fail (appropriately) if they're not installed
1624 let result = check_required_tools();
1625
1626 // On systems with PostgreSQL installed, this should pass
1627 // On systems without it, we expect a specific error message
1628 if let Err(err) = result {
1629 let err_msg = err.to_string();
1630 assert!(err_msg.contains("Missing required PostgreSQL client tools"));
1631 assert!(
1632 err_msg.contains("pg_dump")
1633 || err_msg.contains("pg_dumpall")
1634 || err_msg.contains("psql")
1635 );
1636 }
1637 }
1638
1639 #[test]
1640 fn test_validate_connection_string_invalid() {
1641 assert!(validate_connection_string("").is_err());
1642 assert!(validate_connection_string(" ").is_err());
1643 assert!(validate_connection_string("mysql://localhost/db").is_err());
1644 assert!(validate_connection_string("postgresql://localhost").is_err());
1645 assert!(validate_connection_string("postgresql://localhost/db").is_err());
1646 // Missing user
1647 }
1648
1649 #[test]
1650 fn test_sanitize_identifier() {
1651 assert_eq!(sanitize_identifier("normal_table"), "normal_table");
1652 assert_eq!(sanitize_identifier("table\x00name"), "tablename");
1653 assert_eq!(sanitize_identifier("table\nname"), "tablename");
1654
1655 // Test length limit
1656 let long_name = "a".repeat(200);
1657 assert_eq!(sanitize_identifier(&long_name).len(), 100);
1658 }
1659
1660 #[tokio::test]
1661 async fn test_retry_with_backoff_success() {
1662 let mut attempts = 0;
1663 let result = retry_with_backoff(
1664 || {
1665 attempts += 1;
1666 async move {
1667 if attempts < 3 {
1668 anyhow::bail!("Temporary failure")
1669 } else {
1670 Ok("Success")
1671 }
1672 }
1673 },
1674 5,
1675 Duration::from_millis(10),
1676 )
1677 .await;
1678
1679 assert!(result.is_ok());
1680 assert_eq!(result.unwrap(), "Success");
1681 assert_eq!(attempts, 3);
1682 }
1683
1684 #[tokio::test]
1685 async fn test_retry_with_backoff_failure() {
1686 let mut attempts = 0;
1687 let result: Result<&str> = retry_with_backoff(
1688 || {
1689 attempts += 1;
1690 async move { anyhow::bail!("Permanent failure") }
1691 },
1692 2,
1693 Duration::from_millis(10),
1694 )
1695 .await;
1696
1697 assert!(result.is_err());
1698 assert_eq!(attempts, 3); // Initial + 2 retries
1699 }
1700
1701 #[test]
1702 fn test_validate_source_target_different_valid() {
1703 // Different hosts
1704 assert!(validate_source_target_different(
1705 "postgresql://user:pass@source.com:5432/db",
1706 "postgresql://user:pass@target.com:5432/db"
1707 )
1708 .is_ok());
1709
1710 // Different databases on same host
1711 assert!(validate_source_target_different(
1712 "postgresql://user:pass@host:5432/db1",
1713 "postgresql://user:pass@host:5432/db2"
1714 )
1715 .is_ok());
1716
1717 // Different ports on same host
1718 assert!(validate_source_target_different(
1719 "postgresql://user:pass@host:5432/db",
1720 "postgresql://user:pass@host:5433/db"
1721 )
1722 .is_ok());
1723
1724 // Different users on same host/db (edge case but allowed)
1725 assert!(validate_source_target_different(
1726 "postgresql://user1:pass@host:5432/db",
1727 "postgresql://user2:pass@host:5432/db"
1728 )
1729 .is_ok());
1730 }
1731
1732 #[test]
1733 fn test_validate_source_target_different_invalid() {
1734 // Exact same URL
1735 assert!(validate_source_target_different(
1736 "postgresql://user:pass@host:5432/db",
1737 "postgresql://user:pass@host:5432/db"
1738 )
1739 .is_err());
1740
1741 // Same URL with different scheme (postgres vs postgresql)
1742 assert!(validate_source_target_different(
1743 "postgres://user:pass@host:5432/db",
1744 "postgresql://user:pass@host:5432/db"
1745 )
1746 .is_err());
1747
1748 // Same URL with default port vs explicit port
1749 assert!(validate_source_target_different(
1750 "postgresql://user:pass@host/db",
1751 "postgresql://user:pass@host:5432/db"
1752 )
1753 .is_err());
1754
1755 // Same URL with different query parameters (still same database)
1756 assert!(validate_source_target_different(
1757 "postgresql://user:pass@host:5432/db?sslmode=require",
1758 "postgresql://user:pass@host:5432/db?sslmode=prefer"
1759 )
1760 .is_err());
1761
1762 // Same host with different case (hostnames are case-insensitive)
1763 assert!(validate_source_target_different(
1764 "postgresql://user:pass@HOST.COM:5432/db",
1765 "postgresql://user:pass@host.com:5432/db"
1766 )
1767 .is_err());
1768 }
1769
1770 #[test]
1771 fn test_parse_postgres_url() {
1772 // Full URL with all components including password
1773 let parts = parse_postgres_url("postgresql://myuser:mypass@localhost:5432/mydb").unwrap();
1774 assert_eq!(parts.host, "localhost");
1775 assert_eq!(parts.port, 5432);
1776 assert_eq!(parts.database, "mydb");
1777 assert_eq!(parts.user, Some("myuser".to_string()));
1778 assert_eq!(parts.password, Some("mypass".to_string()));
1779
1780 // URL without port (should default to 5432)
1781 let parts = parse_postgres_url("postgresql://user:pass@host/db").unwrap();
1782 assert_eq!(parts.host, "host");
1783 assert_eq!(parts.port, 5432);
1784 assert_eq!(parts.database, "db");
1785 assert_eq!(parts.user, Some("user".to_string()));
1786 assert_eq!(parts.password, Some("pass".to_string()));
1787
1788 // URL with user but no password
1789 let parts = parse_postgres_url("postgresql://user@host/db").unwrap();
1790 assert_eq!(parts.host, "host");
1791 assert_eq!(parts.user, Some("user".to_string()));
1792 assert_eq!(parts.password, None);
1793
1794 // URL without authentication
1795 let parts = parse_postgres_url("postgresql://host:5433/db").unwrap();
1796 assert_eq!(parts.host, "host");
1797 assert_eq!(parts.port, 5433);
1798 assert_eq!(parts.database, "db");
1799 assert_eq!(parts.user, None);
1800 assert_eq!(parts.password, None);
1801
1802 // URL with query parameters
1803 let parts = parse_postgres_url("postgresql://user:pass@host/db?sslmode=require").unwrap();
1804 assert_eq!(parts.host, "host");
1805 assert_eq!(parts.database, "db");
1806 assert_eq!(parts.password, Some("pass".to_string()));
1807
1808 // URL with postgres:// scheme (alternative)
1809 let parts = parse_postgres_url("postgres://user:pass@host/db").unwrap();
1810 assert_eq!(parts.host, "host");
1811 assert_eq!(parts.database, "db");
1812 assert_eq!(parts.password, Some("pass".to_string()));
1813
1814 // Host normalization (lowercase)
1815 let parts = parse_postgres_url("postgresql://user:pass@HOST.COM/db").unwrap();
1816 assert_eq!(parts.host, "host.com");
1817 assert_eq!(parts.password, Some("pass".to_string()));
1818
1819 // Password with special characters
1820 let parts = parse_postgres_url("postgresql://user:p@ss!word@host/db").unwrap();
1821 assert_eq!(parts.password, Some("p@ss!word".to_string()));
1822 }
1823
1824 #[test]
1825 fn test_validate_postgres_identifier_valid() {
1826 // Valid identifiers
1827 assert!(validate_postgres_identifier("mydb").is_ok());
1828 assert!(validate_postgres_identifier("my_database").is_ok());
1829 assert!(validate_postgres_identifier("_private_db").is_ok());
1830 assert!(validate_postgres_identifier("db123").is_ok());
1831 assert!(validate_postgres_identifier("Database_2024").is_ok());
1832
1833 // Maximum length (63 characters)
1834 let max_length_name = "a".repeat(63);
1835 assert!(validate_postgres_identifier(&max_length_name).is_ok());
1836 }
1837
1838 #[test]
1839 fn test_pgpass_file_creation() {
1840 let parts = PostgresUrlParts {
1841 host: "localhost".to_string(),
1842 port: 5432,
1843 database: "testdb".to_string(),
1844 user: Some("testuser".to_string()),
1845 password: Some("testpass".to_string()),
1846 query_params: std::collections::HashMap::new(),
1847 };
1848
1849 let pgpass = PgPassFile::new(&parts).unwrap();
1850 assert!(pgpass.path().exists());
1851
1852 // Verify file content
1853 let content = std::fs::read_to_string(pgpass.path()).unwrap();
1854 assert_eq!(content, "localhost:5432:testdb:testuser:testpass\n");
1855
1856 // Verify permissions on Unix
1857 #[cfg(unix)]
1858 {
1859 use std::os::unix::fs::PermissionsExt;
1860 let metadata = std::fs::metadata(pgpass.path()).unwrap();
1861 let permissions = metadata.permissions();
1862 assert_eq!(permissions.mode() & 0o777, 0o600);
1863 }
1864
1865 // File should be cleaned up when pgpass is dropped
1866 let path = pgpass.path().to_path_buf();
1867 drop(pgpass);
1868 assert!(!path.exists());
1869 }
1870
1871 #[test]
1872 fn test_pgpass_file_without_password() {
1873 let parts = PostgresUrlParts {
1874 host: "localhost".to_string(),
1875 port: 5432,
1876 database: "testdb".to_string(),
1877 user: Some("testuser".to_string()),
1878 password: None,
1879 query_params: std::collections::HashMap::new(),
1880 };
1881
1882 let pgpass = PgPassFile::new(&parts).unwrap();
1883 let content = std::fs::read_to_string(pgpass.path()).unwrap();
1884 // Should use empty password
1885 assert_eq!(content, "localhost:5432:testdb:testuser:\n");
1886 }
1887
1888 #[test]
1889 fn test_pgpass_file_without_user() {
1890 let parts = PostgresUrlParts {
1891 host: "localhost".to_string(),
1892 port: 5432,
1893 database: "testdb".to_string(),
1894 user: None,
1895 password: Some("testpass".to_string()),
1896 query_params: std::collections::HashMap::new(),
1897 };
1898
1899 let pgpass = PgPassFile::new(&parts).unwrap();
1900 let content = std::fs::read_to_string(pgpass.path()).unwrap();
1901 // Should use wildcard for user
1902 assert_eq!(content, "localhost:5432:testdb:*:testpass\n");
1903 }
1904
1905 #[test]
1906 fn test_strip_password_from_url() {
1907 // With password
1908 let url = "postgresql://user:p@ssw0rd@host:5432/db";
1909 let stripped = strip_password_from_url(url).unwrap();
1910 assert_eq!(stripped, "postgresql://user@host:5432/db");
1911
1912 // With special characters in password
1913 let url = "postgresql://user:p@ss!w0rd@host:5432/db";
1914 let stripped = strip_password_from_url(url).unwrap();
1915 assert_eq!(stripped, "postgresql://user@host:5432/db");
1916
1917 // Without password
1918 let url = "postgresql://user@host:5432/db";
1919 let stripped = strip_password_from_url(url).unwrap();
1920 assert_eq!(stripped, "postgresql://user@host:5432/db");
1921
1922 // With query parameters
1923 let url = "postgresql://user:pass@host:5432/db?sslmode=require";
1924 let stripped = strip_password_from_url(url).unwrap();
1925 assert_eq!(stripped, "postgresql://user@host:5432/db?sslmode=require");
1926
1927 // No user
1928 let url = "postgresql://host:5432/db";
1929 let stripped = strip_password_from_url(url).unwrap();
1930 assert_eq!(stripped, "postgresql://host:5432/db");
1931 }
1932
1933 #[test]
1934 fn test_validate_postgres_identifier_invalid() {
1935 // SQL injection attempts
1936 assert!(validate_postgres_identifier("mydb\"; DROP DATABASE production; --").is_err());
1937 assert!(validate_postgres_identifier("db'; DELETE FROM users; --").is_err());
1938
1939 // Invalid start characters
1940 assert!(validate_postgres_identifier("123db").is_err()); // Starts with digit
1941 assert!(validate_postgres_identifier("$db").is_err()); // Starts with special char
1942 assert!(validate_postgres_identifier("-db").is_err()); // Starts with dash
1943
1944 // Contains invalid characters
1945 assert!(validate_postgres_identifier("my-database").is_err()); // Contains dash
1946 assert!(validate_postgres_identifier("my.database").is_err()); // Contains dot
1947 assert!(validate_postgres_identifier("my database").is_err()); // Contains space
1948 assert!(validate_postgres_identifier("my@db").is_err()); // Contains @
1949 assert!(validate_postgres_identifier("my#db").is_err()); // Contains #
1950
1951 // Empty or too long
1952 assert!(validate_postgres_identifier("").is_err());
1953 assert!(validate_postgres_identifier(" ").is_err());
1954
1955 // Over maximum length (64+ characters)
1956 let too_long = "a".repeat(64);
1957 assert!(validate_postgres_identifier(&too_long).is_err());
1958
1959 // Control characters
1960 assert!(validate_postgres_identifier("my\ndb").is_err());
1961 assert!(validate_postgres_identifier("my\tdb").is_err());
1962 assert!(validate_postgres_identifier("my\x00db").is_err());
1963 }
1964
1965 #[test]
1966 fn test_is_serendb_target() {
1967 // Positive cases - SerenDB hosts
1968 assert!(is_serendb_target(
1969 "postgresql://user:pass@db.serendb.com/mydb"
1970 ));
1971 assert!(is_serendb_target(
1972 "postgresql://user:pass@cluster.console.serendb.com/mydb"
1973 ));
1974 assert!(is_serendb_target(
1975 "postgres://u:p@x.serendb.com:5432/db?sslmode=require"
1976 ));
1977 assert!(is_serendb_target("postgresql://user:pass@serendb.com/mydb"));
1978
1979 // Negative cases - not SerenDB
1980 assert!(!is_serendb_target("postgresql://user:pass@localhost/mydb"));
1981 assert!(!is_serendb_target(
1982 "postgresql://user:pass@rds.amazonaws.com/mydb"
1983 ));
1984 assert!(!is_serendb_target("postgresql://user:pass@neon.tech/mydb"));
1985 // Domain spoofing attempt - should NOT match
1986 assert!(!is_serendb_target(
1987 "postgresql://user:pass@serendb.com.evil.com/mydb"
1988 ));
1989 assert!(!is_serendb_target(
1990 "postgresql://user:pass@notserendb.com/mydb"
1991 ));
1992 // Invalid URL
1993 assert!(!is_serendb_target("not-a-url"));
1994 }
1995
1996 #[test]
1997 fn test_parse_pg_version_string() {
1998 // Standard pg_dump output
1999 assert_eq!(
2000 parse_pg_version_string("pg_dump (PostgreSQL) 16.10 (Ubuntu 16.10-0ubuntu0.24.04.1)")
2001 .unwrap(),
2002 16
2003 );
2004
2005 // Standard psql output
2006 assert_eq!(
2007 parse_pg_version_string("psql (PostgreSQL) 17.2").unwrap(),
2008 17
2009 );
2010
2011 // pg_restore output
2012 assert_eq!(
2013 parse_pg_version_string("pg_restore (PostgreSQL) 15.4").unwrap(),
2014 15
2015 );
2016
2017 // Debian-style version
2018 assert_eq!(
2019 parse_pg_version_string("17.2 (Debian 17.2-1.pgdg120+1)").unwrap(),
2020 17
2021 );
2022
2023 // Should fail on invalid input
2024 assert!(parse_pg_version_string("not a version").is_err());
2025 assert!(parse_pg_version_string("version 1.2.3").is_err()); // 1 is < 9
2026 assert!(parse_pg_version_string("").is_err());
2027 }
2028
2029 #[test]
2030 fn test_get_pg_tool_version() {
2031 // This test will only pass if pg_dump is installed
2032 // Skip gracefully if not available
2033 if which("pg_dump").is_ok() {
2034 let version = get_pg_tool_version("pg_dump").unwrap();
2035 assert!(
2036 version >= 12,
2037 "Expected pg_dump version >= 12, got {}",
2038 version
2039 );
2040 assert!(
2041 version <= 99,
2042 "Expected pg_dump version <= 99, got {}",
2043 version
2044 );
2045 }
2046
2047 // Non-existent tool should fail
2048 assert!(get_pg_tool_version("nonexistent_pg_tool_xyz").is_err());
2049 }
2050}