database_replicator/postgres/privileges.rs
1// ABOUTME: Privilege checking utilities for migration prerequisites
2// ABOUTME: Validates source and target databases have required permissions
3
4use anyhow::{Context, Result};
5use tokio_postgres::Client;
6
7/// Result of privilege check for a PostgreSQL user
8///
9/// Contains information about the user's permissions required for migration.
10pub struct PrivilegeCheck {
11 /// User has REPLICATION privilege (required for source database)
12 pub has_replication: bool,
13 /// User has CREATEDB privilege (required for target database)
14 pub has_create_db: bool,
15 /// User has CREATEROLE privilege (optional, for role migration)
16 pub has_create_role: bool,
17 /// User is a superuser (bypasses other privilege requirements)
18 pub is_superuser: bool,
19 /// User has AWS RDS rds_replication role (RDS-specific alternative to REPLICATION)
20 pub has_rds_replication: bool,
21}
22
23impl PrivilegeCheck {
24 /// Returns true if user can perform replication (any method)
25 ///
26 /// Checks for standard PostgreSQL REPLICATION privilege, superuser status,
27 /// or AWS RDS rds_replication role membership.
28 pub fn can_replicate(&self) -> bool {
29 self.has_replication || self.is_superuser || self.has_rds_replication
30 }
31}
32
33/// Check if connected user has replication privileges (needed for source)
34///
35/// Queries `pg_roles` to determine the privileges of the currently connected user.
36/// For source databases, the user must have REPLICATION privilege (or be a superuser)
37/// to enable logical replication.
38///
39/// # Arguments
40///
41/// * `client` - Connected PostgreSQL client
42///
43/// # Returns
44///
45/// Returns a `PrivilegeCheck` containing the user's privileges.
46///
47/// # Errors
48///
49/// This function will return an error if the database query fails.
50///
51/// # Examples
52///
53/// ```no_run
54/// # use anyhow::Result;
55/// # use database_replicator::postgres::{connect, check_source_privileges};
56/// # async fn example() -> Result<()> {
57/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
58/// let privs = check_source_privileges(&client).await?;
59/// assert!(privs.has_replication || privs.is_superuser);
60/// # Ok(())
61/// # }
62/// ```
63pub async fn check_source_privileges(client: &Client) -> Result<PrivilegeCheck> {
64 let row = client
65 .query_one(
66 "SELECT rolreplication, rolcreatedb, rolcreaterole, rolsuper
67 FROM pg_roles
68 WHERE rolname = current_user",
69 &[],
70 )
71 .await
72 .context("Failed to query user privileges")?;
73
74 // Check for AWS RDS rds_replication role membership
75 // This role exists only on AWS RDS and provides replication capability
76 let has_rds_replication = client
77 .query_opt(
78 "SELECT 1 FROM pg_roles WHERE rolname = 'rds_replication'
79 AND pg_has_role(current_user, 'rds_replication', 'MEMBER')",
80 &[],
81 )
82 .await
83 .unwrap_or(None)
84 .is_some();
85
86 Ok(PrivilegeCheck {
87 has_replication: row.get(0),
88 has_create_db: row.get(1),
89 has_create_role: row.get(2),
90 is_superuser: row.get(3),
91 has_rds_replication,
92 })
93}
94
95/// Check if connected user has sufficient privileges for target database
96///
97/// Queries `pg_roles` to determine the privileges of the currently connected user.
98/// For target databases, the user must have CREATEDB privilege (or be a superuser)
99/// to create new databases during migration.
100///
101/// # Arguments
102///
103/// * `client` - Connected PostgreSQL client
104///
105/// # Returns
106///
107/// Returns a `PrivilegeCheck` containing the user's privileges.
108///
109/// # Errors
110///
111/// This function will return an error if the database query fails.
112///
113/// # Examples
114///
115/// ```no_run
116/// # use anyhow::Result;
117/// # use database_replicator::postgres::{connect, check_target_privileges};
118/// # async fn example() -> Result<()> {
119/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
120/// let privs = check_target_privileges(&client).await?;
121/// assert!(privs.has_create_db || privs.is_superuser);
122/// # Ok(())
123/// # }
124/// ```
125pub async fn check_target_privileges(client: &Client) -> Result<PrivilegeCheck> {
126 // Same query as source
127 check_source_privileges(client).await
128}
129
130/// Check the wal_level setting on the target database
131///
132/// Queries the current `wal_level` configuration parameter.
133/// For logical replication (subscriptions), `wal_level` must be set to `logical`.
134///
135/// # Arguments
136///
137/// * `client` - Connected PostgreSQL client
138///
139/// # Returns
140///
141/// Returns the current `wal_level` setting as a String (e.g., "replica", "logical").
142///
143/// # Errors
144///
145/// This function will return an error if the database query fails.
146///
147/// # Examples
148///
149/// ```no_run
150/// # use anyhow::Result;
151/// # use database_replicator::postgres::{connect, check_wal_level};
152/// # async fn example() -> Result<()> {
153/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
154/// let wal_level = check_wal_level(&client).await?;
155/// assert_eq!(wal_level, "logical");
156/// # Ok(())
157/// # }
158/// ```
159pub async fn check_wal_level(client: &Client) -> Result<String> {
160 let row = client
161 .query_one("SHOW wal_level", &[])
162 .await
163 .context("Failed to query wal_level setting")?;
164
165 let wal_level: String = row.get(0);
166 Ok(wal_level)
167}
168
169/// Result of table-level permission check
170#[derive(Debug, Clone)]
171pub struct TablePermissionCheck {
172 /// Tables the user CAN read (has SELECT privilege)
173 pub accessible_tables: Vec<String>,
174 /// Tables the user CANNOT read (missing SELECT privilege)
175 pub inaccessible_tables: Vec<String>,
176}
177
178impl TablePermissionCheck {
179 /// Returns true if user has SELECT on all tables
180 pub fn all_accessible(&self) -> bool {
181 self.inaccessible_tables.is_empty()
182 }
183
184 /// Count of inaccessible tables
185 pub fn inaccessible_count(&self) -> usize {
186 self.inaccessible_tables.len()
187 }
188}
189
190/// Check SELECT permission on all user tables in a database
191///
192/// Queries pg_tables to find all user tables (excluding pg_catalog and
193/// information_schema) and checks if current user has SELECT privilege.
194///
195/// # Arguments
196///
197/// * `client` - Connected PostgreSQL client (must be connected to the target database)
198///
199/// # Returns
200///
201/// Returns `TablePermissionCheck` with lists of accessible and inaccessible tables.
202///
203/// # Errors
204///
205/// Returns an error if the permission query fails.
206///
207/// # Examples
208///
209/// ```no_run
210/// # use anyhow::Result;
211/// # use database_replicator::postgres::{connect, check_table_select_permissions};
212/// # async fn example() -> Result<()> {
213/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
214/// let perms = check_table_select_permissions(&client).await?;
215/// if !perms.all_accessible() {
216/// println!("Cannot read {} tables", perms.inaccessible_count());
217/// }
218/// # Ok(())
219/// # }
220/// ```
221pub async fn check_table_select_permissions(client: &Client) -> Result<TablePermissionCheck> {
222 // Query all user tables and check SELECT permission
223 let query = r#"
224 SELECT
225 schemaname,
226 tablename,
227 has_table_privilege(current_user, quote_ident(schemaname) || '.' || quote_ident(tablename), 'SELECT') as has_select
228 FROM pg_tables
229 WHERE schemaname NOT IN ('pg_catalog', 'information_schema')
230 ORDER BY schemaname, tablename
231 "#;
232
233 let rows = client
234 .query(query, &[])
235 .await
236 .context("Failed to query table permissions")?;
237
238 let mut accessible = Vec::new();
239 let mut inaccessible = Vec::new();
240
241 for row in rows {
242 let schema: String = row.get(0);
243 let table: String = row.get(1);
244 let has_select: bool = row.get(2);
245
246 let full_name = format!("{}.{}", schema, table);
247
248 if has_select {
249 accessible.push(full_name);
250 } else {
251 inaccessible.push(full_name);
252 }
253 }
254
255 Ok(TablePermissionCheck {
256 accessible_tables: accessible,
257 inaccessible_tables: inaccessible,
258 })
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use crate::postgres::connect;
265
266 #[tokio::test]
267 #[ignore]
268 async fn test_check_source_privileges() {
269 let url = std::env::var("TEST_SOURCE_URL").unwrap();
270 let client = connect(&url).await.unwrap();
271
272 let privileges = check_source_privileges(&client).await.unwrap();
273
274 // Should have at least one replication method
275 assert!(
276 privileges.can_replicate(),
277 "Source user should have REPLICATION privilege, rds_replication role, or be superuser"
278 );
279 }
280
281 #[tokio::test]
282 #[ignore]
283 async fn test_check_target_privileges() {
284 let url = std::env::var("TEST_TARGET_URL").unwrap();
285 let client = connect(&url).await.unwrap();
286
287 let privileges = check_target_privileges(&client).await.unwrap();
288
289 // Should have create privileges or be superuser
290 assert!(
291 privileges.has_create_db || privileges.is_superuser,
292 "Target user should have CREATE DATABASE privilege or be superuser"
293 );
294 }
295
296 #[tokio::test]
297 #[ignore] // Requires database connection
298 async fn test_check_table_select_permissions() {
299 let url = std::env::var("TEST_SOURCE_URL").expect("TEST_SOURCE_URL not set");
300 let client = connect(&url).await.unwrap();
301
302 let result = check_table_select_permissions(&client).await.unwrap();
303
304 // Just verify the function runs without error
305 // In a real database, results depend on actual permissions
306 println!("Accessible tables: {}", result.accessible_tables.len());
307 println!("Inaccessible tables: {}", result.inaccessible_tables.len());
308 }
309
310 #[test]
311 fn test_table_permission_check_struct() {
312 let check = TablePermissionCheck {
313 accessible_tables: vec!["public.users".to_string()],
314 inaccessible_tables: vec![],
315 };
316 assert!(check.all_accessible());
317 assert_eq!(check.inaccessible_count(), 0);
318
319 let check_with_issues = TablePermissionCheck {
320 accessible_tables: vec!["public.users".to_string()],
321 inaccessible_tables: vec!["public.secrets".to_string()],
322 };
323 assert!(!check_with_issues.all_accessible());
324 assert_eq!(check_with_issues.inaccessible_count(), 1);
325 }
326}