database_replicator/migration/
dump.rs1use crate::filters::ReplicationFilter;
5use anyhow::{Context, Result};
6use std::collections::BTreeSet;
7use std::fs;
8use std::process::{Command, Stdio};
9use std::time::Duration;
10
11pub async fn dump_globals(source_url: &str, output_path: &str) -> Result<()> {
13 tracing::info!("Dumping global objects to {}", output_path);
14
15 let parts = crate::utils::parse_postgres_url(source_url)
17 .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
18 let pgpass = crate::utils::PgPassFile::new(&parts)
19 .context("Failed to create .pgpass file for authentication")?;
20
21 let env_vars = parts.to_pg_env_vars();
22 let output_path_owned = output_path.to_string();
23
24 crate::utils::retry_subprocess_with_backoff(
26 || {
27 let mut cmd = Command::new("pg_dumpall");
28 cmd.arg("--globals-only")
29 .arg("--no-role-passwords") .arg("--verbose") .arg("--host")
32 .arg(&parts.host)
33 .arg("--port")
34 .arg(parts.port.to_string())
35 .arg("--database")
36 .arg(&parts.database)
37 .arg(format!("--file={}", output_path_owned))
38 .env("PGPASSFILE", pgpass.path())
39 .stdout(Stdio::inherit())
40 .stderr(Stdio::inherit());
41
42 if let Some(user) = &parts.user {
44 cmd.arg("--username").arg(user);
45 }
46
47 for (env_var, value) in &env_vars {
49 cmd.env(env_var, value);
50 }
51
52 for (env_var, value) in crate::utils::get_keepalive_env_vars() {
54 cmd.env(env_var, value);
55 }
56
57 cmd.status().context(
58 "Failed to execute pg_dumpall. Is PostgreSQL client installed?\n\
59 Install with:\n\
60 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
61 - macOS: brew install postgresql\n\
62 - RHEL/CentOS: sudo yum install postgresql",
63 )
64 },
65 3, Duration::from_secs(1), "pg_dumpall (dump globals)",
68 )
69 .await
70 .context(
71 "pg_dumpall failed to dump global objects.\n\
72 \n\
73 Common causes:\n\
74 - Connection authentication failed\n\
75 - User lacks sufficient privileges (need SUPERUSER or pg_read_all_settings role)\n\
76 - Network connectivity issues\n\
77 - Invalid connection string\n\
78 - Connection timeout or network issues",
79 )?;
80
81 tracing::info!("✓ Global objects dumped successfully");
82 Ok(())
83}
84
85pub fn sanitize_globals_dump(path: &str) -> Result<()> {
96 let content = fs::read_to_string(path)
97 .with_context(|| format!("Failed to read globals dump at {}", path))?;
98
99 if let Some(updated) = rewrite_create_role_statements(&content) {
100 fs::write(path, updated)
101 .with_context(|| format!("Failed to update globals dump at {}", path))?;
102 }
103
104 Ok(())
105}
106
107pub fn remove_superuser_from_globals(path: &str) -> Result<()> {
113 let content = fs::read_to_string(path)
114 .with_context(|| format!("Failed to read globals dump at {}", path))?;
115
116 let mut updated = String::with_capacity(content.len());
117 let mut modified = false;
118 for line in content.lines() {
119 if line.contains("ALTER ROLE") && line.contains("SUPERUSER") {
120 updated.push_str("-- ");
121 updated.push_str(line);
122 updated.push('\n');
123 modified = true;
124 } else {
125 updated.push_str(line);
126 updated.push('\n');
127 }
128 }
129
130 if modified {
131 fs::write(path, updated)
132 .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
133 }
134
135 Ok(())
136}
137
138pub fn remove_restricted_guc_settings(path: &str) -> Result<()> {
144 let content = fs::read_to_string(path)
145 .with_context(|| format!("Failed to read globals dump at {}", path))?;
146
147 let mut updated = String::with_capacity(content.len());
148 let mut modified = false;
149
150 for line in content.lines() {
151 let lower_line = line.to_ascii_lowercase();
152 if lower_line.contains("alter role") && lower_line.contains("set") {
153 updated.push_str("-- ");
154 updated.push_str(line);
155 updated.push('\n');
156 modified = true;
157 } else {
158 updated.push_str(line);
159 updated.push('\n');
160 }
161 }
162
163 if modified {
164 fs::write(path, updated)
165 .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
166 }
167
168 Ok(())
169}
170
171pub fn remove_tablespace_statements(path: &str) -> Result<()> {
177 let content = fs::read_to_string(path)
178 .with_context(|| format!("Failed to read globals dump at {}", path))?;
179
180 let mut updated = String::with_capacity(content.len());
181 let mut modified = false;
182
183 for line in content.lines() {
184 let lower_trimmed = line.trim().to_ascii_lowercase();
185 if lower_trimmed.starts_with("create tablespace") {
186 updated.push_str("-- ");
187 updated.push_str(line);
188 updated.push('\n');
189 modified = true;
190 } else {
191 updated.push_str(line);
192 updated.push('\n');
193 }
194 }
195
196 if modified {
197 fs::write(path, updated)
198 .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
199 }
200
201 Ok(())
202}
203
204pub fn remove_restricted_role_grants(path: &str) -> Result<()> {
210 const RESTRICTED_ROLES: &[&str] = &[
212 "pg_checkpoint",
213 "pg_read_all_data",
214 "pg_write_all_data",
215 "pg_read_all_settings",
216 "pg_read_all_stats",
217 "pg_stat_scan_tables",
218 "pg_monitor",
219 "pg_signal_backend",
220 "pg_read_server_files",
221 "pg_write_server_files",
222 "pg_execute_server_program",
223 "pg_create_subscription",
224 "pg_maintain",
225 "pg_use_reserved_connections",
226 ];
227
228 const RESTRICTED_GRANTORS: &[&str] = &[
230 "rdsadmin",
231 "rds_superuser",
232 "rdsrepladmin",
233 "rds_replication",
234 ];
235
236 let content = fs::read_to_string(path)
237 .with_context(|| format!("Failed to read globals dump at {}", path))?;
238
239 let mut updated = String::with_capacity(content.len());
240 let mut modified = false;
241
242 for line in content.lines() {
243 let lower_trimmed = line.trim().to_ascii_lowercase();
244 if lower_trimmed.starts_with("grant ") {
245 let is_restricted_role = RESTRICTED_ROLES.iter().any(|role| {
247 lower_trimmed
250 .split_whitespace()
251 .nth(1)
252 .map(|r| r.trim_matches('"') == *role)
253 .unwrap_or(false)
254 });
255
256 let has_restricted_grantor = RESTRICTED_GRANTORS.iter().any(|grantor| {
258 lower_trimmed.contains(&format!("granted by {}", grantor))
260 || lower_trimmed.contains(&format!("granted by \"{}\"", grantor))
261 });
262
263 if is_restricted_role || has_restricted_grantor {
264 updated.push_str("-- ");
265 updated.push_str(line);
266 updated.push('\n');
267 modified = true;
268 continue;
269 }
270 }
271
272 updated.push_str(line);
273 updated.push('\n');
274 }
275
276 if modified {
277 fs::write(path, updated)
278 .with_context(|| format!("Failed to write sanitized globals dump to {}", path))?;
279 }
280
281 Ok(())
282}
283
284fn rewrite_create_role_statements(sql: &str) -> Option<String> {
285 if sql.is_empty() {
286 return None;
287 }
288
289 let mut output = String::with_capacity(sql.len() + 1024);
290 let mut modified = false;
291 let mut cursor = 0;
292
293 while cursor < sql.len() {
294 if let Some(rel_pos) = sql[cursor..].find('\n') {
295 let end = cursor + rel_pos + 1;
296 let chunk = &sql[cursor..end];
297 if let Some(transformed) = wrap_create_role_line(chunk) {
298 output.push_str(&transformed);
299 modified = true;
300 } else {
301 output.push_str(chunk);
302 }
303 cursor = end;
304 } else {
305 let chunk = &sql[cursor..];
306 if let Some(transformed) = wrap_create_role_line(chunk) {
307 output.push_str(&transformed);
308 modified = true;
309 } else {
310 output.push_str(chunk);
311 }
312 break;
313 }
314 }
315
316 if modified {
317 Some(output)
318 } else {
319 None
320 }
321}
322
323fn wrap_create_role_line(chunk: &str) -> Option<String> {
324 let trimmed = chunk.trim_start();
325 if !trimmed.starts_with("CREATE ROLE ") {
326 return None;
327 }
328
329 let statement = trimmed.trim_end();
330 let statement_body = statement.trim_end_matches(';').trim_end();
331 let leading_ws_len = chunk.len() - trimmed.len();
332 let leading_ws = &chunk[..leading_ws_len];
333 let newline = if chunk.ends_with("\r\n") {
334 "\r\n"
335 } else if chunk.ends_with('\n') {
336 "\n"
337 } else {
338 ""
339 };
340
341 let role_token = extract_role_token(statement_body)?;
342
343 let notice_name = escape_single_quotes(&unquote_role_name(&role_token));
344
345 let mut block = String::with_capacity(chunk.len() + 128);
346 block.push_str(leading_ws);
347 block.push_str("DO $$\n");
348 block.push_str(leading_ws);
349 block.push_str("BEGIN\n");
350 block.push_str(leading_ws);
351 block.push_str(" ");
352 block.push_str(statement_body);
353 block.push_str(";\n");
354 block.push_str(leading_ws);
355 block.push_str("EXCEPTION\n");
356 block.push_str(leading_ws);
357 block.push_str(" WHEN duplicate_object THEN\n");
358 block.push_str(leading_ws);
359 block.push_str(" RAISE NOTICE 'Role ");
360 block.push_str(¬ice_name);
361 block.push_str(" already exists on target, skipping CREATE ROLE';\n");
362 block.push_str(leading_ws);
363 block.push_str("END $$;");
364
365 if !newline.is_empty() {
366 block.push_str(newline);
367 }
368
369 Some(block)
370}
371
372fn extract_role_token(statement: &str) -> Option<String> {
373 let remainder = statement.strip_prefix("CREATE ROLE")?.trim_start();
374
375 if remainder.starts_with('"') {
376 let mut idx = 1;
377 let bytes = remainder.as_bytes();
378 while idx < bytes.len() {
379 if bytes[idx] == b'"' {
380 if idx + 1 < bytes.len() && bytes[idx + 1] == b'"' {
381 idx += 2;
382 continue;
383 } else {
384 idx += 1;
385 break;
386 }
387 }
388 idx += 1;
389 }
390 if idx <= remainder.len() {
391 return Some(remainder[..idx].to_string());
392 }
393 None
394 } else {
395 let mut end = remainder.len();
396 for (i, ch) in remainder.char_indices() {
397 if ch.is_whitespace() || ch == ';' {
398 end = i;
399 break;
400 }
401 }
402 if end == 0 {
403 None
404 } else {
405 Some(remainder[..end].to_string())
406 }
407 }
408}
409
410fn unquote_role_name(token: &str) -> String {
411 if token.starts_with('"') && token.ends_with('"') && token.len() >= 2 {
412 let inner = &token[1..token.len() - 1];
413 inner.replace("\"\"", "\"")
414 } else {
415 token.to_string()
416 }
417}
418
419fn escape_single_quotes(value: &str) -> String {
420 value.replace('\'', "''")
421}
422
423pub async fn dump_schema(
425 source_url: &str,
426 database: &str,
427 output_path: &str,
428 filter: &ReplicationFilter,
429) -> Result<()> {
430 tracing::info!(
431 "Dumping schema for database '{}' to {}",
432 database,
433 output_path
434 );
435
436 let parts = crate::utils::parse_postgres_url(source_url)
438 .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
439 let pgpass = crate::utils::PgPassFile::new(&parts)
440 .context("Failed to create .pgpass file for authentication")?;
441
442 let env_vars = parts.to_pg_env_vars();
443 let output_path_owned = output_path.to_string();
444
445 let exclude_tables = get_schema_excluded_tables_for_db(filter, database);
447 let include_tables = get_included_tables_for_db(filter, database);
448
449 crate::utils::retry_subprocess_with_backoff(
451 || {
452 let mut cmd = Command::new("pg_dump");
453 cmd.arg("--schema-only")
454 .arg("--no-owner") .arg("--no-privileges") .arg("--verbose"); if let Some(ref exclude) = exclude_tables {
461 if !exclude.is_empty() {
462 for table in exclude {
463 cmd.arg("--exclude-table").arg(table);
464 }
465 }
466 }
467
468 if let Some(ref include) = include_tables {
470 if !include.is_empty() {
471 for table in include {
472 cmd.arg("--table").arg(table);
473 }
474 }
475 }
476
477 cmd.arg("--host")
478 .arg(&parts.host)
479 .arg("--port")
480 .arg(parts.port.to_string())
481 .arg("--dbname")
482 .arg(&parts.database)
483 .arg(format!("--file={}", output_path_owned))
484 .env("PGPASSFILE", pgpass.path())
485 .stdout(Stdio::inherit())
486 .stderr(Stdio::inherit());
487
488 if let Some(user) = &parts.user {
490 cmd.arg("--username").arg(user);
491 }
492
493 for (env_var, value) in &env_vars {
495 cmd.env(env_var, value);
496 }
497
498 for (env_var, value) in crate::utils::get_keepalive_env_vars() {
500 cmd.env(env_var, value);
501 }
502
503 cmd.status().context(
504 "Failed to execute pg_dump. Is PostgreSQL client installed?\n\
505 Install with:\n\
506 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
507 - macOS: brew install postgresql\n\
508 - RHEL/CentOS: sudo yum install postgresql",
509 )
510 },
511 3, Duration::from_secs(1), "pg_dump (dump schema)",
514 )
515 .await
516 .with_context(|| {
517 format!(
518 "pg_dump failed to dump schema for database '{}'.\n\
519 \n\
520 Common causes:\n\
521 - Database does not exist\n\
522 - Connection authentication failed\n\
523 - User lacks privileges to read database schema\n\
524 - Network connectivity issues\n\
525 - Connection timeout or network issues",
526 database
527 )
528 })?;
529
530 tracing::info!("✓ Schema dumped successfully");
531 Ok(())
532}
533
534pub async fn dump_data(
544 source_url: &str,
545 database: &str,
546 output_path: &str,
547 filter: &ReplicationFilter,
548) -> Result<()> {
549 let num_cpus = std::thread::available_parallelism()
551 .map(|n| n.get().min(8))
552 .unwrap_or(4);
553
554 tracing::info!(
555 "Dumping data for database '{}' to {} (parallel={}, compression=9, format=directory)",
556 database,
557 output_path,
558 num_cpus
559 );
560
561 let parts = crate::utils::parse_postgres_url(source_url)
563 .with_context(|| format!("Failed to parse source URL: {}", source_url))?;
564 let pgpass = crate::utils::PgPassFile::new(&parts)
565 .context("Failed to create .pgpass file for authentication")?;
566
567 let env_vars = parts.to_pg_env_vars();
568 let output_path_owned = output_path.to_string();
569
570 let exclude_tables = get_data_excluded_tables_for_db(filter, database);
572 let include_tables = get_included_tables_for_db(filter, database);
573
574 crate::utils::retry_subprocess_with_backoff(
576 || {
577 let mut cmd = Command::new("pg_dump");
578 cmd.arg("--data-only")
579 .arg("--no-owner")
580 .arg("--format=directory") .arg("--blobs") .arg("--compress=9") .arg(format!("--jobs={}", num_cpus)) .arg("--verbose"); if let Some(ref exclude) = exclude_tables {
589 if !exclude.is_empty() {
590 for table in exclude {
591 cmd.arg("--exclude-table-data").arg(table);
592 }
593 }
594 }
595
596 if let Some(ref include) = include_tables {
598 if !include.is_empty() {
599 for table in include {
600 cmd.arg("--table").arg(table);
601 }
602 }
603 }
604
605 cmd.arg("--host")
606 .arg(&parts.host)
607 .arg("--port")
608 .arg(parts.port.to_string())
609 .arg("--dbname")
610 .arg(&parts.database)
611 .arg(format!("--file={}", output_path_owned))
612 .env("PGPASSFILE", pgpass.path())
613 .stdout(Stdio::inherit())
614 .stderr(Stdio::inherit());
615
616 if let Some(user) = &parts.user {
618 cmd.arg("--username").arg(user);
619 }
620
621 for (env_var, value) in &env_vars {
623 cmd.env(env_var, value);
624 }
625
626 for (env_var, value) in crate::utils::get_keepalive_env_vars() {
628 cmd.env(env_var, value);
629 }
630
631 cmd.status().context(
632 "Failed to execute pg_dump. Is PostgreSQL client installed?\n\
633 Install with:\n\
634 - Ubuntu/Debian: sudo apt-get install postgresql-client\n\
635 - macOS: brew install postgresql\n\
636 - RHEL/CentOS: sudo yum install postgresql",
637 )
638 },
639 3, Duration::from_secs(1), "pg_dump (dump data)",
642 )
643 .await
644 .with_context(|| {
645 format!(
646 "pg_dump failed to dump data for database '{}'.\n\
647 \n\
648 Common causes:\n\
649 - Database does not exist\n\
650 - Connection authentication failed\n\
651 - User lacks privileges to read table data\n\
652 - Network connectivity issues\n\
653 - Insufficient disk space for dump directory\n\
654 - Output directory already exists (pg_dump requires non-existent path)\n\
655 - Connection timeout or network issues",
656 database
657 )
658 })?;
659
660 tracing::info!(
661 "✓ Data dumped successfully using {} parallel jobs",
662 num_cpus
663 );
664 Ok(())
665}
666
667fn get_schema_excluded_tables_for_db(
672 filter: &ReplicationFilter,
673 db_name: &str,
674) -> Option<Vec<String>> {
675 let mut tables = BTreeSet::new();
676
677 if let Some(explicit) = filter.exclude_tables() {
680 for full_name in explicit {
681 let parts: Vec<&str> = full_name.split('.').collect();
682 if parts.len() == 2 && parts[0] == db_name {
683 tables.insert(format!("\"public\".\"{}\"", parts[1]));
685 }
686 }
687 }
688
689 if tables.is_empty() {
690 None
691 } else {
692 Some(tables.into_iter().collect())
693 }
694}
695
696fn get_data_excluded_tables_for_db(
701 filter: &ReplicationFilter,
702 db_name: &str,
703) -> Option<Vec<String>> {
704 let mut tables = BTreeSet::new();
705
706 if let Some(explicit) = filter.exclude_tables() {
709 for full_name in explicit {
710 let parts: Vec<&str> = full_name.split('.').collect();
711 if parts.len() == 2 && parts[0] == db_name {
712 tables.insert(format!("\"public\".\"{}\"", parts[1]));
714 }
715 }
716 }
717
718 for table in filter.schema_only_tables(db_name) {
720 tables.insert(table);
721 }
722
723 for (table, _) in filter.predicate_tables(db_name) {
724 tables.insert(table);
725 }
726
727 if tables.is_empty() {
728 None
729 } else {
730 Some(tables.into_iter().collect())
731 }
732}
733
734fn get_included_tables_for_db(filter: &ReplicationFilter, db_name: &str) -> Option<Vec<String>> {
737 filter.include_tables().map(|tables| {
738 tables
739 .iter()
740 .filter_map(|full_name| {
741 let parts: Vec<&str> = full_name.split('.').collect();
742 if parts.len() == 2 && parts[0] == db_name {
743 Some(format!("\"public\".\"{}\"", parts[1]))
745 } else {
746 None
747 }
748 })
749 .collect()
750 })
751}
752
753#[cfg(test)]
754mod tests {
755 use super::*;
756 use tempfile::tempdir;
757
758 #[tokio::test]
759 #[ignore]
760 async fn test_dump_globals() {
761 let url = std::env::var("TEST_SOURCE_URL").unwrap();
762 let dir = tempdir().unwrap();
763 let output = dir.path().join("globals.sql");
764
765 let result = dump_globals(&url, output.to_str().unwrap()).await;
766
767 assert!(result.is_ok());
768 assert!(output.exists());
769
770 let content = std::fs::read_to_string(&output).unwrap();
772 assert!(content.contains("CREATE ROLE") || !content.is_empty());
773 }
774
775 #[tokio::test]
776 #[ignore]
777 async fn test_dump_schema() {
778 let url = std::env::var("TEST_SOURCE_URL").unwrap();
779 let dir = tempdir().unwrap();
780 let output = dir.path().join("schema.sql");
781
782 let db = url.split('/').next_back().unwrap_or("postgres");
784
785 let filter = crate::filters::ReplicationFilter::empty();
786 let result = dump_schema(&url, db, output.to_str().unwrap(), &filter).await;
787
788 assert!(result.is_ok());
789 assert!(output.exists());
790 }
791
792 #[test]
793 fn test_get_schema_excluded_tables_for_db() {
794 let filter = crate::filters::ReplicationFilter::new(
795 None,
796 None,
797 None,
798 Some(vec![
799 "db1.table1".to_string(),
800 "db1.table2".to_string(),
801 "db2.table3".to_string(),
802 ]),
803 )
804 .unwrap();
805
806 let tables = get_schema_excluded_tables_for_db(&filter, "db1").unwrap();
808 assert_eq!(
810 tables,
811 vec!["\"public\".\"table1\"", "\"public\".\"table2\""]
812 );
813
814 let tables = get_schema_excluded_tables_for_db(&filter, "db2").unwrap();
815 assert_eq!(tables, vec!["\"public\".\"table3\""]);
816
817 let tables = get_schema_excluded_tables_for_db(&filter, "db3");
818 assert!(tables.is_none() || tables.unwrap().is_empty());
819 }
820
821 #[test]
822 fn test_get_data_excluded_tables_for_db() {
823 let filter = crate::filters::ReplicationFilter::new(
824 None,
825 None,
826 None,
827 Some(vec![
828 "db1.table1".to_string(),
829 "db1.table2".to_string(),
830 "db2.table3".to_string(),
831 ]),
832 )
833 .unwrap();
834
835 let tables = get_data_excluded_tables_for_db(&filter, "db1").unwrap();
837 assert_eq!(
839 tables,
840 vec!["\"public\".\"table1\"", "\"public\".\"table2\""]
841 );
842
843 let tables = get_data_excluded_tables_for_db(&filter, "db2").unwrap();
844 assert_eq!(tables, vec!["\"public\".\"table3\""]);
845
846 let tables = get_data_excluded_tables_for_db(&filter, "db3");
847 assert!(tables.is_none() || tables.unwrap().is_empty());
848 }
849
850 #[test]
851 fn test_get_included_tables_for_db() {
852 let filter = crate::filters::ReplicationFilter::new(
853 None,
854 None,
855 Some(vec![
856 "db1.users".to_string(),
857 "db1.orders".to_string(),
858 "db2.products".to_string(),
859 ]),
860 None,
861 )
862 .unwrap();
863
864 let tables = get_included_tables_for_db(&filter, "db1").unwrap();
865 assert_eq!(
867 tables,
868 vec!["\"public\".\"users\"", "\"public\".\"orders\""]
869 );
870
871 let tables = get_included_tables_for_db(&filter, "db2").unwrap();
872 assert_eq!(tables, vec!["\"public\".\"products\""]);
873
874 let tables = get_included_tables_for_db(&filter, "db3");
875 assert!(tables.is_none() || tables.unwrap().is_empty());
876 }
877
878 #[test]
879 fn test_get_schema_excluded_tables_for_db_with_empty_filter() {
880 let filter = crate::filters::ReplicationFilter::empty();
881 let tables = get_schema_excluded_tables_for_db(&filter, "db1");
882 assert!(tables.is_none());
883 }
884
885 #[test]
886 fn test_get_data_excluded_tables_for_db_with_empty_filter() {
887 let filter = crate::filters::ReplicationFilter::empty();
888 let tables = get_data_excluded_tables_for_db(&filter, "db1");
889 assert!(tables.is_none());
890 }
891
892 #[test]
893 fn test_get_included_tables_for_db_with_empty_filter() {
894 let filter = crate::filters::ReplicationFilter::empty();
895 let tables = get_included_tables_for_db(&filter, "db1");
896 assert!(tables.is_none());
897 }
898
899 #[test]
900 fn test_rewrite_create_role_statements_wraps_unquoted_role() {
901 let sql = "CREATE ROLE replicator WITH LOGIN;\nALTER ROLE replicator WITH LOGIN;\n";
902 let rewritten = rewrite_create_role_statements(sql).expect("rewrite expected");
903
904 assert!(rewritten.contains("DO $$"));
905 assert!(rewritten.contains("Role replicator already exists"));
906 assert!(rewritten.contains("CREATE ROLE replicator WITH LOGIN;"));
907 assert!(rewritten.contains("ALTER ROLE replicator WITH LOGIN;"));
908 }
909
910 #[test]
911 fn test_rewrite_create_role_statements_wraps_quoted_role() {
912 let sql = " CREATE ROLE \"Andre Admin\";\n";
913 let rewritten = rewrite_create_role_statements(sql).expect("rewrite expected");
914
915 assert!(rewritten.contains("DO $$"));
916 assert!(rewritten.contains("Andre Admin already exists"));
917 assert!(rewritten.contains("CREATE ROLE \"Andre Admin\""));
918 assert!(rewritten.starts_with(" DO $$"));
919 }
920
921 #[test]
922 fn test_rewrite_create_role_statements_noop_when_absent() {
923 let sql = "ALTER ROLE existing WITH LOGIN;\n";
924 assert!(rewrite_create_role_statements(sql).is_none());
925 }
926
927 #[test]
928 fn test_remove_restricted_role_grants() {
929 let dir = tempdir().unwrap();
930 let globals_file = dir.path().join("globals.sql");
931
932 let content = r#"CREATE ROLE myuser;
934ALTER ROLE myuser WITH LOGIN;
935GRANT pg_checkpoint TO myuser;
936GRANT "pg_read_all_stats" TO myuser;
937GRANT pg_monitor TO myuser;
938GRANT myrole TO myuser;
939GRANT SELECT ON TABLE users TO myuser;
940GRANT rds_superuser TO myuser GRANTED BY rdsadmin;
941GRANT ALL ON SCHEMA public TO myuser GRANTED BY "rdsadmin";
942GRANT SELECT ON TABLE orders TO myuser GRANTED BY postgres;
943"#;
944 std::fs::write(&globals_file, content).unwrap();
945
946 remove_restricted_role_grants(globals_file.to_str().unwrap()).unwrap();
948
949 let result = std::fs::read_to_string(&globals_file).unwrap();
951
952 assert!(result.contains("-- GRANT pg_checkpoint TO myuser;"));
954 assert!(result.contains("-- GRANT \"pg_read_all_stats\" TO myuser;"));
955 assert!(result.contains("-- GRANT pg_monitor TO myuser;"));
956
957 assert!(result.contains("-- GRANT rds_superuser TO myuser GRANTED BY rdsadmin;"));
959 assert!(result.contains("-- GRANT ALL ON SCHEMA public TO myuser GRANTED BY \"rdsadmin\";"));
960
961 assert!(result.contains("\nGRANT myrole TO myuser;\n"));
963 assert!(result.contains("\nGRANT SELECT ON TABLE users TO myuser;\n"));
964 assert!(result.contains("\nGRANT SELECT ON TABLE orders TO myuser GRANTED BY postgres;\n"));
965
966 assert!(result.contains("CREATE ROLE myuser;"));
968 assert!(result.contains("ALTER ROLE myuser WITH LOGIN;"));
969 }
970}