database_replicator/migration/
dump.rs1use crate::filters::ReplicationFilter;
5use anyhow::{Context, Result};
6use std::collections::BTreeSet;
7use std::fs;
8use std::process::{Command, Stdio};
9use std::time::Duration;
10
11pub async fn dump_globals(source_url: &str, output_path: &str) -> Result<()> {
13 tracing::info!("Dumping global objects to {}", output_path);
14
15 let parts = crate::utils::parse_postgres_url(source_url)
17 .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
18 let pgpass = crate::utils::PgPassFile::new(&parts)
19 .context("Failed to create .pgpass file for authentication")?;
20
21 let env_vars = parts.to_pg_env_vars();
22 let output_path_owned = output_path.to_string();
23
24 crate::utils::retry_subprocess_with_backoff(
26 || {
27 let mut cmd = Command::new("pg_dumpall");
28 cmd.arg("--globals-only")
29 .arg("--no-role-passwords") .arg("--verbose") .arg("--host")
32 .arg(&parts.host)
33 .arg("--port")
34 .arg(parts.port.to_string())
35 .arg("--database")
36 .arg(&parts.database)
37 .arg(format!("--file={}", output_path_owned))
38 .env("PGPASSFILE", pgpass.path())
39 .stdout(Stdio::inherit())
40 .stderr(Stdio::inherit());
41
42 if let Some(user) = &parts.user {
44 cmd.arg("--username").arg(user);
45 }
46
47 for (env_var, value) in &env_vars {
49 cmd.env(env_var, value);
50 }
51
52 for (env_var, value) in crate::utils::get_keepalive_env_vars() {
54 cmd.env(env_var, value);
55 }
56
57 cmd.status().context(
58 "Failed to execute pg_dumpall. Is PostgreSQL client installed?\n\
59 Install with:\n\
60 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
61 - macOS: brew install postgresql\n\
62 - RHEL/CentOS: sudo yum install postgresql",
63 )
64 },
65 3, Duration::from_secs(1), "pg_dumpall (dump globals)",
68 )
69 .await
70 .context(
71 "pg_dumpall failed to dump global objects.\n\
72 \n\
73 Common causes:\n\
74 - Connection authentication failed\n\
75 - User lacks sufficient privileges (need SUPERUSER or pg_read_all_settings role)\n\
76 - Network connectivity issues\n\
77 - Invalid connection string\n\
78 - Connection timeout or network issues",
79 )?;
80
81 tracing::info!("✓ Global objects dumped successfully");
82 Ok(())
83}
84
85pub fn sanitize_globals_dump(path: &str) -> Result<()> {
96 let content = fs::read_to_string(path)
97 .with_context(|| format!("Failed to read globals dump at {}", path))?;
98
99 if let Some(updated) = rewrite_create_role_statements(&content) {
100 fs::write(path, updated)
101 .with_context(|| format!("Failed to update globals dump at {}", path))?;
102 }
103
104 Ok(())
105}
106
107pub fn remove_superuser_from_globals(path: &str) -> Result<()> {
113 let content = fs::read_to_string(path)
114 .with_context(|| format!("Failed to read globals dump at {}", path))?;
115
116 let mut updated = String::with_capacity(content.len());
117 let mut modified = false;
118 for line in content.lines() {
119 if line.contains("ALTER ROLE") && line.contains("SUPERUSER") {
120 updated.push_str("-- ");
121 updated.push_str(line);
122 updated.push('\n');
123 modified = true;
124 } else {
125 updated.push_str(line);
126 updated.push('\n');
127 }
128 }
129
130 if modified {
131 fs::write(path, updated)
132 .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
133 }
134
135 Ok(())
136}
137
138pub fn remove_restricted_guc_settings(path: &str) -> Result<()> {
144 let content = fs::read_to_string(path)
145 .with_context(|| format!("Failed to read globals dump at {}", path))?;
146
147 let mut updated = String::with_capacity(content.len());
148 let mut modified = false;
149
150 for line in content.lines() {
151 let lower_line = line.to_ascii_lowercase();
152 if lower_line.contains("alter role") && lower_line.contains("set") {
153 updated.push_str("-- ");
154 updated.push_str(line);
155 updated.push('\n');
156 modified = true;
157 } else {
158 updated.push_str(line);
159 updated.push('\n');
160 }
161 }
162
163 if modified {
164 fs::write(path, updated)
165 .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
166 }
167
168 Ok(())
169}
170
171pub fn remove_tablespace_statements(path: &str) -> Result<()> {
178 let content = fs::read_to_string(path)
179 .with_context(|| format!("Failed to read globals dump at {}", path))?;
180
181 let mut updated = String::with_capacity(content.len());
182 let mut modified = false;
183
184 for line in content.lines() {
185 let lower_trimmed = line.trim().to_ascii_lowercase();
186
187 let is_create_tablespace = lower_trimmed.starts_with("create tablespace");
189
190 let references_rds_tablespace =
193 lower_trimmed.contains("'rds_") || lower_trimmed.contains("\"rds_");
194
195 if is_create_tablespace || references_rds_tablespace {
196 updated.push_str("-- ");
197 updated.push_str(line);
198 updated.push('\n');
199 modified = true;
200 } else {
201 updated.push_str(line);
202 updated.push('\n');
203 }
204 }
205
206 if modified {
207 fs::write(path, updated)
208 .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
209 }
210
211 Ok(())
212}
213
214pub fn remove_restricted_role_grants(path: &str) -> Result<()> {
220 const RESTRICTED_ROLES: &[&str] = &[
222 "pg_checkpoint",
223 "pg_read_all_data",
224 "pg_write_all_data",
225 "pg_read_all_settings",
226 "pg_read_all_stats",
227 "pg_stat_scan_tables",
228 "pg_monitor",
229 "pg_signal_backend",
230 "pg_read_server_files",
231 "pg_write_server_files",
232 "pg_execute_server_program",
233 "pg_create_subscription",
234 "pg_maintain",
235 "pg_use_reserved_connections",
236 ];
237
238 const RESTRICTED_GRANTORS: &[&str] = &[
240 "rdsadmin",
241 "rds_superuser",
242 "rdsrepladmin",
243 "rds_replication",
244 ];
245
246 let content = fs::read_to_string(path)
247 .with_context(|| format!("Failed to read globals dump at {}", path))?;
248
249 let mut updated = String::with_capacity(content.len());
250 let mut modified = false;
251
252 for line in content.lines() {
253 let lower_trimmed = line.trim().to_ascii_lowercase();
254 if lower_trimmed.starts_with("grant ") {
255 let is_restricted_role = RESTRICTED_ROLES.iter().any(|role| {
257 lower_trimmed
260 .split_whitespace()
261 .nth(1)
262 .map(|r| r.trim_matches('"') == *role)
263 .unwrap_or(false)
264 });
265
266 let has_restricted_grantor = RESTRICTED_GRANTORS.iter().any(|grantor| {
268 lower_trimmed.contains(&format!("granted by {}", grantor))
270 || lower_trimmed.contains(&format!("granted by \"{}\"", grantor))
271 });
272
273 if is_restricted_role || has_restricted_grantor {
274 updated.push_str("-- ");
275 updated.push_str(line);
276 updated.push('\n');
277 modified = true;
278 continue;
279 }
280 }
281
282 updated.push_str(line);
283 updated.push('\n');
284 }
285
286 if modified {
287 fs::write(path, updated)
288 .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
289 }
290
291 Ok(())
292}
293
294fn rewrite_create_role_statements(sql: &str) -> Option<String> {
295 if sql.is_empty() {
296 return None;
297 }
298
299 let mut output = String::with_capacity(sql.len() + 1024);
300 let mut modified = false;
301 let mut cursor = 0;
302
303 while cursor < sql.len() {
304 if let Some(rel_pos) = sql[cursor..].find('\n') {
305 let end = cursor + rel_pos + 1;
306 let chunk = &sql[cursor..end];
307 if let Some(transformed) = wrap_create_role_line(chunk) {
308 output.push_str(&transformed);
309 modified = true;
310 } else {
311 output.push_str(chunk);
312 }
313 cursor = end;
314 } else {
315 let chunk = &sql[cursor..];
316 if let Some(transformed) = wrap_create_role_line(chunk) {
317 output.push_str(&transformed);
318 modified = true;
319 } else {
320 output.push_str(chunk);
321 }
322 break;
323 }
324 }
325
326 if modified {
327 Some(output)
328 } else {
329 None
330 }
331}
332
333fn wrap_create_role_line(chunk: &str) -> Option<String> {
334 let trimmed = chunk.trim_start();
335 if !trimmed.starts_with("CREATE ROLE ") {
336 return None;
337 }
338
339 let statement = trimmed.trim_end();
340 let statement_body = statement.trim_end_matches(';').trim_end();
341 let leading_ws_len = chunk.len() - trimmed.len();
342 let leading_ws = &chunk[..leading_ws_len];
343 let newline = if chunk.ends_with("\r\n") {
344 "\r\n"
345 } else if chunk.ends_with('\n') {
346 "\n"
347 } else {
348 ""
349 };
350
351 let role_token = extract_role_token(statement_body)?;
352
353 let notice_name = escape_single_quotes(&unquote_role_name(&role_token));
354
355 let mut block = String::with_capacity(chunk.len() + 128);
356 block.push_str(leading_ws);
357 block.push_str("DO $$\n");
358 block.push_str(leading_ws);
359 block.push_str("BEGIN\n");
360 block.push_str(leading_ws);
361 block.push_str(" ");
362 block.push_str(statement_body);
363 block.push_str(";\n");
364 block.push_str(leading_ws);
365 block.push_str("EXCEPTION\n");
366 block.push_str(leading_ws);
367 block.push_str(" WHEN duplicate_object THEN\n");
368 block.push_str(leading_ws);
369 block.push_str(" RAISE NOTICE 'Role ");
370 block.push_str(¬ice_name);
371 block.push_str(" already exists on target, skipping CREATE ROLE';\n");
372 block.push_str(leading_ws);
373 block.push_str("END $$;");
374
375 if !newline.is_empty() {
376 block.push_str(newline);
377 }
378
379 Some(block)
380}
381
382fn extract_role_token(statement: &str) -> Option<String> {
383 let remainder = statement.strip_prefix("CREATE ROLE")?.trim_start();
384
385 if remainder.starts_with('"') {
386 let mut idx = 1;
387 let bytes = remainder.as_bytes();
388 while idx < bytes.len() {
389 if bytes[idx] == b'"' {
390 if idx + 1 < bytes.len() && bytes[idx + 1] == b'"' {
391 idx += 2;
392 continue;
393 } else {
394 idx += 1;
395 break;
396 }
397 }
398 idx += 1;
399 }
400 if idx <= remainder.len() {
401 return Some(remainder[..idx].to_string());
402 }
403 None
404 } else {
405 let mut end = remainder.len();
406 for (i, ch) in remainder.char_indices() {
407 if ch.is_whitespace() || ch == ';' {
408 end = i;
409 break;
410 }
411 }
412 if end == 0 {
413 None
414 } else {
415 Some(remainder[..end].to_string())
416 }
417 }
418}
419
420fn unquote_role_name(token: &str) -> String {
421 if token.starts_with('"') && token.ends_with('"') && token.len() >= 2 {
422 let inner = &token[1..token.len() - 1];
423 inner.replace("\"\"", "\"")
424 } else {
425 token.to_string()
426 }
427}
428
429fn escape_single_quotes(value: &str) -> String {
430 value.replace('\'', "''")
431}
432
433pub async fn dump_schema(
435 source_url: &str,
436 database: &str,
437 output_path: &str,
438 filter: &ReplicationFilter,
439) -> Result<()> {
440 tracing::info!(
441 "Dumping schema for database '{}' to {}",
442 database,
443 output_path
444 );
445
446 let parts = crate::utils::parse_postgres_url(source_url)
448 .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
449 let pgpass = crate::utils::PgPassFile::new(&parts)
450 .context("Failed to create .pgpass file for authentication")?;
451
452 let env_vars = parts.to_pg_env_vars();
453 let output_path_owned = output_path.to_string();
454
455 let exclude_tables = get_schema_excluded_tables_for_db(filter, database);
457 let include_tables = get_included_tables_for_db(filter, database);
458
459 crate::utils::retry_subprocess_with_backoff(
461 || {
462 let mut cmd = Command::new("pg_dump");
463 cmd.arg("--schema-only")
464 .arg("--no-owner") .arg("--no-privileges") .arg("--verbose"); if let Some(ref exclude) = exclude_tables {
471 if !exclude.is_empty() {
472 for table in exclude {
473 cmd.arg("--exclude-table").arg(table);
474 }
475 }
476 }
477
478 if let Some(ref include) = include_tables {
480 if !include.is_empty() {
481 for table in include {
482 cmd.arg("--table").arg(table);
483 }
484 }
485 }
486
487 cmd.arg("--host")
488 .arg(&parts.host)
489 .arg("--port")
490 .arg(parts.port.to_string())
491 .arg("--dbname")
492 .arg(&parts.database)
493 .arg(format!("--file={}", output_path_owned))
494 .env("PGPASSFILE", pgpass.path())
495 .stdout(Stdio::inherit())
496 .stderr(Stdio::inherit());
497
498 if let Some(user) = &parts.user {
500 cmd.arg("--username").arg(user);
501 }
502
503 for (env_var, value) in &env_vars {
505 cmd.env(env_var, value);
506 }
507
508 for (env_var, value) in crate::utils::get_keepalive_env_vars() {
510 cmd.env(env_var, value);
511 }
512
513 cmd.status().context(
514 "Failed to execute pg_dump. Is PostgreSQL client installed?\n\
515 Install with:\n\
516 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
517 - macOS: brew install postgresql\n\
518 - RHEL/CentOS: sudo yum install postgresql",
519 )
520 },
521 3, Duration::from_secs(1), "pg_dump (dump schema)",
524 )
525 .await
526 .with_context(|| {
527 format!(
528 "pg_dump failed to dump schema for database '{}'.\n\
529 \n\
530 Common causes:\n\
531 - Database does not exist\n\
532 - Connection authentication failed\n\
533 - User lacks privileges to read database schema\n\
534 - Network connectivity issues\n\
535 - Connection timeout or network issues",
536 database
537 )
538 })?;
539
540 tracing::info!("✓ Schema dumped successfully");
541 Ok(())
542}
543
544pub async fn dump_data(
554 source_url: &str,
555 database: &str,
556 output_path: &str,
557 filter: &ReplicationFilter,
558) -> Result<()> {
559 let num_cpus = std::thread::available_parallelism()
561 .map(|n| n.get().min(8))
562 .unwrap_or(4);
563
564 tracing::info!(
565 "Dumping data for database '{}' to {} (parallel={}, compression=9, format=directory)",
566 database,
567 output_path,
568 num_cpus
569 );
570
571 let parts = crate::utils::parse_postgres_url(source_url)
573 .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
574 let pgpass = crate::utils::PgPassFile::new(&parts)
575 .context("Failed to create .pgpass file for authentication")?;
576
577 let env_vars = parts.to_pg_env_vars();
578 let output_path_owned = output_path.to_string();
579
580 let exclude_tables = get_data_excluded_tables_for_db(filter, database);
582 let include_tables = get_included_tables_for_db(filter, database);
583
584 crate::utils::retry_subprocess_with_backoff(
586 || {
587 let mut cmd = Command::new("pg_dump");
588 cmd.arg("--data-only")
589 .arg("--no-owner")
590 .arg("--format=directory") .arg("--blobs") .arg("--compress=9") .arg(format!("--jobs={}", num_cpus)) .arg("--verbose"); if let Some(ref exclude) = exclude_tables {
599 if !exclude.is_empty() {
600 for table in exclude {
601 cmd.arg("--exclude-table-data").arg(table);
602 }
603 }
604 }
605
606 if let Some(ref include) = include_tables {
608 if !include.is_empty() {
609 for table in include {
610 cmd.arg("--table").arg(table);
611 }
612 }
613 }
614
615 cmd.arg("--host")
616 .arg(&parts.host)
617 .arg("--port")
618 .arg(parts.port.to_string())
619 .arg("--dbname")
620 .arg(&parts.database)
621 .arg(format!("--file={}", output_path_owned))
622 .env("PGPASSFILE", pgpass.path())
623 .stdout(Stdio::inherit())
624 .stderr(Stdio::inherit());
625
626 if let Some(user) = &parts.user {
628 cmd.arg("--username").arg(user);
629 }
630
631 for (env_var, value) in &env_vars {
633 cmd.env(env_var, value);
634 }
635
636 for (env_var, value) in crate::utils::get_keepalive_env_vars() {
638 cmd.env(env_var, value);
639 }
640
641 cmd.status().context(
642 "Failed to execute pg_dump. Is PostgreSQL client installed?\n\
643 Install with:\n\
644 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
645 - macOS: brew install postgresql\n\
646 - RHEL/CentOS: sudo yum install postgresql",
647 )
648 },
649 3, Duration::from_secs(1), "pg_dump (dump data)",
652 )
653 .await
654 .with_context(|| {
655 format!(
656 "pg_dump failed to dump data for database '{}'.\n\
657 \n\
658 Common causes:\n\
659 - Database does not exist\n\
660 - Connection authentication failed\n\
661 - User lacks privileges to read table data\n\
662 - Network connectivity issues\n\
663 - Insufficient disk space for dump directory\n\
664 - Output directory already exists (pg_dump requires non-existent path)\n\
665 - Connection timeout or network issues",
666 database
667 )
668 })?;
669
670 tracing::info!(
671 "✓ Data dumped successfully using {} parallel jobs",
672 num_cpus
673 );
674 Ok(())
675}
676
677fn get_schema_excluded_tables_for_db(
682 filter: &ReplicationFilter,
683 db_name: &str,
684) -> Option<Vec<String>> {
685 let mut tables = BTreeSet::new();
686
687 if let Some(explicit) = filter.exclude_tables() {
690 for full_name in explicit {
691 let parts: Vec<&str> = full_name.split('.').collect();
692 if parts.len() == 2 && parts[0] == db_name {
693 tables.insert(format!("\"public\".\"{}\"", parts[1]));
695 }
696 }
697 }
698
699 if tables.is_empty() {
700 None
701 } else {
702 Some(tables.into_iter().collect())
703 }
704}
705
706fn get_data_excluded_tables_for_db(
711 filter: &ReplicationFilter,
712 db_name: &str,
713) -> Option<Vec<String>> {
714 let mut tables = BTreeSet::new();
715
716 if let Some(explicit) = filter.exclude_tables() {
719 for full_name in explicit {
720 let parts: Vec<&str> = full_name.split('.').collect();
721 if parts.len() == 2 && parts[0] == db_name {
722 tables.insert(format!("\"public\".\"{}\"", parts[1]));
724 }
725 }
726 }
727
728 for table in filter.schema_only_tables(db_name) {
730 tables.insert(table);
731 }
732
733 for (table, _) in filter.predicate_tables(db_name) {
734 tables.insert(table);
735 }
736
737 if tables.is_empty() {
738 None
739 } else {
740 Some(tables.into_iter().collect())
741 }
742}
743
744fn get_included_tables_for_db(filter: &ReplicationFilter, db_name: &str) -> Option<Vec<String>> {
747 filter.include_tables().map(|tables| {
748 tables
749 .iter()
750 .filter_map(|full_name| {
751 let parts: Vec<&str> = full_name.split('.').collect();
752 if parts.len() == 2 && parts[0] == db_name {
753 Some(format!("\"public\".\"{}\"", parts[1]))
755 } else {
756 None
757 }
758 })
759 .collect()
760 })
761}
762
763#[cfg(test)]
764mod tests {
765 use super::*;
766 use tempfile::tempdir;
767
768 #[tokio::test]
769 #[ignore]
770 async fn test_dump_globals() {
771 let url = std::env::var("TEST_SOURCE_URL").unwrap();
772 let dir = tempdir().unwrap();
773 let output = dir.path().join("globals.sql");
774
775 let result = dump_globals(&url, output.to_str().unwrap()).await;
776
777 assert!(result.is_ok());
778 assert!(output.exists());
779
780 let content = std::fs::read_to_string(&output).unwrap();
782 assert!(content.contains("CREATE ROLE") || !content.is_empty());
783 }
784
785 #[tokio::test]
786 #[ignore]
787 async fn test_dump_schema() {
788 let url = std::env::var("TEST_SOURCE_URL").unwrap();
789 let dir = tempdir().unwrap();
790 let output = dir.path().join("schema.sql");
791
792 let db = url.split('/').next_back().unwrap_or("postgres");
794
795 let filter = crate::filters::ReplicationFilter::empty();
796 let result = dump_schema(&url, db, output.to_str().unwrap(), &filter).await;
797
798 assert!(result.is_ok());
799 assert!(output.exists());
800 }
801
802 #[test]
803 fn test_get_schema_excluded_tables_for_db() {
804 let filter = crate::filters::ReplicationFilter::new(
805 None,
806 None,
807 None,
808 Some(vec![
809 "db1.table1".to_string(),
810 "db1.table2".to_string(),
811 "db2.table3".to_string(),
812 ]),
813 )
814 .unwrap();
815
816 let tables = get_schema_excluded_tables_for_db(&filter, "db1").unwrap();
818 assert_eq!(
820 tables,
821 vec!["\"public\".\"table1\"", "\"public\".\"table2\""]
822 );
823
824 let tables = get_schema_excluded_tables_for_db(&filter, "db2").unwrap();
825 assert_eq!(tables, vec!["\"public\".\"table3\""]);
826
827 let tables = get_schema_excluded_tables_for_db(&filter, "db3");
828 assert!(tables.is_none() || tables.unwrap().is_empty());
829 }
830
831 #[test]
832 fn test_get_data_excluded_tables_for_db() {
833 let filter = crate::filters::ReplicationFilter::new(
834 None,
835 None,
836 None,
837 Some(vec![
838 "db1.table1".to_string(),
839 "db1.table2".to_string(),
840 "db2.table3".to_string(),
841 ]),
842 )
843 .unwrap();
844
845 let tables = get_data_excluded_tables_for_db(&filter, "db1").unwrap();
847 assert_eq!(
849 tables,
850 vec!["\"public\".\"table1\"", "\"public\".\"table2\""]
851 );
852
853 let tables = get_data_excluded_tables_for_db(&filter, "db2").unwrap();
854 assert_eq!(tables, vec!["\"public\".\"table3\""]);
855
856 let tables = get_data_excluded_tables_for_db(&filter, "db3");
857 assert!(tables.is_none() || tables.unwrap().is_empty());
858 }
859
860 #[test]
861 fn test_get_included_tables_for_db() {
862 let filter = crate::filters::ReplicationFilter::new(
863 None,
864 None,
865 Some(vec![
866 "db1.users".to_string(),
867 "db1.orders".to_string(),
868 "db2.products".to_string(),
869 ]),
870 None,
871 )
872 .unwrap();
873
874 let tables = get_included_tables_for_db(&filter, "db1").unwrap();
875 assert_eq!(
877 tables,
878 vec!["\"public\".\"users\"", "\"public\".\"orders\""]
879 );
880
881 let tables = get_included_tables_for_db(&filter, "db2").unwrap();
882 assert_eq!(tables, vec!["\"public\".\"products\""]);
883
884 let tables = get_included_tables_for_db(&filter, "db3");
885 assert!(tables.is_none() || tables.unwrap().is_empty());
886 }
887
888 #[test]
889 fn test_get_schema_excluded_tables_for_db_with_empty_filter() {
890 let filter = crate::filters::ReplicationFilter::empty();
891 let tables = get_schema_excluded_tables_for_db(&filter, "db1");
892 assert!(tables.is_none());
893 }
894
895 #[test]
896 fn test_get_data_excluded_tables_for_db_with_empty_filter() {
897 let filter = crate::filters::ReplicationFilter::empty();
898 let tables = get_data_excluded_tables_for_db(&filter, "db1");
899 assert!(tables.is_none());
900 }
901
902 #[test]
903 fn test_get_included_tables_for_db_with_empty_filter() {
904 let filter = crate::filters::ReplicationFilter::empty();
905 let tables = get_included_tables_for_db(&filter, "db1");
906 assert!(tables.is_none());
907 }
908
909 #[test]
910 fn test_rewrite_create_role_statements_wraps_unquoted_role() {
911 let sql = "CREATE ROLE replicator WITH LOGIN;\nALTER ROLE replicator WITH LOGIN;\n";
912 let rewritten = rewrite_create_role_statements(sql).expect("rewrite expected");
913
914 assert!(rewritten.contains("DO $$"));
915 assert!(rewritten.contains("Role replicator already exists"));
916 assert!(rewritten.contains("CREATE ROLE replicator WITH LOGIN;"));
917 assert!(rewritten.contains("ALTER ROLE replicator WITH LOGIN;"));
918 }
919
920 #[test]
921 fn test_rewrite_create_role_statements_wraps_quoted_role() {
922 let sql = " CREATE ROLE \"Andre Admin\";\n";
923 let rewritten = rewrite_create_role_statements(sql).expect("rewrite expected");
924
925 assert!(rewritten.contains("DO $$"));
926 assert!(rewritten.contains("Andre Admin already exists"));
927 assert!(rewritten.contains("CREATE ROLE \"Andre Admin\""));
928 assert!(rewritten.starts_with(" DO $$"));
929 }
930
931 #[test]
932 fn test_rewrite_create_role_statements_noop_when_absent() {
933 let sql = "ALTER ROLE existing WITH LOGIN;\n";
934 assert!(rewrite_create_role_statements(sql).is_none());
935 }
936
937 #[test]
938 fn test_remove_restricted_role_grants() {
939 let dir = tempdir().unwrap();
940 let globals_file = dir.path().join("globals.sql");
941
942 let content = r#"CREATE ROLE myuser;
944ALTER ROLE myuser WITH LOGIN;
945GRANT pg_checkpoint TO myuser;
946GRANT "pg_read_all_stats" TO myuser;
947GRANT pg_monitor TO myuser;
948GRANT myrole TO myuser;
949GRANT SELECT ON TABLE users TO myuser;
950GRANT rds_superuser TO myuser GRANTED BY rdsadmin;
951GRANT ALL ON SCHEMA public TO myuser GRANTED BY "rdsadmin";
952GRANT SELECT ON TABLE orders TO myuser GRANTED BY postgres;
953"#;
954 std::fs::write(&globals_file, content).unwrap();
955
956 remove_restricted_role_grants(globals_file.to_str().unwrap()).unwrap();
958
959 let result = std::fs::read_to_string(&globals_file).unwrap();
961
962 assert!(result.contains("-- GRANT pg_checkpoint TO myuser;"));
964 assert!(result.contains("-- GRANT \"pg_read_all_stats\" TO myuser;"));
965 assert!(result.contains("-- GRANT pg_monitor TO myuser;"));
966
967 assert!(result.contains("-- GRANT rds_superuser TO myuser GRANTED BY rdsadmin;"));
969 assert!(result.contains("-- GRANT ALL ON SCHEMA public TO myuser GRANTED BY \"rdsadmin\";"));
970
971 assert!(result.contains("\nGRANT myrole TO myuser;\n"));
973 assert!(result.contains("\nGRANT SELECT ON TABLE users TO myuser;\n"));
974 assert!(result.contains("\nGRANT SELECT ON TABLE orders TO myuser GRANTED BY postgres;\n"));
975
976 assert!(result.contains("CREATE ROLE myuser;"));
978 assert!(result.contains("ALTER ROLE myuser WITH LOGIN;"));
979 }
980}