database_replicator/migration/
dump.rs1use crate::filters::ReplicationFilter;
5use anyhow::{Context, Result};
6use std::collections::BTreeSet;
7use std::fs;
8use std::process::{Command, Stdio};
9use std::time::Duration;
10
11pub async fn dump_globals(source_url: &str, output_path: &str) -> Result<()> {
13 tracing::info!("Dumping global objects to {}", output_path);
14
15 let parts = crate::utils::parse_postgres_url(source_url)
17 .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
18 let pgpass = crate::utils::PgPassFile::new(&parts)
19 .context("Failed to create .pgpass file for authentication")?;
20
21 let env_vars = parts.to_pg_env_vars();
22 let output_path_owned = output_path.to_string();
23
24 crate::utils::retry_subprocess_with_backoff(
26 || {
27 let mut cmd = Command::new("pg_dumpall");
28 cmd.arg("--globals-only")
29 .arg("--no-role-passwords") .arg("--verbose") .arg("--host")
32 .arg(&parts.host)
33 .arg("--port")
34 .arg(parts.port.to_string())
35 .arg("--database")
36 .arg(&parts.database)
37 .arg(format!("--file={}", output_path_owned))
38 .env("PGPASSFILE", pgpass.path())
39 .stdout(Stdio::inherit())
40 .stderr(Stdio::inherit());
41
42 if let Some(user) = &parts.user {
44 cmd.arg("--username").arg(user);
45 }
46
47 for (env_var, value) in &env_vars {
49 cmd.env(env_var, value);
50 }
51
52 for (env_var, value) in crate::utils::get_keepalive_env_vars() {
54 cmd.env(env_var, value);
55 }
56
57 cmd.env("PGCONNECT_TIMEOUT", "30"); cmd.status().context(
61 "Failed to execute pg_dumpall. Is PostgreSQL client installed?\n\
62 Install with:\n\
63 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
64 - macOS: brew install postgresql\n\
65 - RHEL/CentOS: sudo yum install postgresql",
66 )
67 },
68 3, Duration::from_secs(1), "pg_dumpall (dump globals)",
71 )
72 .await
73 .context(
74 "pg_dumpall failed to dump global objects.\n\
75 \n\
76 Common causes:\n\
77 - Connection authentication failed\n\
78 - User lacks sufficient privileges (need SUPERUSER or pg_read_all_settings role)\n\
79 - Network connectivity issues\n\
80 - Invalid connection string\n\
81 - Connection timeout or network issues",
82 )?;
83
84 tracing::info!("✓ Global objects dumped successfully");
85 Ok(())
86}
87
88pub fn sanitize_globals_dump(path: &str) -> Result<()> {
99 let content = fs::read_to_string(path)
100 .with_context(|| format!("Failed to read globals dump at {}", path))?;
101
102 if let Some(updated) = rewrite_create_role_statements(&content) {
103 fs::write(path, updated)
104 .with_context(|| format!("Failed to update globals dump at {}", path))?;
105 }
106
107 Ok(())
108}
109
110pub fn remove_superuser_from_globals(path: &str) -> Result<()> {
116 let content = fs::read_to_string(path)
117 .with_context(|| format!("Failed to read globals dump at {}", path))?;
118
119 let mut updated = String::with_capacity(content.len());
120 let mut modified = false;
121 for line in content.lines() {
122 if line.contains("ALTER ROLE") && line.contains("SUPERUSER") {
123 updated.push_str("-- ");
124 updated.push_str(line);
125 updated.push('\n');
126 modified = true;
127 } else {
128 updated.push_str(line);
129 updated.push('\n');
130 }
131 }
132
133 if modified {
134 fs::write(path, updated)
135 .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
136 }
137
138 Ok(())
139}
140
141pub fn remove_restricted_guc_settings(path: &str) -> Result<()> {
147 let content = fs::read_to_string(path)
148 .with_context(|| format!("Failed to read globals dump at {}", path))?;
149
150 let mut updated = String::with_capacity(content.len());
151 let mut modified = false;
152
153 for line in content.lines() {
154 let lower_line = line.to_ascii_lowercase();
155 if lower_line.contains("alter role") && lower_line.contains("set") {
156 updated.push_str("-- ");
157 updated.push_str(line);
158 updated.push('\n');
159 modified = true;
160 } else {
161 updated.push_str(line);
162 updated.push('\n');
163 }
164 }
165
166 if modified {
167 fs::write(path, updated)
168 .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
169 }
170
171 Ok(())
172}
173
174pub fn remove_tablespace_statements(path: &str) -> Result<()> {
182 let content = fs::read_to_string(path)
183 .with_context(|| format!("Failed to read globals dump at {}", path))?;
184
185 let mut updated = String::with_capacity(content.len());
186 let mut modified = false;
187
188 for line in content.lines() {
189 let lower_trimmed = line.trim().to_ascii_lowercase();
190
191 let is_create_tablespace = lower_trimmed.starts_with("create tablespace");
193
194 let references_rds_tablespace = lower_trimmed.contains("'rds_")
201 || lower_trimmed.contains("\"rds_")
202 || lower_trimmed.contains("tablespace rds_");
203
204 if is_create_tablespace || references_rds_tablespace {
205 updated.push_str("-- ");
206 updated.push_str(line);
207 updated.push('\n');
208 modified = true;
209 } else {
210 updated.push_str(line);
211 updated.push('\n');
212 }
213 }
214
215 if modified {
216 fs::write(path, updated)
217 .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
218 }
219
220 Ok(())
221}
222
223pub fn remove_restricted_role_grants(path: &str) -> Result<()> {
229 const RESTRICTED_ROLES: &[&str] = &[
231 "pg_checkpoint",
232 "pg_read_all_data",
233 "pg_write_all_data",
234 "pg_read_all_settings",
235 "pg_read_all_stats",
236 "pg_stat_scan_tables",
237 "pg_monitor",
238 "pg_signal_backend",
239 "pg_read_server_files",
240 "pg_write_server_files",
241 "pg_execute_server_program",
242 "pg_create_subscription",
243 "pg_maintain",
244 "pg_use_reserved_connections",
245 ];
246
247 const RESTRICTED_GRANTORS: &[&str] = &[
249 "rdsadmin",
250 "rds_superuser",
251 "rdsrepladmin",
252 "rds_replication",
253 ];
254
255 let content = fs::read_to_string(path)
256 .with_context(|| format!("Failed to read globals dump at {}", path))?;
257
258 let mut updated = String::with_capacity(content.len());
259 let mut modified = false;
260
261 for line in content.lines() {
262 let lower_trimmed = line.trim().to_ascii_lowercase();
263 if lower_trimmed.starts_with("grant ") {
264 let is_restricted_role = RESTRICTED_ROLES.iter().any(|role| {
266 lower_trimmed
269 .split_whitespace()
270 .nth(1)
271 .map(|r| r.trim_matches('"') == *role)
272 .unwrap_or(false)
273 });
274
275 let has_restricted_grantor = RESTRICTED_GRANTORS.iter().any(|grantor| {
277 lower_trimmed.contains(&format!("granted by {}", grantor))
279 || lower_trimmed.contains(&format!("granted by \"{}\"", grantor))
280 });
281
282 if is_restricted_role || has_restricted_grantor {
283 updated.push_str("-- ");
284 updated.push_str(line);
285 updated.push('\n');
286 modified = true;
287 continue;
288 }
289 }
290
291 updated.push_str(line);
292 updated.push('\n');
293 }
294
295 if modified {
296 fs::write(path, updated)
297 .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
298 }
299
300 Ok(())
301}
302
303fn rewrite_create_role_statements(sql: &str) -> Option<String> {
304 if sql.is_empty() {
305 return None;
306 }
307
308 let mut output = String::with_capacity(sql.len() + 1024);
309 let mut modified = false;
310 let mut cursor = 0;
311
312 while cursor < sql.len() {
313 if let Some(rel_pos) = sql[cursor..].find('\n') {
314 let end = cursor + rel_pos + 1;
315 let chunk = &sql[cursor..end];
316 if let Some(transformed) = wrap_create_role_line(chunk) {
317 output.push_str(&transformed);
318 modified = true;
319 } else {
320 output.push_str(chunk);
321 }
322 cursor = end;
323 } else {
324 let chunk = &sql[cursor..];
325 if let Some(transformed) = wrap_create_role_line(chunk) {
326 output.push_str(&transformed);
327 modified = true;
328 } else {
329 output.push_str(chunk);
330 }
331 break;
332 }
333 }
334
335 if modified {
336 Some(output)
337 } else {
338 None
339 }
340}
341
342fn wrap_create_role_line(chunk: &str) -> Option<String> {
343 let trimmed = chunk.trim_start();
344 if !trimmed.starts_with("CREATE ROLE ") {
345 return None;
346 }
347
348 let statement = trimmed.trim_end();
349 let statement_body = statement.trim_end_matches(';').trim_end();
350 let leading_ws_len = chunk.len() - trimmed.len();
351 let leading_ws = &chunk[..leading_ws_len];
352 let newline = if chunk.ends_with("\r\n") {
353 "\r\n"
354 } else if chunk.ends_with('\n') {
355 "\n"
356 } else {
357 ""
358 };
359
360 let role_token = extract_role_token(statement_body)?;
361
362 let notice_name = escape_single_quotes(&unquote_role_name(&role_token));
363
364 let mut block = String::with_capacity(chunk.len() + 128);
365 block.push_str(leading_ws);
366 block.push_str("DO $$\n");
367 block.push_str(leading_ws);
368 block.push_str("BEGIN\n");
369 block.push_str(leading_ws);
370 block.push_str(" ");
371 block.push_str(statement_body);
372 block.push_str(";\n");
373 block.push_str(leading_ws);
374 block.push_str("EXCEPTION\n");
375 block.push_str(leading_ws);
376 block.push_str(" WHEN duplicate_object THEN\n");
377 block.push_str(leading_ws);
378 block.push_str(" RAISE NOTICE 'Role ");
379 block.push_str(¬ice_name);
380 block.push_str(" already exists on target, skipping CREATE ROLE';\n");
381 block.push_str(leading_ws);
382 block.push_str("END $$;");
383
384 if !newline.is_empty() {
385 block.push_str(newline);
386 }
387
388 Some(block)
389}
390
391fn extract_role_token(statement: &str) -> Option<String> {
392 let remainder = statement.strip_prefix("CREATE ROLE")?.trim_start();
393
394 if remainder.starts_with('"') {
395 let mut idx = 1;
396 let bytes = remainder.as_bytes();
397 while idx < bytes.len() {
398 if bytes[idx] == b'"' {
399 if idx + 1 < bytes.len() && bytes[idx + 1] == b'"' {
400 idx += 2;
401 continue;
402 } else {
403 idx += 1;
404 break;
405 }
406 }
407 idx += 1;
408 }
409 if idx <= remainder.len() {
410 return Some(remainder[..idx].to_string());
411 }
412 None
413 } else {
414 let mut end = remainder.len();
415 for (i, ch) in remainder.char_indices() {
416 if ch.is_whitespace() || ch == ';' {
417 end = i;
418 break;
419 }
420 }
421 if end == 0 {
422 None
423 } else {
424 Some(remainder[..end].to_string())
425 }
426 }
427}
428
429fn unquote_role_name(token: &str) -> String {
430 if token.starts_with('"') && token.ends_with('"') && token.len() >= 2 {
431 let inner = &token[1..token.len() - 1];
432 inner.replace("\"\"", "\"")
433 } else {
434 token.to_string()
435 }
436}
437
438fn escape_single_quotes(value: &str) -> String {
439 value.replace('\'', "''")
440}
441
442pub async fn dump_schema(
444 source_url: &str,
445 database: &str,
446 output_path: &str,
447 filter: &ReplicationFilter,
448) -> Result<()> {
449 tracing::info!(
450 "Dumping schema for database '{}' to {}",
451 database,
452 output_path
453 );
454
455 let parts = crate::utils::parse_postgres_url(source_url)
457 .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
458 let pgpass = crate::utils::PgPassFile::new(&parts)
459 .context("Failed to create .pgpass file for authentication")?;
460
461 let env_vars = parts.to_pg_env_vars();
462 let output_path_owned = output_path.to_string();
463
464 let exclude_tables = get_schema_excluded_tables_for_db(filter, database);
466 let include_tables = get_included_tables_for_db(filter, database);
467
468 crate::utils::retry_subprocess_with_backoff(
470 || {
471 let mut cmd = Command::new("pg_dump");
472 cmd.arg("--schema-only")
473 .arg("--no-owner") .arg("--no-privileges") .arg("--verbose"); if let Some(ref exclude) = exclude_tables {
480 if !exclude.is_empty() {
481 for table in exclude {
482 cmd.arg("--exclude-table").arg(table);
483 }
484 }
485 }
486
487 if let Some(ref include) = include_tables {
489 if !include.is_empty() {
490 for table in include {
491 cmd.arg("--table").arg(table);
492 }
493 }
494 }
495
496 cmd.arg("--host")
497 .arg(&parts.host)
498 .arg("--port")
499 .arg(parts.port.to_string())
500 .arg("--dbname")
501 .arg(&parts.database)
502 .arg(format!("--file={}", output_path_owned))
503 .env("PGPASSFILE", pgpass.path())
504 .stdout(Stdio::inherit())
505 .stderr(Stdio::inherit());
506
507 if let Some(user) = &parts.user {
509 cmd.arg("--username").arg(user);
510 }
511
512 for (env_var, value) in &env_vars {
514 cmd.env(env_var, value);
515 }
516
517 for (env_var, value) in crate::utils::get_keepalive_env_vars() {
519 cmd.env(env_var, value);
520 }
521
522 cmd.env("PGCONNECT_TIMEOUT", "30"); cmd.status().context(
526 "Failed to execute pg_dump. Is PostgreSQL client installed?\n\
527 Install with:\n\
528 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
529 - macOS: brew install postgresql\n\
530 - RHEL/CentOS: sudo yum install postgresql",
531 )
532 },
533 3, Duration::from_secs(1), "pg_dump (dump schema)",
536 )
537 .await
538 .with_context(|| {
539 format!(
540 "pg_dump failed to dump schema for database '{}'.\n\
541 \n\
542 Common causes:\n\
543 - Database does not exist\n\
544 - Connection authentication failed\n\
545 - User lacks privileges to read database schema\n\
546 - Network connectivity issues\n\
547 - Connection timeout or network issues",
548 database
549 )
550 })?;
551
552 tracing::info!("✓ Schema dumped successfully");
553 Ok(())
554}
555
556pub async fn dump_data(
566 source_url: &str,
567 database: &str,
568 output_path: &str,
569 filter: &ReplicationFilter,
570) -> Result<()> {
571 let num_cpus = std::thread::available_parallelism()
573 .map(|n| n.get().min(8))
574 .unwrap_or(4);
575
576 tracing::info!(
577 "Dumping data for database '{}' to {} (parallel={}, compression=9, format=directory)",
578 database,
579 output_path,
580 num_cpus
581 );
582
583 let parts = crate::utils::parse_postgres_url(source_url)
585 .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
586 let pgpass = crate::utils::PgPassFile::new(&parts)
587 .context("Failed to create .pgpass file for authentication")?;
588
589 let env_vars = parts.to_pg_env_vars();
590 let output_path_owned = output_path.to_string();
591
592 let exclude_tables = get_data_excluded_tables_for_db(filter, database);
594 let include_tables = get_included_tables_for_db(filter, database);
595
596 crate::utils::retry_subprocess_with_backoff(
598 || {
599 let mut cmd = Command::new("pg_dump");
600 cmd.arg("--data-only")
601 .arg("--no-owner")
602 .arg("--format=directory") .arg("--blobs") .arg("--compress=9") .arg(format!("--jobs={}", num_cpus)) .arg("--verbose"); if let Some(ref exclude) = exclude_tables {
611 if !exclude.is_empty() {
612 for table in exclude {
613 cmd.arg("--exclude-table-data").arg(table);
614 }
615 }
616 }
617
618 if let Some(ref include) = include_tables {
620 if !include.is_empty() {
621 for table in include {
622 cmd.arg("--table").arg(table);
623 }
624 }
625 }
626
627 cmd.arg("--host")
628 .arg(&parts.host)
629 .arg("--port")
630 .arg(parts.port.to_string())
631 .arg("--dbname")
632 .arg(&parts.database)
633 .arg(format!("--file={}", output_path_owned))
634 .env("PGPASSFILE", pgpass.path())
635 .stdout(Stdio::inherit())
636 .stderr(Stdio::inherit());
637
638 if let Some(user) = &parts.user {
640 cmd.arg("--username").arg(user);
641 }
642
643 for (env_var, value) in &env_vars {
645 cmd.env(env_var, value);
646 }
647
648 for (env_var, value) in crate::utils::get_keepalive_env_vars() {
650 cmd.env(env_var, value);
651 }
652
653 cmd.env("PGCONNECT_TIMEOUT", "30"); cmd.status().context(
657 "Failed to execute pg_dump. Is PostgreSQL client installed?\n\
658 Install with:\n\
659 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
660 - macOS: brew install postgresql\n\
661 - RHEL/CentOS: sudo yum install postgresql",
662 )
663 },
664 3, Duration::from_secs(1), "pg_dump (dump data)",
667 )
668 .await
669 .with_context(|| {
670 format!(
671 "pg_dump failed to dump data for database '{}'.\n\
672 \n\
673 Common causes:\n\
674 - Database does not exist\n\
675 - Connection authentication failed\n\
676 - User lacks privileges to read table data\n\
677 - Network connectivity issues\n\
678 - Insufficient disk space for dump directory\n\
679 - Output directory already exists (pg_dump requires non-existent path)\n\
680 - Connection timeout or network issues",
681 database
682 )
683 })?;
684
685 tracing::info!(
686 "✓ Data dumped successfully using {} parallel jobs",
687 num_cpus
688 );
689 Ok(())
690}
691
692fn get_schema_excluded_tables_for_db(
697 filter: &ReplicationFilter,
698 db_name: &str,
699) -> Option<Vec<String>> {
700 let mut tables = BTreeSet::new();
701
702 if let Some(explicit) = filter.exclude_tables() {
705 for full_name in explicit {
706 let parts: Vec<&str> = full_name.split('.').collect();
707 if parts.len() == 2 && parts[0] == db_name {
708 tables.insert(format!("\"public\".\"{}\"", parts[1]));
710 }
711 }
712 }
713
714 if tables.is_empty() {
715 None
716 } else {
717 Some(tables.into_iter().collect())
718 }
719}
720
721fn get_data_excluded_tables_for_db(
726 filter: &ReplicationFilter,
727 db_name: &str,
728) -> Option<Vec<String>> {
729 let mut tables = BTreeSet::new();
730
731 if let Some(explicit) = filter.exclude_tables() {
734 for full_name in explicit {
735 let parts: Vec<&str> = full_name.split('.').collect();
736 if parts.len() == 2 && parts[0] == db_name {
737 tables.insert(format!("\"public\".\"{}\"", parts[1]));
739 }
740 }
741 }
742
743 for table in filter.schema_only_tables(db_name) {
745 tables.insert(table);
746 }
747
748 for (table, _) in filter.predicate_tables(db_name) {
749 tables.insert(table);
750 }
751
752 if tables.is_empty() {
753 None
754 } else {
755 Some(tables.into_iter().collect())
756 }
757}
758
759fn get_included_tables_for_db(filter: &ReplicationFilter, db_name: &str) -> Option<Vec<String>> {
762 filter.include_tables().map(|tables| {
763 tables
764 .iter()
765 .filter_map(|full_name| {
766 let parts: Vec<&str> = full_name.split('.').collect();
767 if parts.len() == 2 && parts[0] == db_name {
768 Some(format!("\"public\".\"{}\"", parts[1]))
770 } else {
771 None
772 }
773 })
774 .collect()
775 })
776}
777
778#[cfg(test)]
779mod tests {
780 use super::*;
781 use tempfile::tempdir;
782
783 #[tokio::test]
784 #[ignore]
785 async fn test_dump_globals() {
786 let url = std::env::var("TEST_SOURCE_URL").unwrap();
787 let dir = tempdir().unwrap();
788 let output = dir.path().join("globals.sql");
789
790 let result = dump_globals(&url, output.to_str().unwrap()).await;
791
792 assert!(result.is_ok());
793 assert!(output.exists());
794
795 let content = std::fs::read_to_string(&output).unwrap();
797 assert!(content.contains("CREATE ROLE") || !content.is_empty());
798 }
799
800 #[tokio::test]
801 #[ignore]
802 async fn test_dump_schema() {
803 let url = std::env::var("TEST_SOURCE_URL").unwrap();
804 let dir = tempdir().unwrap();
805 let output = dir.path().join("schema.sql");
806
807 let db = url.split('/').next_back().unwrap_or("postgres");
809
810 let filter = crate::filters::ReplicationFilter::empty();
811 let result = dump_schema(&url, db, output.to_str().unwrap(), &filter).await;
812
813 assert!(result.is_ok());
814 assert!(output.exists());
815 }
816
817 #[test]
818 fn test_get_schema_excluded_tables_for_db() {
819 let filter = crate::filters::ReplicationFilter::new(
820 None,
821 None,
822 None,
823 Some(vec![
824 "db1.table1".to_string(),
825 "db1.table2".to_string(),
826 "db2.table3".to_string(),
827 ]),
828 )
829 .unwrap();
830
831 let tables = get_schema_excluded_tables_for_db(&filter, "db1").unwrap();
833 assert_eq!(
835 tables,
836 vec!["\"public\".\"table1\"", "\"public\".\"table2\""]
837 );
838
839 let tables = get_schema_excluded_tables_for_db(&filter, "db2").unwrap();
840 assert_eq!(tables, vec!["\"public\".\"table3\""]);
841
842 let tables = get_schema_excluded_tables_for_db(&filter, "db3");
843 assert!(tables.is_none() || tables.unwrap().is_empty());
844 }
845
846 #[test]
847 fn test_get_data_excluded_tables_for_db() {
848 let filter = crate::filters::ReplicationFilter::new(
849 None,
850 None,
851 None,
852 Some(vec![
853 "db1.table1".to_string(),
854 "db1.table2".to_string(),
855 "db2.table3".to_string(),
856 ]),
857 )
858 .unwrap();
859
860 let tables = get_data_excluded_tables_for_db(&filter, "db1").unwrap();
862 assert_eq!(
864 tables,
865 vec!["\"public\".\"table1\"", "\"public\".\"table2\""]
866 );
867
868 let tables = get_data_excluded_tables_for_db(&filter, "db2").unwrap();
869 assert_eq!(tables, vec!["\"public\".\"table3\""]);
870
871 let tables = get_data_excluded_tables_for_db(&filter, "db3");
872 assert!(tables.is_none() || tables.unwrap().is_empty());
873 }
874
875 #[test]
876 fn test_get_included_tables_for_db() {
877 let filter = crate::filters::ReplicationFilter::new(
878 None,
879 None,
880 Some(vec![
881 "db1.users".to_string(),
882 "db1.orders".to_string(),
883 "db2.products".to_string(),
884 ]),
885 None,
886 )
887 .unwrap();
888
889 let tables = get_included_tables_for_db(&filter, "db1").unwrap();
890 assert_eq!(
892 tables,
893 vec!["\"public\".\"users\"", "\"public\".\"orders\""]
894 );
895
896 let tables = get_included_tables_for_db(&filter, "db2").unwrap();
897 assert_eq!(tables, vec!["\"public\".\"products\""]);
898
899 let tables = get_included_tables_for_db(&filter, "db3");
900 assert!(tables.is_none() || tables.unwrap().is_empty());
901 }
902
903 #[test]
904 fn test_get_schema_excluded_tables_for_db_with_empty_filter() {
905 let filter = crate::filters::ReplicationFilter::empty();
906 let tables = get_schema_excluded_tables_for_db(&filter, "db1");
907 assert!(tables.is_none());
908 }
909
910 #[test]
911 fn test_get_data_excluded_tables_for_db_with_empty_filter() {
912 let filter = crate::filters::ReplicationFilter::empty();
913 let tables = get_data_excluded_tables_for_db(&filter, "db1");
914 assert!(tables.is_none());
915 }
916
917 #[test]
918 fn test_get_included_tables_for_db_with_empty_filter() {
919 let filter = crate::filters::ReplicationFilter::empty();
920 let tables = get_included_tables_for_db(&filter, "db1");
921 assert!(tables.is_none());
922 }
923
924 #[test]
925 fn test_rewrite_create_role_statements_wraps_unquoted_role() {
926 let sql = "CREATE ROLE replicator WITH LOGIN;\nALTER ROLE replicator WITH LOGIN;\n";
927 let rewritten = rewrite_create_role_statements(sql).expect("rewrite expected");
928
929 assert!(rewritten.contains("DO $$"));
930 assert!(rewritten.contains("Role replicator already exists"));
931 assert!(rewritten.contains("CREATE ROLE replicator WITH LOGIN;"));
932 assert!(rewritten.contains("ALTER ROLE replicator WITH LOGIN;"));
933 }
934
935 #[test]
936 fn test_rewrite_create_role_statements_wraps_quoted_role() {
937 let sql = " CREATE ROLE \"Andre Admin\";\n";
938 let rewritten = rewrite_create_role_statements(sql).expect("rewrite expected");
939
940 assert!(rewritten.contains("DO $$"));
941 assert!(rewritten.contains("Andre Admin already exists"));
942 assert!(rewritten.contains("CREATE ROLE \"Andre Admin\""));
943 assert!(rewritten.starts_with(" DO $$"));
944 }
945
946 #[test]
947 fn test_rewrite_create_role_statements_noop_when_absent() {
948 let sql = "ALTER ROLE existing WITH LOGIN;\n";
949 assert!(rewrite_create_role_statements(sql).is_none());
950 }
951
952 #[test]
953 fn test_remove_restricted_role_grants() {
954 let dir = tempdir().unwrap();
955 let globals_file = dir.path().join("globals.sql");
956
957 let content = r#"CREATE ROLE myuser;
959ALTER ROLE myuser WITH LOGIN;
960GRANT pg_checkpoint TO myuser;
961GRANT "pg_read_all_stats" TO myuser;
962GRANT pg_monitor TO myuser;
963GRANT myrole TO myuser;
964GRANT SELECT ON TABLE users TO myuser;
965GRANT rds_superuser TO myuser GRANTED BY rdsadmin;
966GRANT ALL ON SCHEMA public TO myuser GRANTED BY "rdsadmin";
967GRANT SELECT ON TABLE orders TO myuser GRANTED BY postgres;
968"#;
969 std::fs::write(&globals_file, content).unwrap();
970
971 remove_restricted_role_grants(globals_file.to_str().unwrap()).unwrap();
973
974 let result = std::fs::read_to_string(&globals_file).unwrap();
976
977 assert!(result.contains("-- GRANT pg_checkpoint TO myuser;"));
979 assert!(result.contains("-- GRANT \"pg_read_all_stats\" TO myuser;"));
980 assert!(result.contains("-- GRANT pg_monitor TO myuser;"));
981
982 assert!(result.contains("-- GRANT rds_superuser TO myuser GRANTED BY rdsadmin;"));
984 assert!(result.contains("-- GRANT ALL ON SCHEMA public TO myuser GRANTED BY \"rdsadmin\";"));
985
986 assert!(result.contains("\nGRANT myrole TO myuser;\n"));
988 assert!(result.contains("\nGRANT SELECT ON TABLE users TO myuser;\n"));
989 assert!(result.contains("\nGRANT SELECT ON TABLE orders TO myuser GRANTED BY postgres;\n"));
990
991 assert!(result.contains("CREATE ROLE myuser;"));
993 assert!(result.contains("ALTER ROLE myuser WITH LOGIN;"));
994 }
995}