database_replicator/migration/
dump.rs1use crate::filters::ReplicationFilter;
5use anyhow::{Context, Result};
6use std::collections::BTreeSet;
7use std::fs;
8use std::process::{Command, Stdio};
9use std::time::Duration;
10
11pub async fn dump_globals(source_url: &str, output_path: &str) -> Result<()> {
13 tracing::info!("Dumping global objects to {}", output_path);
14
15 let parts = crate::utils::parse_postgres_url(source_url)
17 .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
18 let pgpass = crate::utils::PgPassFile::new(&parts)
19 .context("Failed to create .pgpass file for authentication")?;
20
21 let env_vars = parts.to_pg_env_vars();
22 let output_path_owned = output_path.to_string();
23
24 crate::utils::retry_subprocess_with_backoff(
26 || {
27 let mut cmd = Command::new("pg_dumpall");
28 cmd.arg("--globals-only")
29 .arg("--no-role-passwords") .arg("--verbose") .arg("--host")
32 .arg(&parts.host)
33 .arg("--port")
34 .arg(parts.port.to_string())
35 .arg("--database")
36 .arg(&parts.database)
37 .arg(format!("--file={}", output_path_owned))
38 .env("PGPASSFILE", pgpass.path())
39 .stdout(Stdio::inherit())
40 .stderr(Stdio::inherit());
41
42 if let Some(user) = &parts.user {
44 cmd.arg("--username").arg(user);
45 }
46
47 for (env_var, value) in &env_vars {
49 cmd.env(env_var, value);
50 }
51
52 for (env_var, value) in crate::utils::get_keepalive_env_vars() {
54 cmd.env(env_var, value);
55 }
56
57 cmd.status().context(
58 "Failed to execute pg_dumpall. Is PostgreSQL client installed?\n\
59 Install with:\n\
60 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
61 - macOS: brew install postgresql\n\
62 - RHEL/CentOS: sudo yum install postgresql",
63 )
64 },
65 3, Duration::from_secs(1), "pg_dumpall (dump globals)",
68 )
69 .await
70 .context(
71 "pg_dumpall failed to dump global objects.\n\
72 \n\
73 Common causes:\n\
74 - Connection authentication failed\n\
75 - User lacks sufficient privileges (need SUPERUSER or pg_read_all_settings role)\n\
76 - Network connectivity issues\n\
77 - Invalid connection string\n\
78 - Connection timeout or network issues",
79 )?;
80
81 tracing::info!("✓ Global objects dumped successfully");
82 Ok(())
83}
84
85pub fn sanitize_globals_dump(path: &str) -> Result<()> {
96 let content = fs::read_to_string(path)
97 .with_context(|| format!("Failed to read globals dump at {}", path))?;
98
99 if let Some(updated) = rewrite_create_role_statements(&content) {
100 fs::write(path, updated)
101 .with_context(|| format!("Failed to update globals dump at {}", path))?;
102 }
103
104 Ok(())
105}
106
107pub fn remove_superuser_from_globals(path: &str) -> Result<()> {
113 let content = fs::read_to_string(path)
114 .with_context(|| format!("Failed to read globals dump at {}", path))?;
115
116 let mut updated = String::with_capacity(content.len());
117 let mut modified = false;
118 for line in content.lines() {
119 if line.contains("ALTER ROLE") && line.contains("SUPERUSER") {
120 updated.push_str("-- ");
121 updated.push_str(line);
122 updated.push('\n');
123 modified = true;
124 } else {
125 updated.push_str(line);
126 updated.push('\n');
127 }
128 }
129
130 if modified {
131 fs::write(path, updated)
132 .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
133 }
134
135 Ok(())
136}
137
138pub fn remove_restricted_guc_settings(path: &str) -> Result<()> {
144 let content = fs::read_to_string(path)
145 .with_context(|| format!("Failed to read globals dump at {}", path))?;
146
147 let mut updated = String::with_capacity(content.len());
148 let mut modified = false;
149
150 for line in content.lines() {
151 let lower_line = line.to_ascii_lowercase();
152 if lower_line.contains("alter role") && lower_line.contains("set") {
153 updated.push_str("-- ");
154 updated.push_str(line);
155 updated.push('\n');
156 modified = true;
157 } else {
158 updated.push_str(line);
159 updated.push('\n');
160 }
161 }
162
163 if modified {
164 fs::write(path, updated)
165 .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
166 }
167
168 Ok(())
169}
170
171pub fn remove_tablespace_statements(path: &str) -> Result<()> {
179 let content = fs::read_to_string(path)
180 .with_context(|| format!("Failed to read globals dump at {}", path))?;
181
182 let mut updated = String::with_capacity(content.len());
183 let mut modified = false;
184
185 for line in content.lines() {
186 let lower_trimmed = line.trim().to_ascii_lowercase();
187
188 let is_create_tablespace = lower_trimmed.starts_with("create tablespace");
190
191 let references_rds_tablespace = lower_trimmed.contains("'rds_")
198 || lower_trimmed.contains("\"rds_")
199 || lower_trimmed.contains("tablespace rds_");
200
201 if is_create_tablespace || references_rds_tablespace {
202 updated.push_str("-- ");
203 updated.push_str(line);
204 updated.push('\n');
205 modified = true;
206 } else {
207 updated.push_str(line);
208 updated.push('\n');
209 }
210 }
211
212 if modified {
213 fs::write(path, updated)
214 .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
215 }
216
217 Ok(())
218}
219
220pub fn remove_restricted_role_grants(path: &str) -> Result<()> {
226 const RESTRICTED_ROLES: &[&str] = &[
228 "pg_checkpoint",
229 "pg_read_all_data",
230 "pg_write_all_data",
231 "pg_read_all_settings",
232 "pg_read_all_stats",
233 "pg_stat_scan_tables",
234 "pg_monitor",
235 "pg_signal_backend",
236 "pg_read_server_files",
237 "pg_write_server_files",
238 "pg_execute_server_program",
239 "pg_create_subscription",
240 "pg_maintain",
241 "pg_use_reserved_connections",
242 ];
243
244 const RESTRICTED_GRANTORS: &[&str] = &[
246 "rdsadmin",
247 "rds_superuser",
248 "rdsrepladmin",
249 "rds_replication",
250 ];
251
252 let content = fs::read_to_string(path)
253 .with_context(|| format!("Failed to read globals dump at {}", path))?;
254
255 let mut updated = String::with_capacity(content.len());
256 let mut modified = false;
257
258 for line in content.lines() {
259 let lower_trimmed = line.trim().to_ascii_lowercase();
260 if lower_trimmed.starts_with("grant ") {
261 let is_restricted_role = RESTRICTED_ROLES.iter().any(|role| {
263 lower_trimmed
266 .split_whitespace()
267 .nth(1)
268 .map(|r| r.trim_matches('"') == *role)
269 .unwrap_or(false)
270 });
271
272 let has_restricted_grantor = RESTRICTED_GRANTORS.iter().any(|grantor| {
274 lower_trimmed.contains(&format!("granted by {}", grantor))
276 || lower_trimmed.contains(&format!("granted by \"{}\"", grantor))
277 });
278
279 if is_restricted_role || has_restricted_grantor {
280 updated.push_str("-- ");
281 updated.push_str(line);
282 updated.push('\n');
283 modified = true;
284 continue;
285 }
286 }
287
288 updated.push_str(line);
289 updated.push('\n');
290 }
291
292 if modified {
293 fs::write(path, updated)
294 .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
295 }
296
297 Ok(())
298}
299
300fn rewrite_create_role_statements(sql: &str) -> Option<String> {
301 if sql.is_empty() {
302 return None;
303 }
304
305 let mut output = String::with_capacity(sql.len() + 1024);
306 let mut modified = false;
307 let mut cursor = 0;
308
309 while cursor < sql.len() {
310 if let Some(rel_pos) = sql[cursor..].find('\n') {
311 let end = cursor + rel_pos + 1;
312 let chunk = &sql[cursor..end];
313 if let Some(transformed) = wrap_create_role_line(chunk) {
314 output.push_str(&transformed);
315 modified = true;
316 } else {
317 output.push_str(chunk);
318 }
319 cursor = end;
320 } else {
321 let chunk = &sql[cursor..];
322 if let Some(transformed) = wrap_create_role_line(chunk) {
323 output.push_str(&transformed);
324 modified = true;
325 } else {
326 output.push_str(chunk);
327 }
328 break;
329 }
330 }
331
332 if modified {
333 Some(output)
334 } else {
335 None
336 }
337}
338
339fn wrap_create_role_line(chunk: &str) -> Option<String> {
340 let trimmed = chunk.trim_start();
341 if !trimmed.starts_with("CREATE ROLE ") {
342 return None;
343 }
344
345 let statement = trimmed.trim_end();
346 let statement_body = statement.trim_end_matches(';').trim_end();
347 let leading_ws_len = chunk.len() - trimmed.len();
348 let leading_ws = &chunk[..leading_ws_len];
349 let newline = if chunk.ends_with("\r\n") {
350 "\r\n"
351 } else if chunk.ends_with('\n') {
352 "\n"
353 } else {
354 ""
355 };
356
357 let role_token = extract_role_token(statement_body)?;
358
359 let notice_name = escape_single_quotes(&unquote_role_name(&role_token));
360
361 let mut block = String::with_capacity(chunk.len() + 128);
362 block.push_str(leading_ws);
363 block.push_str("DO $$\n");
364 block.push_str(leading_ws);
365 block.push_str("BEGIN\n");
366 block.push_str(leading_ws);
367 block.push_str(" ");
368 block.push_str(statement_body);
369 block.push_str(";\n");
370 block.push_str(leading_ws);
371 block.push_str("EXCEPTION\n");
372 block.push_str(leading_ws);
373 block.push_str(" WHEN duplicate_object THEN\n");
374 block.push_str(leading_ws);
375 block.push_str(" RAISE NOTICE 'Role ");
376 block.push_str(¬ice_name);
377 block.push_str(" already exists on target, skipping CREATE ROLE';\n");
378 block.push_str(leading_ws);
379 block.push_str("END $$;");
380
381 if !newline.is_empty() {
382 block.push_str(newline);
383 }
384
385 Some(block)
386}
387
388fn extract_role_token(statement: &str) -> Option<String> {
389 let remainder = statement.strip_prefix("CREATE ROLE")?.trim_start();
390
391 if remainder.starts_with('"') {
392 let mut idx = 1;
393 let bytes = remainder.as_bytes();
394 while idx < bytes.len() {
395 if bytes[idx] == b'"' {
396 if idx + 1 < bytes.len() && bytes[idx + 1] == b'"' {
397 idx += 2;
398 continue;
399 } else {
400 idx += 1;
401 break;
402 }
403 }
404 idx += 1;
405 }
406 if idx <= remainder.len() {
407 return Some(remainder[..idx].to_string());
408 }
409 None
410 } else {
411 let mut end = remainder.len();
412 for (i, ch) in remainder.char_indices() {
413 if ch.is_whitespace() || ch == ';' {
414 end = i;
415 break;
416 }
417 }
418 if end == 0 {
419 None
420 } else {
421 Some(remainder[..end].to_string())
422 }
423 }
424}
425
426fn unquote_role_name(token: &str) -> String {
427 if token.starts_with('"') && token.ends_with('"') && token.len() >= 2 {
428 let inner = &token[1..token.len() - 1];
429 inner.replace("\"\"", "\"")
430 } else {
431 token.to_string()
432 }
433}
434
435fn escape_single_quotes(value: &str) -> String {
436 value.replace('\'', "''")
437}
438
439pub async fn dump_schema(
441 source_url: &str,
442 database: &str,
443 output_path: &str,
444 filter: &ReplicationFilter,
445) -> Result<()> {
446 tracing::info!(
447 "Dumping schema for database '{}' to {}",
448 database,
449 output_path
450 );
451
452 let parts = crate::utils::parse_postgres_url(source_url)
454 .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
455 let pgpass = crate::utils::PgPassFile::new(&parts)
456 .context("Failed to create .pgpass file for authentication")?;
457
458 let env_vars = parts.to_pg_env_vars();
459 let output_path_owned = output_path.to_string();
460
461 let exclude_tables = get_schema_excluded_tables_for_db(filter, database);
463 let include_tables = get_included_tables_for_db(filter, database);
464
465 crate::utils::retry_subprocess_with_backoff(
467 || {
468 let mut cmd = Command::new("pg_dump");
469 cmd.arg("--schema-only")
470 .arg("--no-owner") .arg("--no-privileges") .arg("--verbose"); if let Some(ref exclude) = exclude_tables {
477 if !exclude.is_empty() {
478 for table in exclude {
479 cmd.arg("--exclude-table").arg(table);
480 }
481 }
482 }
483
484 if let Some(ref include) = include_tables {
486 if !include.is_empty() {
487 for table in include {
488 cmd.arg("--table").arg(table);
489 }
490 }
491 }
492
493 cmd.arg("--host")
494 .arg(&parts.host)
495 .arg("--port")
496 .arg(parts.port.to_string())
497 .arg("--dbname")
498 .arg(&parts.database)
499 .arg(format!("--file={}", output_path_owned))
500 .env("PGPASSFILE", pgpass.path())
501 .stdout(Stdio::inherit())
502 .stderr(Stdio::inherit());
503
504 if let Some(user) = &parts.user {
506 cmd.arg("--username").arg(user);
507 }
508
509 for (env_var, value) in &env_vars {
511 cmd.env(env_var, value);
512 }
513
514 for (env_var, value) in crate::utils::get_keepalive_env_vars() {
516 cmd.env(env_var, value);
517 }
518
519 cmd.status().context(
520 "Failed to execute pg_dump. Is PostgreSQL client installed?\n\
521 Install with:\n\
522 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
523 - macOS: brew install postgresql\n\
524 - RHEL/CentOS: sudo yum install postgresql",
525 )
526 },
527 3, Duration::from_secs(1), "pg_dump (dump schema)",
530 )
531 .await
532 .with_context(|| {
533 format!(
534 "pg_dump failed to dump schema for database '{}'.\n\
535 \n\
536 Common causes:\n\
537 - Database does not exist\n\
538 - Connection authentication failed\n\
539 - User lacks privileges to read database schema\n\
540 - Network connectivity issues\n\
541 - Connection timeout or network issues",
542 database
543 )
544 })?;
545
546 tracing::info!("✓ Schema dumped successfully");
547 Ok(())
548}
549
550pub async fn dump_data(
560 source_url: &str,
561 database: &str,
562 output_path: &str,
563 filter: &ReplicationFilter,
564) -> Result<()> {
565 let num_cpus = std::thread::available_parallelism()
567 .map(|n| n.get().min(8))
568 .unwrap_or(4);
569
570 tracing::info!(
571 "Dumping data for database '{}' to {} (parallel={}, compression=9, format=directory)",
572 database,
573 output_path,
574 num_cpus
575 );
576
577 let parts = crate::utils::parse_postgres_url(source_url)
579 .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
580 let pgpass = crate::utils::PgPassFile::new(&parts)
581 .context("Failed to create .pgpass file for authentication")?;
582
583 let env_vars = parts.to_pg_env_vars();
584 let output_path_owned = output_path.to_string();
585
586 let exclude_tables = get_data_excluded_tables_for_db(filter, database);
588 let include_tables = get_included_tables_for_db(filter, database);
589
590 crate::utils::retry_subprocess_with_backoff(
592 || {
593 let mut cmd = Command::new("pg_dump");
594 cmd.arg("--data-only")
595 .arg("--no-owner")
596 .arg("--format=directory") .arg("--blobs") .arg("--compress=9") .arg(format!("--jobs={}", num_cpus)) .arg("--verbose"); if let Some(ref exclude) = exclude_tables {
605 if !exclude.is_empty() {
606 for table in exclude {
607 cmd.arg("--exclude-table-data").arg(table);
608 }
609 }
610 }
611
612 if let Some(ref include) = include_tables {
614 if !include.is_empty() {
615 for table in include {
616 cmd.arg("--table").arg(table);
617 }
618 }
619 }
620
621 cmd.arg("--host")
622 .arg(&parts.host)
623 .arg("--port")
624 .arg(parts.port.to_string())
625 .arg("--dbname")
626 .arg(&parts.database)
627 .arg(format!("--file={}", output_path_owned))
628 .env("PGPASSFILE", pgpass.path())
629 .stdout(Stdio::inherit())
630 .stderr(Stdio::inherit());
631
632 if let Some(user) = &parts.user {
634 cmd.arg("--username").arg(user);
635 }
636
637 for (env_var, value) in &env_vars {
639 cmd.env(env_var, value);
640 }
641
642 for (env_var, value) in crate::utils::get_keepalive_env_vars() {
644 cmd.env(env_var, value);
645 }
646
647 cmd.status().context(
648 "Failed to execute pg_dump. Is PostgreSQL client installed?\n\
649 Install with:\n\
650 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
651 - macOS: brew install postgresql\n\
652 - RHEL/CentOS: sudo yum install postgresql",
653 )
654 },
655 3, Duration::from_secs(1), "pg_dump (dump data)",
658 )
659 .await
660 .with_context(|| {
661 format!(
662 "pg_dump failed to dump data for database '{}'.\n\
663 \n\
664 Common causes:\n\
665 - Database does not exist\n\
666 - Connection authentication failed\n\
667 - User lacks privileges to read table data\n\
668 - Network connectivity issues\n\
669 - Insufficient disk space for dump directory\n\
670 - Output directory already exists (pg_dump requires non-existent path)\n\
671 - Connection timeout or network issues",
672 database
673 )
674 })?;
675
676 tracing::info!(
677 "✓ Data dumped successfully using {} parallel jobs",
678 num_cpus
679 );
680 Ok(())
681}
682
683fn get_schema_excluded_tables_for_db(
688 filter: &ReplicationFilter,
689 db_name: &str,
690) -> Option<Vec<String>> {
691 let mut tables = BTreeSet::new();
692
693 if let Some(explicit) = filter.exclude_tables() {
696 for full_name in explicit {
697 let parts: Vec<&str> = full_name.split('.').collect();
698 if parts.len() == 2 && parts[0] == db_name {
699 tables.insert(format!("\"public\".\"{}\"", parts[1]));
701 }
702 }
703 }
704
705 if tables.is_empty() {
706 None
707 } else {
708 Some(tables.into_iter().collect())
709 }
710}
711
712fn get_data_excluded_tables_for_db(
717 filter: &ReplicationFilter,
718 db_name: &str,
719) -> Option<Vec<String>> {
720 let mut tables = BTreeSet::new();
721
722 if let Some(explicit) = filter.exclude_tables() {
725 for full_name in explicit {
726 let parts: Vec<&str> = full_name.split('.').collect();
727 if parts.len() == 2 && parts[0] == db_name {
728 tables.insert(format!("\"public\".\"{}\"", parts[1]));
730 }
731 }
732 }
733
734 for table in filter.schema_only_tables(db_name) {
736 tables.insert(table);
737 }
738
739 for (table, _) in filter.predicate_tables(db_name) {
740 tables.insert(table);
741 }
742
743 if tables.is_empty() {
744 None
745 } else {
746 Some(tables.into_iter().collect())
747 }
748}
749
750fn get_included_tables_for_db(filter: &ReplicationFilter, db_name: &str) -> Option<Vec<String>> {
753 filter.include_tables().map(|tables| {
754 tables
755 .iter()
756 .filter_map(|full_name| {
757 let parts: Vec<&str> = full_name.split('.').collect();
758 if parts.len() == 2 && parts[0] == db_name {
759 Some(format!("\"public\".\"{}\"", parts[1]))
761 } else {
762 None
763 }
764 })
765 .collect()
766 })
767}
768
769#[cfg(test)]
770mod tests {
771 use super::*;
772 use tempfile::tempdir;
773
774 #[tokio::test]
775 #[ignore]
776 async fn test_dump_globals() {
777 let url = std::env::var("TEST_SOURCE_URL").unwrap();
778 let dir = tempdir().unwrap();
779 let output = dir.path().join("globals.sql");
780
781 let result = dump_globals(&url, output.to_str().unwrap()).await;
782
783 assert!(result.is_ok());
784 assert!(output.exists());
785
786 let content = std::fs::read_to_string(&output).unwrap();
788 assert!(content.contains("CREATE ROLE") || !content.is_empty());
789 }
790
791 #[tokio::test]
792 #[ignore]
793 async fn test_dump_schema() {
794 let url = std::env::var("TEST_SOURCE_URL").unwrap();
795 let dir = tempdir().unwrap();
796 let output = dir.path().join("schema.sql");
797
798 let db = url.split('/').next_back().unwrap_or("postgres");
800
801 let filter = crate::filters::ReplicationFilter::empty();
802 let result = dump_schema(&url, db, output.to_str().unwrap(), &filter).await;
803
804 assert!(result.is_ok());
805 assert!(output.exists());
806 }
807
808 #[test]
809 fn test_get_schema_excluded_tables_for_db() {
810 let filter = crate::filters::ReplicationFilter::new(
811 None,
812 None,
813 None,
814 Some(vec![
815 "db1.table1".to_string(),
816 "db1.table2".to_string(),
817 "db2.table3".to_string(),
818 ]),
819 )
820 .unwrap();
821
822 let tables = get_schema_excluded_tables_for_db(&filter, "db1").unwrap();
824 assert_eq!(
826 tables,
827 vec!["\"public\".\"table1\"", "\"public\".\"table2\""]
828 );
829
830 let tables = get_schema_excluded_tables_for_db(&filter, "db2").unwrap();
831 assert_eq!(tables, vec!["\"public\".\"table3\""]);
832
833 let tables = get_schema_excluded_tables_for_db(&filter, "db3");
834 assert!(tables.is_none() || tables.unwrap().is_empty());
835 }
836
837 #[test]
838 fn test_get_data_excluded_tables_for_db() {
839 let filter = crate::filters::ReplicationFilter::new(
840 None,
841 None,
842 None,
843 Some(vec![
844 "db1.table1".to_string(),
845 "db1.table2".to_string(),
846 "db2.table3".to_string(),
847 ]),
848 )
849 .unwrap();
850
851 let tables = get_data_excluded_tables_for_db(&filter, "db1").unwrap();
853 assert_eq!(
855 tables,
856 vec!["\"public\".\"table1\"", "\"public\".\"table2\""]
857 );
858
859 let tables = get_data_excluded_tables_for_db(&filter, "db2").unwrap();
860 assert_eq!(tables, vec!["\"public\".\"table3\""]);
861
862 let tables = get_data_excluded_tables_for_db(&filter, "db3");
863 assert!(tables.is_none() || tables.unwrap().is_empty());
864 }
865
866 #[test]
867 fn test_get_included_tables_for_db() {
868 let filter = crate::filters::ReplicationFilter::new(
869 None,
870 None,
871 Some(vec![
872 "db1.users".to_string(),
873 "db1.orders".to_string(),
874 "db2.products".to_string(),
875 ]),
876 None,
877 )
878 .unwrap();
879
880 let tables = get_included_tables_for_db(&filter, "db1").unwrap();
881 assert_eq!(
883 tables,
884 vec!["\"public\".\"users\"", "\"public\".\"orders\""]
885 );
886
887 let tables = get_included_tables_for_db(&filter, "db2").unwrap();
888 assert_eq!(tables, vec!["\"public\".\"products\""]);
889
890 let tables = get_included_tables_for_db(&filter, "db3");
891 assert!(tables.is_none() || tables.unwrap().is_empty());
892 }
893
894 #[test]
895 fn test_get_schema_excluded_tables_for_db_with_empty_filter() {
896 let filter = crate::filters::ReplicationFilter::empty();
897 let tables = get_schema_excluded_tables_for_db(&filter, "db1");
898 assert!(tables.is_none());
899 }
900
901 #[test]
902 fn test_get_data_excluded_tables_for_db_with_empty_filter() {
903 let filter = crate::filters::ReplicationFilter::empty();
904 let tables = get_data_excluded_tables_for_db(&filter, "db1");
905 assert!(tables.is_none());
906 }
907
908 #[test]
909 fn test_get_included_tables_for_db_with_empty_filter() {
910 let filter = crate::filters::ReplicationFilter::empty();
911 let tables = get_included_tables_for_db(&filter, "db1");
912 assert!(tables.is_none());
913 }
914
915 #[test]
916 fn test_rewrite_create_role_statements_wraps_unquoted_role() {
917 let sql = "CREATE ROLE replicator WITH LOGIN;\nALTER ROLE replicator WITH LOGIN;\n";
918 let rewritten = rewrite_create_role_statements(sql).expect("rewrite expected");
919
920 assert!(rewritten.contains("DO $$"));
921 assert!(rewritten.contains("Role replicator already exists"));
922 assert!(rewritten.contains("CREATE ROLE replicator WITH LOGIN;"));
923 assert!(rewritten.contains("ALTER ROLE replicator WITH LOGIN;"));
924 }
925
926 #[test]
927 fn test_rewrite_create_role_statements_wraps_quoted_role() {
928 let sql = " CREATE ROLE \"Andre Admin\";\n";
929 let rewritten = rewrite_create_role_statements(sql).expect("rewrite expected");
930
931 assert!(rewritten.contains("DO $$"));
932 assert!(rewritten.contains("Andre Admin already exists"));
933 assert!(rewritten.contains("CREATE ROLE \"Andre Admin\""));
934 assert!(rewritten.starts_with(" DO $$"));
935 }
936
937 #[test]
938 fn test_rewrite_create_role_statements_noop_when_absent() {
939 let sql = "ALTER ROLE existing WITH LOGIN;\n";
940 assert!(rewrite_create_role_statements(sql).is_none());
941 }
942
943 #[test]
944 fn test_remove_restricted_role_grants() {
945 let dir = tempdir().unwrap();
946 let globals_file = dir.path().join("globals.sql");
947
948 let content = r#"CREATE ROLE myuser;
950ALTER ROLE myuser WITH LOGIN;
951GRANT pg_checkpoint TO myuser;
952GRANT "pg_read_all_stats" TO myuser;
953GRANT pg_monitor TO myuser;
954GRANT myrole TO myuser;
955GRANT SELECT ON TABLE users TO myuser;
956GRANT rds_superuser TO myuser GRANTED BY rdsadmin;
957GRANT ALL ON SCHEMA public TO myuser GRANTED BY "rdsadmin";
958GRANT SELECT ON TABLE orders TO myuser GRANTED BY postgres;
959"#;
960 std::fs::write(&globals_file, content).unwrap();
961
962 remove_restricted_role_grants(globals_file.to_str().unwrap()).unwrap();
964
965 let result = std::fs::read_to_string(&globals_file).unwrap();
967
968 assert!(result.contains("-- GRANT pg_checkpoint TO myuser;"));
970 assert!(result.contains("-- GRANT \"pg_read_all_stats\" TO myuser;"));
971 assert!(result.contains("-- GRANT pg_monitor TO myuser;"));
972
973 assert!(result.contains("-- GRANT rds_superuser TO myuser GRANTED BY rdsadmin;"));
975 assert!(result.contains("-- GRANT ALL ON SCHEMA public TO myuser GRANTED BY \"rdsadmin\";"));
976
977 assert!(result.contains("\nGRANT myrole TO myuser;\n"));
979 assert!(result.contains("\nGRANT SELECT ON TABLE users TO myuser;\n"));
980 assert!(result.contains("\nGRANT SELECT ON TABLE orders TO myuser GRANTED BY postgres;\n"));
981
982 assert!(result.contains("CREATE ROLE myuser;"));
984 assert!(result.contains("ALTER ROLE myuser WITH LOGIN;"));
985 }
986}