database_replicator/migration/
dump.rs

1// ABOUTME: Wrapper for pg_dump command to export database objects
2// ABOUTME: Handles global objects, schema, and data export
3
4use crate::filters::ReplicationFilter;
5use anyhow::{Context, Result};
6use std::collections::BTreeSet;
7use std::fs;
8use std::process::{Command, Stdio};
9use std::time::Duration;
10
11/// Dump global objects (roles, tablespaces) using pg_dumpall
12pub async fn dump_globals(source_url: &str, output_path: &str) -> Result<()> {
13    tracing::info!("Dumping global objects to {}", output_path);
14
15    // Parse URL and create .pgpass file for secure authentication
16    let parts = crate::utils::parse_postgres_url(source_url)
17        .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
18    let pgpass = crate::utils::PgPassFile::new(&parts)
19        .context("Failed to create .pgpass file for authentication")?;
20
21    let env_vars = parts.to_pg_env_vars();
22    let output_path_owned = output_path.to_string();
23
24    // Wrap subprocess execution with retry logic
25    crate::utils::retry_subprocess_with_backoff(
26        || {
27            let mut cmd = Command::new("pg_dumpall");
28            cmd.arg("--globals-only")
29                .arg("--no-role-passwords") // Don't dump passwords
30                .arg("--verbose") // Show progress
31                .arg("--host")
32                .arg(&parts.host)
33                .arg("--port")
34                .arg(parts.port.to_string())
35                .arg("--database")
36                .arg(&parts.database)
37                .arg(format!("--file={}", output_path_owned))
38                .env("PGPASSFILE", pgpass.path())
39                .stdout(Stdio::inherit())
40                .stderr(Stdio::inherit());
41
42            // Add username if specified
43            if let Some(user) = &parts.user {
44                cmd.arg("--username").arg(user);
45            }
46
47            // Apply query parameters as environment variables (SSL, channel_binding, etc.)
48            for (env_var, value) in &env_vars {
49                cmd.env(env_var, value);
50            }
51
52            // Apply TCP keepalive parameters to prevent idle connection timeouts
53            for (env_var, value) in crate::utils::get_keepalive_env_vars() {
54                cmd.env(env_var, value);
55            }
56
57            cmd.status().context(
58                "Failed to execute pg_dumpall. Is PostgreSQL client installed?\n\
59                 Install with:\n\
60                 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
61                 - macOS: brew install postgresql\n\
62                 - RHEL/CentOS: sudo yum install postgresql",
63            )
64        },
65        3,                      // Max 3 retries
66        Duration::from_secs(1), // Start with 1 second delay
67        "pg_dumpall (dump globals)",
68    )
69    .await
70    .context(
71        "pg_dumpall failed to dump global objects.\n\
72         \n\
73         Common causes:\n\
74         - Connection authentication failed\n\
75         - User lacks sufficient privileges (need SUPERUSER or pg_read_all_settings role)\n\
76         - Network connectivity issues\n\
77         - Invalid connection string\n\
78         - Connection timeout or network issues",
79    )?;
80
81    tracing::info!("✓ Global objects dumped successfully");
82    Ok(())
83}
84
85/// Update a globals dump so duplicate role creation errors become harmless notices.
86///
87/// `pg_dumpall --globals-only` emits `CREATE ROLE` statements that fail if the
88/// role already exists on the target cluster. When an operator reruns
89/// replication against the same target, those statements cause `psql` to exit
90/// with status 3 which previously triggered noisy retries and prevented the rest
91/// of the globals from being applied. By wrapping each `CREATE ROLE` statement
92/// in a `DO $$ ... EXCEPTION WHEN duplicate_object` block, we allow Postgres to
93/// skip recreating existing roles while still applying subsequent `ALTER ROLE`
94/// and `GRANT` statements.
95pub fn sanitize_globals_dump(path: &str) -> Result<()> {
96    let content = fs::read_to_string(path)
97        .with_context(|| format!("Failed to read globals dump at {}", path))?;
98
99    if let Some(updated) = rewrite_create_role_statements(&content) {
100        fs::write(path, updated)
101            .with_context(|| format!("Failed to update globals dump at {}", path))?;
102    }
103
104    Ok(())
105}
106
107/// Comments out `ALTER ROLE ... SUPERUSER` statements in a globals dump file.
108///
109/// Managed Postgres services (e.g., AWS RDS) often prevent the restore user
110/// from granting SUPERUSER. Commenting out those lines keeps the restore
111/// moving without permission errors.
112pub fn remove_superuser_from_globals(path: &str) -> Result<()> {
113    let content = fs::read_to_string(path)
114        .with_context(|| format!("Failed to read globals dump at {}", path))?;
115
116    let mut updated = String::with_capacity(content.len());
117    let mut modified = false;
118    for line in content.lines() {
119        if line.contains("ALTER ROLE") && line.contains("SUPERUSER") {
120            updated.push_str("-- ");
121            updated.push_str(line);
122            updated.push('\n');
123            modified = true;
124        } else {
125            updated.push_str(line);
126            updated.push('\n');
127        }
128    }
129
130    if modified {
131        fs::write(path, updated)
132            .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
133    }
134
135    Ok(())
136}
137
138/// Removes parameter settings that require superuser privileges (e.g. `log_statement`).
139///
140/// AWS RDS prevents standard replication roles from altering certain GUCs via
141/// `ALTER ROLE ... SET`. Each offending line is commented out so `psql` skips
142/// them without aborting the rest of the globals restore.
143pub fn remove_restricted_guc_settings(path: &str) -> Result<()> {
144    let content = fs::read_to_string(path)
145        .with_context(|| format!("Failed to read globals dump at {}", path))?;
146
147    let mut updated = String::with_capacity(content.len());
148    let mut modified = false;
149
150    for line in content.lines() {
151        let lower_line = line.to_ascii_lowercase();
152        if lower_line.contains("alter role") && lower_line.contains("set") {
153            updated.push_str("-- ");
154            updated.push_str(line);
155            updated.push('\n');
156            modified = true;
157        } else {
158            updated.push_str(line);
159            updated.push('\n');
160        }
161    }
162
163    if modified {
164        fs::write(path, updated)
165            .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
166    }
167
168    Ok(())
169}
170
171/// Comments out tablespace-related statements in a globals dump file.
172///
173/// Some managed PostgreSQL targets (e.g., SerenDB) do not support custom tablespaces.
174/// This function filters out:
175/// - `CREATE TABLESPACE` statements
176/// - Any statement referencing RDS-specific tablespaces (any `rds_*` tablespace)
177pub fn remove_tablespace_statements(path: &str) -> Result<()> {
178    let content = fs::read_to_string(path)
179        .with_context(|| format!("Failed to read globals dump at {}", path))?;
180
181    let mut updated = String::with_capacity(content.len());
182    let mut modified = false;
183
184    for line in content.lines() {
185        let lower_trimmed = line.trim().to_ascii_lowercase();
186
187        // Filter CREATE TABLESPACE statements
188        let is_create_tablespace = lower_trimmed.starts_with("create tablespace");
189
190        // Filter any statement referencing RDS tablespaces (rds_* pattern)
191        // Matches 'rds_something' or "rds_something" in SQL statements
192        let references_rds_tablespace =
193            lower_trimmed.contains("'rds_") || lower_trimmed.contains("\"rds_");
194
195        if is_create_tablespace || references_rds_tablespace {
196            updated.push_str("-- ");
197            updated.push_str(line);
198            updated.push('\n');
199            modified = true;
200        } else {
201            updated.push_str(line);
202            updated.push('\n');
203        }
204    }
205
206    if modified {
207        fs::write(path, updated)
208            .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
209    }
210
211    Ok(())
212}
213
214/// Comments out `GRANT` statements for roles that are restricted on managed services.
215///
216/// AWS RDS and other managed services may prevent granting certain default roles
217/// like `pg_checkpoint`. This function also filters out GRANT statements that use
218/// `GRANTED BY` clauses referencing RDS admin roles (e.g., `rdsadmin`).
219pub fn remove_restricted_role_grants(path: &str) -> Result<()> {
220    // Roles that cannot be granted on managed PostgreSQL services (AWS RDS, etc.)
221    const RESTRICTED_ROLES: &[&str] = &[
222        "pg_checkpoint",
223        "pg_read_all_data",
224        "pg_write_all_data",
225        "pg_read_all_settings",
226        "pg_read_all_stats",
227        "pg_stat_scan_tables",
228        "pg_monitor",
229        "pg_signal_backend",
230        "pg_read_server_files",
231        "pg_write_server_files",
232        "pg_execute_server_program",
233        "pg_create_subscription",
234        "pg_maintain",
235        "pg_use_reserved_connections",
236    ];
237
238    // Roles that cannot be used as grantors in GRANTED BY clauses
239    const RESTRICTED_GRANTORS: &[&str] = &[
240        "rdsadmin",
241        "rds_superuser",
242        "rdsrepladmin",
243        "rds_replication",
244    ];
245
246    let content = fs::read_to_string(path)
247        .with_context(|| format!("Failed to read globals dump at {}", path))?;
248
249    let mut updated = String::with_capacity(content.len());
250    let mut modified = false;
251
252    for line in content.lines() {
253        let lower_trimmed = line.trim().to_ascii_lowercase();
254        if lower_trimmed.starts_with("grant ") {
255            // Check if granting a restricted role
256            let is_restricted_role = RESTRICTED_ROLES.iter().any(|role| {
257                // Get the role being granted (2nd word), stripping any quotes
258                // e.g. "grant pg_checkpoint to some_user" or "grant \"pg_checkpoint\" to some_user"
259                lower_trimmed
260                    .split_whitespace()
261                    .nth(1)
262                    .map(|r| r.trim_matches('"') == *role)
263                    .unwrap_or(false)
264            });
265
266            // Check if using a restricted grantor in GRANTED BY clause
267            let has_restricted_grantor = RESTRICTED_GRANTORS.iter().any(|grantor| {
268                // Look for "granted by rdsadmin" or "granted by \"rdsadmin\""
269                lower_trimmed.contains(&format!("granted by {}", grantor))
270                    || lower_trimmed.contains(&format!("granted by \"{}\"", grantor))
271            });
272
273            if is_restricted_role || has_restricted_grantor {
274                updated.push_str("-- ");
275                updated.push_str(line);
276                updated.push('\n');
277                modified = true;
278                continue;
279            }
280        }
281
282        updated.push_str(line);
283        updated.push('\n');
284    }
285
286    if modified {
287        fs::write(path, updated)
288            .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
289    }
290
291    Ok(())
292}
293
294fn rewrite_create_role_statements(sql: &str) -> Option<String> {
295    if sql.is_empty() {
296        return None;
297    }
298
299    let mut output = String::with_capacity(sql.len() + 1024);
300    let mut modified = false;
301    let mut cursor = 0;
302
303    while cursor < sql.len() {
304        if let Some(rel_pos) = sql[cursor..].find('\n') {
305            let end = cursor + rel_pos + 1;
306            let chunk = &sql[cursor..end];
307            if let Some(transformed) = wrap_create_role_line(chunk) {
308                output.push_str(&transformed);
309                modified = true;
310            } else {
311                output.push_str(chunk);
312            }
313            cursor = end;
314        } else {
315            let chunk = &sql[cursor..];
316            if let Some(transformed) = wrap_create_role_line(chunk) {
317                output.push_str(&transformed);
318                modified = true;
319            } else {
320                output.push_str(chunk);
321            }
322            break;
323        }
324    }
325
326    if modified {
327        Some(output)
328    } else {
329        None
330    }
331}
332
333fn wrap_create_role_line(chunk: &str) -> Option<String> {
334    let trimmed = chunk.trim_start();
335    if !trimmed.starts_with("CREATE ROLE ") {
336        return None;
337    }
338
339    let statement = trimmed.trim_end();
340    let statement_body = statement.trim_end_matches(';').trim_end();
341    let leading_ws_len = chunk.len() - trimmed.len();
342    let leading_ws = &chunk[..leading_ws_len];
343    let newline = if chunk.ends_with("\r\n") {
344        "\r\n"
345    } else if chunk.ends_with('\n') {
346        "\n"
347    } else {
348        ""
349    };
350
351    let role_token = extract_role_token(statement_body)?;
352
353    let notice_name = escape_single_quotes(&unquote_role_name(&role_token));
354
355    let mut block = String::with_capacity(chunk.len() + 128);
356    block.push_str(leading_ws);
357    block.push_str("DO $$\n");
358    block.push_str(leading_ws);
359    block.push_str("BEGIN\n");
360    block.push_str(leading_ws);
361    block.push_str("    ");
362    block.push_str(statement_body);
363    block.push_str(";\n");
364    block.push_str(leading_ws);
365    block.push_str("EXCEPTION\n");
366    block.push_str(leading_ws);
367    block.push_str("    WHEN duplicate_object THEN\n");
368    block.push_str(leading_ws);
369    block.push_str("        RAISE NOTICE 'Role ");
370    block.push_str(&notice_name);
371    block.push_str(" already exists on target, skipping CREATE ROLE';\n");
372    block.push_str(leading_ws);
373    block.push_str("END $$;");
374
375    if !newline.is_empty() {
376        block.push_str(newline);
377    }
378
379    Some(block)
380}
381
382fn extract_role_token(statement: &str) -> Option<String> {
383    let remainder = statement.strip_prefix("CREATE ROLE")?.trim_start();
384
385    if remainder.starts_with('"') {
386        let mut idx = 1;
387        let bytes = remainder.as_bytes();
388        while idx < bytes.len() {
389            if bytes[idx] == b'"' {
390                if idx + 1 < bytes.len() && bytes[idx + 1] == b'"' {
391                    idx += 2;
392                    continue;
393                } else {
394                    idx += 1;
395                    break;
396                }
397            }
398            idx += 1;
399        }
400        if idx <= remainder.len() {
401            return Some(remainder[..idx].to_string());
402        }
403        None
404    } else {
405        let mut end = remainder.len();
406        for (i, ch) in remainder.char_indices() {
407            if ch.is_whitespace() || ch == ';' {
408                end = i;
409                break;
410            }
411        }
412        if end == 0 {
413            None
414        } else {
415            Some(remainder[..end].to_string())
416        }
417    }
418}
419
420fn unquote_role_name(token: &str) -> String {
421    if token.starts_with('"') && token.ends_with('"') && token.len() >= 2 {
422        let inner = &token[1..token.len() - 1];
423        inner.replace("\"\"", "\"")
424    } else {
425        token.to_string()
426    }
427}
428
429fn escape_single_quotes(value: &str) -> String {
430    value.replace('\'', "''")
431}
432
433/// Dump schema (DDL) for a specific database
434pub async fn dump_schema(
435    source_url: &str,
436    database: &str,
437    output_path: &str,
438    filter: &ReplicationFilter,
439) -> Result<()> {
440    tracing::info!(
441        "Dumping schema for database '{}' to {}",
442        database,
443        output_path
444    );
445
446    // Parse URL and create .pgpass file for secure authentication
447    let parts = crate::utils::parse_postgres_url(source_url)
448        .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
449    let pgpass = crate::utils::PgPassFile::new(&parts)
450        .context("Failed to create .pgpass file for authentication")?;
451
452    let env_vars = parts.to_pg_env_vars();
453    let output_path_owned = output_path.to_string();
454
455    // Collect filter options
456    let exclude_tables = get_schema_excluded_tables_for_db(filter, database);
457    let include_tables = get_included_tables_for_db(filter, database);
458
459    // Wrap subprocess execution with retry logic
460    crate::utils::retry_subprocess_with_backoff(
461        || {
462            let mut cmd = Command::new("pg_dump");
463            cmd.arg("--schema-only")
464                .arg("--no-owner") // Don't include ownership commands
465                .arg("--no-privileges") // We'll handle privileges separately
466                .arg("--verbose"); // Show progress
467
468            // Add table filtering if specified
469            // Only exclude explicit exclude_tables from schema dump (NOT schema_only or predicate tables)
470            if let Some(ref exclude) = exclude_tables {
471                if !exclude.is_empty() {
472                    for table in exclude {
473                        cmd.arg("--exclude-table").arg(table);
474                    }
475                }
476            }
477
478            // If include_tables is specified, only dump those tables
479            if let Some(ref include) = include_tables {
480                if !include.is_empty() {
481                    for table in include {
482                        cmd.arg("--table").arg(table);
483                    }
484                }
485            }
486
487            cmd.arg("--host")
488                .arg(&parts.host)
489                .arg("--port")
490                .arg(parts.port.to_string())
491                .arg("--dbname")
492                .arg(&parts.database)
493                .arg(format!("--file={}", output_path_owned))
494                .env("PGPASSFILE", pgpass.path())
495                .stdout(Stdio::inherit())
496                .stderr(Stdio::inherit());
497
498            // Add username if specified
499            if let Some(user) = &parts.user {
500                cmd.arg("--username").arg(user);
501            }
502
503            // Apply query parameters as environment variables (SSL, channel_binding, etc.)
504            for (env_var, value) in &env_vars {
505                cmd.env(env_var, value);
506            }
507
508            // Apply TCP keepalive parameters to prevent idle connection timeouts
509            for (env_var, value) in crate::utils::get_keepalive_env_vars() {
510                cmd.env(env_var, value);
511            }
512
513            cmd.status().context(
514                "Failed to execute pg_dump. Is PostgreSQL client installed?\n\
515                 Install with:\n\
516                 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
517                 - macOS: brew install postgresql\n\
518                 - RHEL/CentOS: sudo yum install postgresql",
519            )
520        },
521        3,                      // Max 3 retries
522        Duration::from_secs(1), // Start with 1 second delay
523        "pg_dump (dump schema)",
524    )
525    .await
526    .with_context(|| {
527        format!(
528            "pg_dump failed to dump schema for database '{}'.\n\
529             \n\
530             Common causes:\n\
531             - Database does not exist\n\
532             - Connection authentication failed\n\
533             - User lacks privileges to read database schema\n\
534             - Network connectivity issues\n\
535             - Connection timeout or network issues",
536            database
537        )
538    })?;
539
540    tracing::info!("✓ Schema dumped successfully");
541    Ok(())
542}
543
544/// Dump data for a specific database using optimized directory format
545///
546/// Uses PostgreSQL directory format dump with:
547/// - Parallel dumps for faster performance
548/// - Maximum compression (level 9)
549/// - Large object (blob) support
550/// - Directory output for efficient parallel restore
551///
552/// The number of parallel jobs is automatically determined based on available CPU cores.
553pub async fn dump_data(
554    source_url: &str,
555    database: &str,
556    output_path: &str,
557    filter: &ReplicationFilter,
558) -> Result<()> {
559    // Determine optimal number of parallel jobs (number of CPUs, capped at 8)
560    let num_cpus = std::thread::available_parallelism()
561        .map(|n| n.get().min(8))
562        .unwrap_or(4);
563
564    tracing::info!(
565        "Dumping data for database '{}' to {} (parallel={}, compression=9, format=directory)",
566        database,
567        output_path,
568        num_cpus
569    );
570
571    // Parse URL and create .pgpass file for secure authentication
572    let parts = crate::utils::parse_postgres_url(source_url)
573        .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
574    let pgpass = crate::utils::PgPassFile::new(&parts)
575        .context("Failed to create .pgpass file for authentication")?;
576
577    let env_vars = parts.to_pg_env_vars();
578    let output_path_owned = output_path.to_string();
579
580    // Collect filter options
581    let exclude_tables = get_data_excluded_tables_for_db(filter, database);
582    let include_tables = get_included_tables_for_db(filter, database);
583
584    // Wrap subprocess execution with retry logic
585    crate::utils::retry_subprocess_with_backoff(
586        || {
587            let mut cmd = Command::new("pg_dump");
588            cmd.arg("--data-only")
589                .arg("--no-owner")
590                .arg("--format=directory") // Directory format enables parallel operations
591                .arg("--blobs") // Include large objects (blobs)
592                .arg("--compress=9") // Maximum compression for smaller dump size
593                .arg(format!("--jobs={}", num_cpus)) // Parallel dump jobs
594                .arg("--verbose"); // Show progress
595
596            // Add table filtering if specified
597            // Exclude explicit excludes, schema_only tables, and predicate tables from data dump
598            if let Some(ref exclude) = exclude_tables {
599                if !exclude.is_empty() {
600                    for table in exclude {
601                        cmd.arg("--exclude-table-data").arg(table);
602                    }
603                }
604            }
605
606            // If include_tables is specified, only dump data for those tables
607            if let Some(ref include) = include_tables {
608                if !include.is_empty() {
609                    for table in include {
610                        cmd.arg("--table").arg(table);
611                    }
612                }
613            }
614
615            cmd.arg("--host")
616                .arg(&parts.host)
617                .arg("--port")
618                .arg(parts.port.to_string())
619                .arg("--dbname")
620                .arg(&parts.database)
621                .arg(format!("--file={}", output_path_owned))
622                .env("PGPASSFILE", pgpass.path())
623                .stdout(Stdio::inherit())
624                .stderr(Stdio::inherit());
625
626            // Add username if specified
627            if let Some(user) = &parts.user {
628                cmd.arg("--username").arg(user);
629            }
630
631            // Apply query parameters as environment variables (SSL, channel_binding, etc.)
632            for (env_var, value) in &env_vars {
633                cmd.env(env_var, value);
634            }
635
636            // Apply TCP keepalive parameters to prevent idle connection timeouts
637            for (env_var, value) in crate::utils::get_keepalive_env_vars() {
638                cmd.env(env_var, value);
639            }
640
641            cmd.status().context(
642                "Failed to execute pg_dump. Is PostgreSQL client installed?\n\
643                 Install with:\n\
644                 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
645                 - macOS: brew install postgresql\n\
646                 - RHEL/CentOS: sudo yum install postgresql",
647            )
648        },
649        3,                      // Max 3 retries
650        Duration::from_secs(1), // Start with 1 second delay
651        "pg_dump (dump data)",
652    )
653    .await
654    .with_context(|| {
655        format!(
656            "pg_dump failed to dump data for database '{}'.\n\
657             \n\
658             Common causes:\n\
659             - Database does not exist\n\
660             - Connection authentication failed\n\
661             - User lacks privileges to read table data\n\
662             - Network connectivity issues\n\
663             - Insufficient disk space for dump directory\n\
664             - Output directory already exists (pg_dump requires non-existent path)\n\
665             - Connection timeout or network issues",
666            database
667        )
668    })?;
669
670    tracing::info!(
671        "✓ Data dumped successfully using {} parallel jobs",
672        num_cpus
673    );
674    Ok(())
675}
676
677/// Extract table names to exclude from SCHEMA dumps (--exclude-table flag)
678/// Only excludes explicit exclude_tables - NOT schema_only or predicate tables
679/// (those need their schema created, just not bulk data copied)
680/// Returns schema-qualified names in format: "schema"."table"
681fn get_schema_excluded_tables_for_db(
682    filter: &ReplicationFilter,
683    db_name: &str,
684) -> Option<Vec<String>> {
685    let mut tables = BTreeSet::new();
686
687    // Handle explicit exclude_tables (format: "database.table")
688    // These tables are completely excluded (no schema, no data)
689    if let Some(explicit) = filter.exclude_tables() {
690        for full_name in explicit {
691            let parts: Vec<&str> = full_name.split('.').collect();
692            if parts.len() == 2 && parts[0] == db_name {
693                // Format as "public"."table" for consistency
694                tables.insert(format!("\"public\".\"{}\"", parts[1]));
695            }
696        }
697    }
698
699    if tables.is_empty() {
700        None
701    } else {
702        Some(tables.into_iter().collect())
703    }
704}
705
706/// Extract table names to exclude from DATA dumps (--exclude-table-data flag)
707/// Excludes explicit excludes, schema_only tables, and predicate tables
708/// (predicate tables will be copied separately with filtering)
709/// Returns schema-qualified names in format: "schema"."table"
710fn get_data_excluded_tables_for_db(
711    filter: &ReplicationFilter,
712    db_name: &str,
713) -> Option<Vec<String>> {
714    let mut tables = BTreeSet::new();
715
716    // Handle explicit exclude_tables (format: "database.table")
717    // Default to public schema for backward compatibility
718    if let Some(explicit) = filter.exclude_tables() {
719        for full_name in explicit {
720            let parts: Vec<&str> = full_name.split('.').collect();
721            if parts.len() == 2 && parts[0] == db_name {
722                // Format as "public"."table" for consistency
723                tables.insert(format!("\"public\".\"{}\"", parts[1]));
724            }
725        }
726    }
727
728    // schema_only_tables and predicate_tables already return schema-qualified names
729    for table in filter.schema_only_tables(db_name) {
730        tables.insert(table);
731    }
732
733    for (table, _) in filter.predicate_tables(db_name) {
734        tables.insert(table);
735    }
736
737    if tables.is_empty() {
738        None
739    } else {
740        Some(tables.into_iter().collect())
741    }
742}
743
744/// Extract table names for a specific database from include_tables filter
745/// Returns schema-qualified names in format: "schema"."table"
746fn get_included_tables_for_db(filter: &ReplicationFilter, db_name: &str) -> Option<Vec<String>> {
747    filter.include_tables().map(|tables| {
748        tables
749            .iter()
750            .filter_map(|full_name| {
751                let parts: Vec<&str> = full_name.split('.').collect();
752                if parts.len() == 2 && parts[0] == db_name {
753                    // Format as "public"."table" for consistency
754                    Some(format!("\"public\".\"{}\"", parts[1]))
755                } else {
756                    None
757                }
758            })
759            .collect()
760    })
761}
762
763#[cfg(test)]
764mod tests {
765    use super::*;
766    use tempfile::tempdir;
767
768    #[tokio::test]
769    #[ignore]
770    async fn test_dump_globals() {
771        let url = std::env::var("TEST_SOURCE_URL").unwrap();
772        let dir = tempdir().unwrap();
773        let output = dir.path().join("globals.sql");
774
775        let result = dump_globals(&url, output.to_str().unwrap()).await;
776
777        assert!(result.is_ok());
778        assert!(output.exists());
779
780        // Verify file contains SQL
781        let content = std::fs::read_to_string(&output).unwrap();
782        assert!(content.contains("CREATE ROLE") || !content.is_empty());
783    }
784
785    #[tokio::test]
786    #[ignore]
787    async fn test_dump_schema() {
788        let url = std::env::var("TEST_SOURCE_URL").unwrap();
789        let dir = tempdir().unwrap();
790        let output = dir.path().join("schema.sql");
791
792        // Extract database name from URL
793        let db = url.split('/').next_back().unwrap_or("postgres");
794
795        let filter = crate::filters::ReplicationFilter::empty();
796        let result = dump_schema(&url, db, output.to_str().unwrap(), &filter).await;
797
798        assert!(result.is_ok());
799        assert!(output.exists());
800    }
801
802    #[test]
803    fn test_get_schema_excluded_tables_for_db() {
804        let filter = crate::filters::ReplicationFilter::new(
805            None,
806            None,
807            None,
808            Some(vec![
809                "db1.table1".to_string(),
810                "db1.table2".to_string(),
811                "db2.table3".to_string(),
812            ]),
813        )
814        .unwrap();
815
816        // Schema exclusion only includes explicit exclude_tables
817        let tables = get_schema_excluded_tables_for_db(&filter, "db1").unwrap();
818        // Should return schema-qualified names
819        assert_eq!(
820            tables,
821            vec!["\"public\".\"table1\"", "\"public\".\"table2\""]
822        );
823
824        let tables = get_schema_excluded_tables_for_db(&filter, "db2").unwrap();
825        assert_eq!(tables, vec!["\"public\".\"table3\""]);
826
827        let tables = get_schema_excluded_tables_for_db(&filter, "db3");
828        assert!(tables.is_none() || tables.unwrap().is_empty());
829    }
830
831    #[test]
832    fn test_get_data_excluded_tables_for_db() {
833        let filter = crate::filters::ReplicationFilter::new(
834            None,
835            None,
836            None,
837            Some(vec![
838                "db1.table1".to_string(),
839                "db1.table2".to_string(),
840                "db2.table3".to_string(),
841            ]),
842        )
843        .unwrap();
844
845        // Data exclusion includes explicit exclude_tables, schema_only, and predicate tables
846        let tables = get_data_excluded_tables_for_db(&filter, "db1").unwrap();
847        // Should return schema-qualified names
848        assert_eq!(
849            tables,
850            vec!["\"public\".\"table1\"", "\"public\".\"table2\""]
851        );
852
853        let tables = get_data_excluded_tables_for_db(&filter, "db2").unwrap();
854        assert_eq!(tables, vec!["\"public\".\"table3\""]);
855
856        let tables = get_data_excluded_tables_for_db(&filter, "db3");
857        assert!(tables.is_none() || tables.unwrap().is_empty());
858    }
859
860    #[test]
861    fn test_get_included_tables_for_db() {
862        let filter = crate::filters::ReplicationFilter::new(
863            None,
864            None,
865            Some(vec![
866                "db1.users".to_string(),
867                "db1.orders".to_string(),
868                "db2.products".to_string(),
869            ]),
870            None,
871        )
872        .unwrap();
873
874        let tables = get_included_tables_for_db(&filter, "db1").unwrap();
875        // Should return schema-qualified names in original order
876        assert_eq!(
877            tables,
878            vec!["\"public\".\"users\"", "\"public\".\"orders\""]
879        );
880
881        let tables = get_included_tables_for_db(&filter, "db2").unwrap();
882        assert_eq!(tables, vec!["\"public\".\"products\""]);
883
884        let tables = get_included_tables_for_db(&filter, "db3");
885        assert!(tables.is_none() || tables.unwrap().is_empty());
886    }
887
888    #[test]
889    fn test_get_schema_excluded_tables_for_db_with_empty_filter() {
890        let filter = crate::filters::ReplicationFilter::empty();
891        let tables = get_schema_excluded_tables_for_db(&filter, "db1");
892        assert!(tables.is_none());
893    }
894
895    #[test]
896    fn test_get_data_excluded_tables_for_db_with_empty_filter() {
897        let filter = crate::filters::ReplicationFilter::empty();
898        let tables = get_data_excluded_tables_for_db(&filter, "db1");
899        assert!(tables.is_none());
900    }
901
902    #[test]
903    fn test_get_included_tables_for_db_with_empty_filter() {
904        let filter = crate::filters::ReplicationFilter::empty();
905        let tables = get_included_tables_for_db(&filter, "db1");
906        assert!(tables.is_none());
907    }
908
909    #[test]
910    fn test_rewrite_create_role_statements_wraps_unquoted_role() {
911        let sql = "CREATE ROLE replicator WITH LOGIN;\nALTER ROLE replicator WITH LOGIN;\n";
912        let rewritten = rewrite_create_role_statements(sql).expect("rewrite expected");
913
914        assert!(rewritten.contains("DO $$"));
915        assert!(rewritten.contains("Role replicator already exists"));
916        assert!(rewritten.contains("CREATE ROLE replicator WITH LOGIN;"));
917        assert!(rewritten.contains("ALTER ROLE replicator WITH LOGIN;"));
918    }
919
920    #[test]
921    fn test_rewrite_create_role_statements_wraps_quoted_role() {
922        let sql = "    CREATE ROLE \"Andre Admin\";\n";
923        let rewritten = rewrite_create_role_statements(sql).expect("rewrite expected");
924
925        assert!(rewritten.contains("DO $$"));
926        assert!(rewritten.contains("Andre Admin already exists"));
927        assert!(rewritten.contains("CREATE ROLE \"Andre Admin\""));
928        assert!(rewritten.starts_with("    DO $$"));
929    }
930
931    #[test]
932    fn test_rewrite_create_role_statements_noop_when_absent() {
933        let sql = "ALTER ROLE existing WITH LOGIN;\n";
934        assert!(rewrite_create_role_statements(sql).is_none());
935    }
936
937    #[test]
938    fn test_remove_restricted_role_grants() {
939        let dir = tempdir().unwrap();
940        let globals_file = dir.path().join("globals.sql");
941
942        // Write a sample globals dump with restricted role grants
943        let content = r#"CREATE ROLE myuser;
944ALTER ROLE myuser WITH LOGIN;
945GRANT pg_checkpoint TO myuser;
946GRANT "pg_read_all_stats" TO myuser;
947GRANT pg_monitor TO myuser;
948GRANT myrole TO myuser;
949GRANT SELECT ON TABLE users TO myuser;
950GRANT rds_superuser TO myuser GRANTED BY rdsadmin;
951GRANT ALL ON SCHEMA public TO myuser GRANTED BY "rdsadmin";
952GRANT SELECT ON TABLE orders TO myuser GRANTED BY postgres;
953"#;
954        std::fs::write(&globals_file, content).unwrap();
955
956        // Run the sanitization
957        remove_restricted_role_grants(globals_file.to_str().unwrap()).unwrap();
958
959        // Verify restricted grants are commented out
960        let result = std::fs::read_to_string(&globals_file).unwrap();
961
962        // Restricted role grants should be commented out
963        assert!(result.contains("-- GRANT pg_checkpoint TO myuser;"));
964        assert!(result.contains("-- GRANT \"pg_read_all_stats\" TO myuser;"));
965        assert!(result.contains("-- GRANT pg_monitor TO myuser;"));
966
967        // GRANTED BY rdsadmin clauses should be commented out
968        assert!(result.contains("-- GRANT rds_superuser TO myuser GRANTED BY rdsadmin;"));
969        assert!(result.contains("-- GRANT ALL ON SCHEMA public TO myuser GRANTED BY \"rdsadmin\";"));
970
971        // Non-restricted grants should remain
972        assert!(result.contains("\nGRANT myrole TO myuser;\n"));
973        assert!(result.contains("\nGRANT SELECT ON TABLE users TO myuser;\n"));
974        assert!(result.contains("\nGRANT SELECT ON TABLE orders TO myuser GRANTED BY postgres;\n"));
975
976        // Other statements should remain unchanged
977        assert!(result.contains("CREATE ROLE myuser;"));
978        assert!(result.contains("ALTER ROLE myuser WITH LOGIN;"));
979    }
980}