database_replicator/migration/
dump.rs

1// ABOUTME: Wrapper for pg_dump command to export database objects
2// ABOUTME: Handles global objects, schema, and data export
3
4use crate::filters::ReplicationFilter;
5use anyhow::{Context, Result};
6use std::collections::BTreeSet;
7use std::fs;
8use std::process::{Command, Stdio};
9use std::time::Duration;
10
11/// Dump global objects (roles, tablespaces) using pg_dumpall
12pub async fn dump_globals(source_url: &str, output_path: &str) -> Result<()> {
13    tracing::info!("Dumping global objects to {}", output_path);
14
15    // Parse URL and create .pgpass file for secure authentication
16    let parts = crate::utils::parse_postgres_url(source_url)
17        .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
18    let pgpass = crate::utils::PgPassFile::new(&parts)
19        .context("Failed to create .pgpass file for authentication")?;
20
21    let env_vars = parts.to_pg_env_vars();
22    let output_path_owned = output_path.to_string();
23
24    // Wrap subprocess execution with retry logic
25    crate::utils::retry_subprocess_with_backoff(
26        || {
27            let mut cmd = Command::new("pg_dumpall");
28            cmd.arg("--globals-only")
29                .arg("--no-role-passwords") // Don't dump passwords
30                .arg("--verbose") // Show progress
31                .arg("--host")
32                .arg(&parts.host)
33                .arg("--port")
34                .arg(parts.port.to_string())
35                .arg("--database")
36                .arg(&parts.database)
37                .arg(format!("--file={}", output_path_owned))
38                .env("PGPASSFILE", pgpass.path())
39                .stdout(Stdio::inherit())
40                .stderr(Stdio::inherit());
41
42            // Add username if specified
43            if let Some(user) = &parts.user {
44                cmd.arg("--username").arg(user);
45            }
46
47            // Apply query parameters as environment variables (SSL, channel_binding, etc.)
48            for (env_var, value) in &env_vars {
49                cmd.env(env_var, value);
50            }
51
52            // Apply TCP keepalive parameters to prevent idle connection timeouts
53            for (env_var, value) in crate::utils::get_keepalive_env_vars() {
54                cmd.env(env_var, value);
55            }
56
57            cmd.status().context(
58                "Failed to execute pg_dumpall. Is PostgreSQL client installed?\n\
59                 Install with:\n\
60                 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
61                 - macOS: brew install postgresql\n\
62                 - RHEL/CentOS: sudo yum install postgresql",
63            )
64        },
65        3,                      // Max 3 retries
66        Duration::from_secs(1), // Start with 1 second delay
67        "pg_dumpall (dump globals)",
68    )
69    .await
70    .context(
71        "pg_dumpall failed to dump global objects.\n\
72         \n\
73         Common causes:\n\
74         - Connection authentication failed\n\
75         - User lacks sufficient privileges (need SUPERUSER or pg_read_all_settings role)\n\
76         - Network connectivity issues\n\
77         - Invalid connection string\n\
78         - Connection timeout or network issues",
79    )?;
80
81    tracing::info!("✓ Global objects dumped successfully");
82    Ok(())
83}
84
85/// Update a globals dump so duplicate role creation errors become harmless notices.
86///
87/// `pg_dumpall --globals-only` emits `CREATE ROLE` statements that fail if the
88/// role already exists on the target cluster. When an operator reruns
89/// replication against the same target, those statements cause `psql` to exit
90/// with status 3 which previously triggered noisy retries and prevented the rest
91/// of the globals from being applied. By wrapping each `CREATE ROLE` statement
92/// in a `DO $$ ... EXCEPTION WHEN duplicate_object` block, we allow Postgres to
93/// skip recreating existing roles while still applying subsequent `ALTER ROLE`
94/// and `GRANT` statements.
95pub fn sanitize_globals_dump(path: &str) -> Result<()> {
96    let content = fs::read_to_string(path)
97        .with_context(|| format!("Failed to read globals dump at {}", path))?;
98
99    if let Some(updated) = rewrite_create_role_statements(&content) {
100        fs::write(path, updated)
101            .with_context(|| format!("Failed to update globals dump at {}", path))?;
102    }
103
104    Ok(())
105}
106
107/// Comments out `ALTER ROLE ... SUPERUSER` statements in a globals dump file.
108///
109/// Managed Postgres services (e.g., AWS RDS) often prevent the restore user
110/// from granting SUPERUSER. Commenting out those lines keeps the restore
111/// moving without permission errors.
112pub fn remove_superuser_from_globals(path: &str) -> Result<()> {
113    let content = fs::read_to_string(path)
114        .with_context(|| format!("Failed to read globals dump at {}", path))?;
115
116    let mut updated = String::with_capacity(content.len());
117    let mut modified = false;
118    for line in content.lines() {
119        if line.contains("ALTER ROLE") && line.contains("SUPERUSER") {
120            updated.push_str("-- ");
121            updated.push_str(line);
122            updated.push('\n');
123            modified = true;
124        } else {
125            updated.push_str(line);
126            updated.push('\n');
127        }
128    }
129
130    if modified {
131        fs::write(path, updated)
132            .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
133    }
134
135    Ok(())
136}
137
138/// Removes parameter settings that require superuser privileges (e.g. `log_statement`).
139///
140/// AWS RDS prevents standard replication roles from altering certain GUCs via
141/// `ALTER ROLE ... SET`. Each offending line is commented out so `psql` skips
142/// them without aborting the rest of the globals restore.
143pub fn remove_restricted_guc_settings(path: &str) -> Result<()> {
144    let content = fs::read_to_string(path)
145        .with_context(|| format!("Failed to read globals dump at {}", path))?;
146
147    let mut updated = String::with_capacity(content.len());
148    let mut modified = false;
149
150    for line in content.lines() {
151        let lower_line = line.to_ascii_lowercase();
152        if lower_line.contains("alter role") && lower_line.contains("set") {
153            updated.push_str("-- ");
154            updated.push_str(line);
155            updated.push('\n');
156            modified = true;
157        } else {
158            updated.push_str(line);
159            updated.push('\n');
160        }
161    }
162
163    if modified {
164        fs::write(path, updated)
165            .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
166    }
167
168    Ok(())
169}
170
171/// Comments out `GRANT` statements for roles that are restricted on managed services.
172///
173/// AWS RDS and other managed services may prevent granting certain default roles
174/// like `pg_checkpoint`. This function comments out those statements to allow
175/// the globals restore to proceed without permission errors.
176pub fn remove_restricted_role_grants(path: &str) -> Result<()> {
177    const RESTRICTED_ROLES: &[&str] = &["pg_checkpoint"];
178
179    let content = fs::read_to_string(path)
180        .with_context(|| format!("Failed to read globals dump at {}", path))?;
181
182    let mut updated = String::with_capacity(content.len());
183    let mut modified = false;
184
185    for line in content.lines() {
186        let lower_trimmed = line.trim().to_ascii_lowercase();
187        if lower_trimmed.starts_with("grant ") {
188            let is_restricted = RESTRICTED_ROLES.iter().any(|role| {
189                // e.g. "grant pg_checkpoint to some_user"
190                lower_trimmed.split_whitespace().nth(1) == Some(*role)
191            });
192
193            if is_restricted {
194                updated.push_str("-- ");
195                updated.push_str(line);
196                updated.push('\n');
197                modified = true;
198                continue;
199            }
200        }
201
202        updated.push_str(line);
203        updated.push('\n');
204    }
205
206    if modified {
207        fs::write(path, updated)
208            .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
209    }
210
211    Ok(())
212}
213
214fn rewrite_create_role_statements(sql: &str) -> Option<String> {
215    if sql.is_empty() {
216        return None;
217    }
218
219    let mut output = String::with_capacity(sql.len() + 1024);
220    let mut modified = false;
221    let mut cursor = 0;
222
223    while cursor < sql.len() {
224        if let Some(rel_pos) = sql[cursor..].find('\n') {
225            let end = cursor + rel_pos + 1;
226            let chunk = &sql[cursor..end];
227            if let Some(transformed) = wrap_create_role_line(chunk) {
228                output.push_str(&transformed);
229                modified = true;
230            } else {
231                output.push_str(chunk);
232            }
233            cursor = end;
234        } else {
235            let chunk = &sql[cursor..];
236            if let Some(transformed) = wrap_create_role_line(chunk) {
237                output.push_str(&transformed);
238                modified = true;
239            } else {
240                output.push_str(chunk);
241            }
242            break;
243        }
244    }
245
246    if modified {
247        Some(output)
248    } else {
249        None
250    }
251}
252
253fn wrap_create_role_line(chunk: &str) -> Option<String> {
254    let trimmed = chunk.trim_start();
255    if !trimmed.starts_with("CREATE ROLE ") {
256        return None;
257    }
258
259    let statement = trimmed.trim_end();
260    let statement_body = statement.trim_end_matches(';').trim_end();
261    let leading_ws_len = chunk.len() - trimmed.len();
262    let leading_ws = &chunk[..leading_ws_len];
263    let newline = if chunk.ends_with("\r\n") {
264        "\r\n"
265    } else if chunk.ends_with('\n') {
266        "\n"
267    } else {
268        ""
269    };
270
271    let role_token = extract_role_token(statement_body)?;
272
273    let notice_name = escape_single_quotes(&unquote_role_name(&role_token));
274
275    let mut block = String::with_capacity(chunk.len() + 128);
276    block.push_str(leading_ws);
277    block.push_str("DO $$\n");
278    block.push_str(leading_ws);
279    block.push_str("BEGIN\n");
280    block.push_str(leading_ws);
281    block.push_str("    ");
282    block.push_str(statement_body);
283    block.push_str(";\n");
284    block.push_str(leading_ws);
285    block.push_str("EXCEPTION\n");
286    block.push_str(leading_ws);
287    block.push_str("    WHEN duplicate_object THEN\n");
288    block.push_str(leading_ws);
289    block.push_str("        RAISE NOTICE 'Role ");
290    block.push_str(&notice_name);
291    block.push_str(" already exists on target, skipping CREATE ROLE';\n");
292    block.push_str(leading_ws);
293    block.push_str("END $$;");
294
295    if !newline.is_empty() {
296        block.push_str(newline);
297    }
298
299    Some(block)
300}
301
302fn extract_role_token(statement: &str) -> Option<String> {
303    let remainder = statement.strip_prefix("CREATE ROLE")?.trim_start();
304
305    if remainder.starts_with('"') {
306        let mut idx = 1;
307        let bytes = remainder.as_bytes();
308        while idx < bytes.len() {
309            if bytes[idx] == b'"' {
310                if idx + 1 < bytes.len() && bytes[idx + 1] == b'"' {
311                    idx += 2;
312                    continue;
313                } else {
314                    idx += 1;
315                    break;
316                }
317            }
318            idx += 1;
319        }
320        if idx <= remainder.len() {
321            return Some(remainder[..idx].to_string());
322        }
323        None
324    } else {
325        let mut end = remainder.len();
326        for (i, ch) in remainder.char_indices() {
327            if ch.is_whitespace() || ch == ';' {
328                end = i;
329                break;
330            }
331        }
332        if end == 0 {
333            None
334        } else {
335            Some(remainder[..end].to_string())
336        }
337    }
338}
339
340fn unquote_role_name(token: &str) -> String {
341    if token.starts_with('"') && token.ends_with('"') && token.len() >= 2 {
342        let inner = &token[1..token.len() - 1];
343        inner.replace("\"\"", "\"")
344    } else {
345        token.to_string()
346    }
347}
348
349fn escape_single_quotes(value: &str) -> String {
350    value.replace('\'', "''")
351}
352
353/// Dump schema (DDL) for a specific database
354pub async fn dump_schema(
355    source_url: &str,
356    database: &str,
357    output_path: &str,
358    filter: &ReplicationFilter,
359) -> Result<()> {
360    tracing::info!(
361        "Dumping schema for database '{}' to {}",
362        database,
363        output_path
364    );
365
366    // Parse URL and create .pgpass file for secure authentication
367    let parts = crate::utils::parse_postgres_url(source_url)
368        .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
369    let pgpass = crate::utils::PgPassFile::new(&parts)
370        .context("Failed to create .pgpass file for authentication")?;
371
372    let env_vars = parts.to_pg_env_vars();
373    let output_path_owned = output_path.to_string();
374
375    // Collect filter options
376    let exclude_tables = get_schema_excluded_tables_for_db(filter, database);
377    let include_tables = get_included_tables_for_db(filter, database);
378
379    // Wrap subprocess execution with retry logic
380    crate::utils::retry_subprocess_with_backoff(
381        || {
382            let mut cmd = Command::new("pg_dump");
383            cmd.arg("--schema-only")
384                .arg("--no-owner") // Don't include ownership commands
385                .arg("--no-privileges") // We'll handle privileges separately
386                .arg("--verbose"); // Show progress
387
388            // Add table filtering if specified
389            // Only exclude explicit exclude_tables from schema dump (NOT schema_only or predicate tables)
390            if let Some(ref exclude) = exclude_tables {
391                if !exclude.is_empty() {
392                    for table in exclude {
393                        cmd.arg("--exclude-table").arg(table);
394                    }
395                }
396            }
397
398            // If include_tables is specified, only dump those tables
399            if let Some(ref include) = include_tables {
400                if !include.is_empty() {
401                    for table in include {
402                        cmd.arg("--table").arg(table);
403                    }
404                }
405            }
406
407            cmd.arg("--host")
408                .arg(&parts.host)
409                .arg("--port")
410                .arg(parts.port.to_string())
411                .arg("--dbname")
412                .arg(&parts.database)
413                .arg(format!("--file={}", output_path_owned))
414                .env("PGPASSFILE", pgpass.path())
415                .stdout(Stdio::inherit())
416                .stderr(Stdio::inherit());
417
418            // Add username if specified
419            if let Some(user) = &parts.user {
420                cmd.arg("--username").arg(user);
421            }
422
423            // Apply query parameters as environment variables (SSL, channel_binding, etc.)
424            for (env_var, value) in &env_vars {
425                cmd.env(env_var, value);
426            }
427
428            // Apply TCP keepalive parameters to prevent idle connection timeouts
429            for (env_var, value) in crate::utils::get_keepalive_env_vars() {
430                cmd.env(env_var, value);
431            }
432
433            cmd.status().context(
434                "Failed to execute pg_dump. Is PostgreSQL client installed?\n\
435                 Install with:\n\
436                 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
437                 - macOS: brew install postgresql\n\
438                 - RHEL/CentOS: sudo yum install postgresql",
439            )
440        },
441        3,                      // Max 3 retries
442        Duration::from_secs(1), // Start with 1 second delay
443        "pg_dump (dump schema)",
444    )
445    .await
446    .with_context(|| {
447        format!(
448            "pg_dump failed to dump schema for database '{}'.\n\
449             \n\
450             Common causes:\n\
451             - Database does not exist\n\
452             - Connection authentication failed\n\
453             - User lacks privileges to read database schema\n\
454             - Network connectivity issues\n\
455             - Connection timeout or network issues",
456            database
457        )
458    })?;
459
460    tracing::info!("✓ Schema dumped successfully");
461    Ok(())
462}
463
464/// Dump data for a specific database using optimized directory format
465///
466/// Uses PostgreSQL directory format dump with:
467/// - Parallel dumps for faster performance
468/// - Maximum compression (level 9)
469/// - Large object (blob) support
470/// - Directory output for efficient parallel restore
471///
472/// The number of parallel jobs is automatically determined based on available CPU cores.
473pub async fn dump_data(
474    source_url: &str,
475    database: &str,
476    output_path: &str,
477    filter: &ReplicationFilter,
478) -> Result<()> {
479    // Determine optimal number of parallel jobs (number of CPUs, capped at 8)
480    let num_cpus = std::thread::available_parallelism()
481        .map(|n| n.get().min(8))
482        .unwrap_or(4);
483
484    tracing::info!(
485        "Dumping data for database '{}' to {} (parallel={}, compression=9, format=directory)",
486        database,
487        output_path,
488        num_cpus
489    );
490
491    // Parse URL and create .pgpass file for secure authentication
492    let parts = crate::utils::parse_postgres_url(source_url)
493        .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
494    let pgpass = crate::utils::PgPassFile::new(&parts)
495        .context("Failed to create .pgpass file for authentication")?;
496
497    let env_vars = parts.to_pg_env_vars();
498    let output_path_owned = output_path.to_string();
499
500    // Collect filter options
501    let exclude_tables = get_data_excluded_tables_for_db(filter, database);
502    let include_tables = get_included_tables_for_db(filter, database);
503
504    // Wrap subprocess execution with retry logic
505    crate::utils::retry_subprocess_with_backoff(
506        || {
507            let mut cmd = Command::new("pg_dump");
508            cmd.arg("--data-only")
509                .arg("--no-owner")
510                .arg("--format=directory") // Directory format enables parallel operations
511                .arg("--blobs") // Include large objects (blobs)
512                .arg("--compress=9") // Maximum compression for smaller dump size
513                .arg(format!("--jobs={}", num_cpus)) // Parallel dump jobs
514                .arg("--verbose"); // Show progress
515
516            // Add table filtering if specified
517            // Exclude explicit excludes, schema_only tables, and predicate tables from data dump
518            if let Some(ref exclude) = exclude_tables {
519                if !exclude.is_empty() {
520                    for table in exclude {
521                        cmd.arg("--exclude-table-data").arg(table);
522                    }
523                }
524            }
525
526            // If include_tables is specified, only dump data for those tables
527            if let Some(ref include) = include_tables {
528                if !include.is_empty() {
529                    for table in include {
530                        cmd.arg("--table").arg(table);
531                    }
532                }
533            }
534
535            cmd.arg("--host")
536                .arg(&parts.host)
537                .arg("--port")
538                .arg(parts.port.to_string())
539                .arg("--dbname")
540                .arg(&parts.database)
541                .arg(format!("--file={}", output_path_owned))
542                .env("PGPASSFILE", pgpass.path())
543                .stdout(Stdio::inherit())
544                .stderr(Stdio::inherit());
545
546            // Add username if specified
547            if let Some(user) = &parts.user {
548                cmd.arg("--username").arg(user);
549            }
550
551            // Apply query parameters as environment variables (SSL, channel_binding, etc.)
552            for (env_var, value) in &env_vars {
553                cmd.env(env_var, value);
554            }
555
556            // Apply TCP keepalive parameters to prevent idle connection timeouts
557            for (env_var, value) in crate::utils::get_keepalive_env_vars() {
558                cmd.env(env_var, value);
559            }
560
561            cmd.status().context(
562                "Failed to execute pg_dump. Is PostgreSQL client installed?\n\
563                 Install with:\n\
564                 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
565                 - macOS: brew install postgresql\n\
566                 - RHEL/CentOS: sudo yum install postgresql",
567            )
568        },
569        3,                      // Max 3 retries
570        Duration::from_secs(1), // Start with 1 second delay
571        "pg_dump (dump data)",
572    )
573    .await
574    .with_context(|| {
575        format!(
576            "pg_dump failed to dump data for database '{}'.\n\
577             \n\
578             Common causes:\n\
579             - Database does not exist\n\
580             - Connection authentication failed\n\
581             - User lacks privileges to read table data\n\
582             - Network connectivity issues\n\
583             - Insufficient disk space for dump directory\n\
584             - Output directory already exists (pg_dump requires non-existent path)\n\
585             - Connection timeout or network issues",
586            database
587        )
588    })?;
589
590    tracing::info!(
591        "✓ Data dumped successfully using {} parallel jobs",
592        num_cpus
593    );
594    Ok(())
595}
596
597/// Extract table names to exclude from SCHEMA dumps (--exclude-table flag)
598/// Only excludes explicit exclude_tables - NOT schema_only or predicate tables
599/// (those need their schema created, just not bulk data copied)
600/// Returns schema-qualified names in format: "schema"."table"
601fn get_schema_excluded_tables_for_db(
602    filter: &ReplicationFilter,
603    db_name: &str,
604) -> Option<Vec<String>> {
605    let mut tables = BTreeSet::new();
606
607    // Handle explicit exclude_tables (format: "database.table")
608    // These tables are completely excluded (no schema, no data)
609    if let Some(explicit) = filter.exclude_tables() {
610        for full_name in explicit {
611            let parts: Vec<&str> = full_name.split('.').collect();
612            if parts.len() == 2 && parts[0] == db_name {
613                // Format as "public"."table" for consistency
614                tables.insert(format!("\"public\".\"{}\"", parts[1]));
615            }
616        }
617    }
618
619    if tables.is_empty() {
620        None
621    } else {
622        Some(tables.into_iter().collect())
623    }
624}
625
626/// Extract table names to exclude from DATA dumps (--exclude-table-data flag)
627/// Excludes explicit excludes, schema_only tables, and predicate tables
628/// (predicate tables will be copied separately with filtering)
629/// Returns schema-qualified names in format: "schema"."table"
630fn get_data_excluded_tables_for_db(
631    filter: &ReplicationFilter,
632    db_name: &str,
633) -> Option<Vec<String>> {
634    let mut tables = BTreeSet::new();
635
636    // Handle explicit exclude_tables (format: "database.table")
637    // Default to public schema for backward compatibility
638    if let Some(explicit) = filter.exclude_tables() {
639        for full_name in explicit {
640            let parts: Vec<&str> = full_name.split('.').collect();
641            if parts.len() == 2 && parts[0] == db_name {
642                // Format as "public"."table" for consistency
643                tables.insert(format!("\"public\".\"{}\"", parts[1]));
644            }
645        }
646    }
647
648    // schema_only_tables and predicate_tables already return schema-qualified names
649    for table in filter.schema_only_tables(db_name) {
650        tables.insert(table);
651    }
652
653    for (table, _) in filter.predicate_tables(db_name) {
654        tables.insert(table);
655    }
656
657    if tables.is_empty() {
658        None
659    } else {
660        Some(tables.into_iter().collect())
661    }
662}
663
664/// Extract table names for a specific database from include_tables filter
665/// Returns schema-qualified names in format: "schema"."table"
666fn get_included_tables_for_db(filter: &ReplicationFilter, db_name: &str) -> Option<Vec<String>> {
667    filter.include_tables().map(|tables| {
668        tables
669            .iter()
670            .filter_map(|full_name| {
671                let parts: Vec<&str> = full_name.split('.').collect();
672                if parts.len() == 2 && parts[0] == db_name {
673                    // Format as "public"."table" for consistency
674                    Some(format!("\"public\".\"{}\"", parts[1]))
675                } else {
676                    None
677                }
678            })
679            .collect()
680    })
681}
682
683#[cfg(test)]
684mod tests {
685    use super::*;
686    use tempfile::tempdir;
687
688    #[tokio::test]
689    #[ignore]
690    async fn test_dump_globals() {
691        let url = std::env::var("TEST_SOURCE_URL").unwrap();
692        let dir = tempdir().unwrap();
693        let output = dir.path().join("globals.sql");
694
695        let result = dump_globals(&url, output.to_str().unwrap()).await;
696
697        assert!(result.is_ok());
698        assert!(output.exists());
699
700        // Verify file contains SQL
701        let content = std::fs::read_to_string(&output).unwrap();
702        assert!(content.contains("CREATE ROLE") || !content.is_empty());
703    }
704
705    #[tokio::test]
706    #[ignore]
707    async fn test_dump_schema() {
708        let url = std::env::var("TEST_SOURCE_URL").unwrap();
709        let dir = tempdir().unwrap();
710        let output = dir.path().join("schema.sql");
711
712        // Extract database name from URL
713        let db = url.split('/').next_back().unwrap_or("postgres");
714
715        let filter = crate::filters::ReplicationFilter::empty();
716        let result = dump_schema(&url, db, output.to_str().unwrap(), &filter).await;
717
718        assert!(result.is_ok());
719        assert!(output.exists());
720    }
721
722    #[test]
723    fn test_get_schema_excluded_tables_for_db() {
724        let filter = crate::filters::ReplicationFilter::new(
725            None,
726            None,
727            None,
728            Some(vec![
729                "db1.table1".to_string(),
730                "db1.table2".to_string(),
731                "db2.table3".to_string(),
732            ]),
733        )
734        .unwrap();
735
736        // Schema exclusion only includes explicit exclude_tables
737        let tables = get_schema_excluded_tables_for_db(&filter, "db1").unwrap();
738        // Should return schema-qualified names
739        assert_eq!(
740            tables,
741            vec!["\"public\".\"table1\"", "\"public\".\"table2\""]
742        );
743
744        let tables = get_schema_excluded_tables_for_db(&filter, "db2").unwrap();
745        assert_eq!(tables, vec!["\"public\".\"table3\""]);
746
747        let tables = get_schema_excluded_tables_for_db(&filter, "db3");
748        assert!(tables.is_none() || tables.unwrap().is_empty());
749    }
750
751    #[test]
752    fn test_get_data_excluded_tables_for_db() {
753        let filter = crate::filters::ReplicationFilter::new(
754            None,
755            None,
756            None,
757            Some(vec![
758                "db1.table1".to_string(),
759                "db1.table2".to_string(),
760                "db2.table3".to_string(),
761            ]),
762        )
763        .unwrap();
764
765        // Data exclusion includes explicit exclude_tables, schema_only, and predicate tables
766        let tables = get_data_excluded_tables_for_db(&filter, "db1").unwrap();
767        // Should return schema-qualified names
768        assert_eq!(
769            tables,
770            vec!["\"public\".\"table1\"", "\"public\".\"table2\""]
771        );
772
773        let tables = get_data_excluded_tables_for_db(&filter, "db2").unwrap();
774        assert_eq!(tables, vec!["\"public\".\"table3\""]);
775
776        let tables = get_data_excluded_tables_for_db(&filter, "db3");
777        assert!(tables.is_none() || tables.unwrap().is_empty());
778    }
779
780    #[test]
781    fn test_get_included_tables_for_db() {
782        let filter = crate::filters::ReplicationFilter::new(
783            None,
784            None,
785            Some(vec![
786                "db1.users".to_string(),
787                "db1.orders".to_string(),
788                "db2.products".to_string(),
789            ]),
790            None,
791        )
792        .unwrap();
793
794        let tables = get_included_tables_for_db(&filter, "db1").unwrap();
795        // Should return schema-qualified names in original order
796        assert_eq!(
797            tables,
798            vec!["\"public\".\"users\"", "\"public\".\"orders\""]
799        );
800
801        let tables = get_included_tables_for_db(&filter, "db2").unwrap();
802        assert_eq!(tables, vec!["\"public\".\"products\""]);
803
804        let tables = get_included_tables_for_db(&filter, "db3");
805        assert!(tables.is_none() || tables.unwrap().is_empty());
806    }
807
808    #[test]
809    fn test_get_schema_excluded_tables_for_db_with_empty_filter() {
810        let filter = crate::filters::ReplicationFilter::empty();
811        let tables = get_schema_excluded_tables_for_db(&filter, "db1");
812        assert!(tables.is_none());
813    }
814
815    #[test]
816    fn test_get_data_excluded_tables_for_db_with_empty_filter() {
817        let filter = crate::filters::ReplicationFilter::empty();
818        let tables = get_data_excluded_tables_for_db(&filter, "db1");
819        assert!(tables.is_none());
820    }
821
822    #[test]
823    fn test_get_included_tables_for_db_with_empty_filter() {
824        let filter = crate::filters::ReplicationFilter::empty();
825        let tables = get_included_tables_for_db(&filter, "db1");
826        assert!(tables.is_none());
827    }
828
829    #[test]
830    fn test_rewrite_create_role_statements_wraps_unquoted_role() {
831        let sql = "CREATE ROLE replicator WITH LOGIN;\nALTER ROLE replicator WITH LOGIN;\n";
832        let rewritten = rewrite_create_role_statements(sql).expect("rewrite expected");
833
834        assert!(rewritten.contains("DO $$"));
835        assert!(rewritten.contains("Role replicator already exists"));
836        assert!(rewritten.contains("CREATE ROLE replicator WITH LOGIN;"));
837        assert!(rewritten.contains("ALTER ROLE replicator WITH LOGIN;"));
838    }
839
840    #[test]
841    fn test_rewrite_create_role_statements_wraps_quoted_role() {
842        let sql = "    CREATE ROLE \"Andre Admin\";\n";
843        let rewritten = rewrite_create_role_statements(sql).expect("rewrite expected");
844
845        assert!(rewritten.contains("DO $$"));
846        assert!(rewritten.contains("Andre Admin already exists"));
847        assert!(rewritten.contains("CREATE ROLE \"Andre Admin\""));
848        assert!(rewritten.starts_with("    DO $$"));
849    }
850
851    #[test]
852    fn test_rewrite_create_role_statements_noop_when_absent() {
853        let sql = "ALTER ROLE existing WITH LOGIN;\n";
854        assert!(rewrite_create_role_statements(sql).is_none());
855    }
856}