database_replicator/postgres/
privileges.rs

1// ABOUTME: Privilege checking utilities for migration prerequisites
2// ABOUTME: Validates source and target databases have required permissions
3
4use anyhow::{Context, Result};
5use tokio_postgres::Client;
6
7/// Result of privilege check for a PostgreSQL user
8///
9/// Contains information about the user's permissions required for migration.
10pub struct PrivilegeCheck {
11    /// User has REPLICATION privilege (required for source database)
12    pub has_replication: bool,
13    /// User has CREATEDB privilege (required for target database)
14    pub has_create_db: bool,
15    /// User has CREATEROLE privilege (optional, for role migration)
16    pub has_create_role: bool,
17    /// User is a superuser (bypasses other privilege requirements)
18    pub is_superuser: bool,
19    /// User has AWS RDS rds_replication role (RDS-specific alternative to REPLICATION)
20    pub has_rds_replication: bool,
21}
22
23impl PrivilegeCheck {
24    /// Returns true if user can perform replication (any method)
25    ///
26    /// Checks for standard PostgreSQL REPLICATION privilege, superuser status,
27    /// or AWS RDS rds_replication role membership.
28    pub fn can_replicate(&self) -> bool {
29        self.has_replication || self.is_superuser || self.has_rds_replication
30    }
31}
32
33/// Check if connected user has replication privileges (needed for source)
34///
35/// Queries `pg_roles` to determine the privileges of the currently connected user.
36/// For source databases, the user must have REPLICATION privilege (or be a superuser)
37/// to enable logical replication.
38///
39/// # Arguments
40///
41/// * `client` - Connected PostgreSQL client
42///
43/// # Returns
44///
45/// Returns a `PrivilegeCheck` containing the user's privileges.
46///
47/// # Errors
48///
49/// This function will return an error if the database query fails.
50///
51/// # Examples
52///
53/// ```no_run
54/// # use anyhow::Result;
55/// # use database_replicator::postgres::{connect, check_source_privileges};
56/// # async fn example() -> Result<()> {
57/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
58/// let privs = check_source_privileges(&client).await?;
59/// assert!(privs.has_replication || privs.is_superuser);
60/// # Ok(())
61/// # }
62/// ```
63pub async fn check_source_privileges(client: &Client) -> Result<PrivilegeCheck> {
64    let row = client
65        .query_one(
66            "SELECT rolreplication, rolcreatedb, rolcreaterole, rolsuper
67             FROM pg_roles
68             WHERE rolname = current_user",
69            &[],
70        )
71        .await
72        .context("Failed to query user privileges")?;
73
74    // Check for AWS RDS rds_replication role membership
75    // This role exists only on AWS RDS and provides replication capability
76    let has_rds_replication = client
77        .query_opt(
78            "SELECT 1 FROM pg_roles WHERE rolname = 'rds_replication'
79             AND pg_has_role(current_user, 'rds_replication', 'MEMBER')",
80            &[],
81        )
82        .await
83        .unwrap_or(None)
84        .is_some();
85
86    Ok(PrivilegeCheck {
87        has_replication: row.get(0),
88        has_create_db: row.get(1),
89        has_create_role: row.get(2),
90        is_superuser: row.get(3),
91        has_rds_replication,
92    })
93}
94
95/// Check if connected user has sufficient privileges for target database
96///
97/// Queries `pg_roles` to determine the privileges of the currently connected user.
98/// For target databases, the user must have CREATEDB privilege (or be a superuser)
99/// to create new databases during migration.
100///
101/// # Arguments
102///
103/// * `client` - Connected PostgreSQL client
104///
105/// # Returns
106///
107/// Returns a `PrivilegeCheck` containing the user's privileges.
108///
109/// # Errors
110///
111/// This function will return an error if the database query fails.
112///
113/// # Examples
114///
115/// ```no_run
116/// # use anyhow::Result;
117/// # use database_replicator::postgres::{connect, check_target_privileges};
118/// # async fn example() -> Result<()> {
119/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
120/// let privs = check_target_privileges(&client).await?;
121/// assert!(privs.has_create_db || privs.is_superuser);
122/// # Ok(())
123/// # }
124/// ```
125pub async fn check_target_privileges(client: &Client) -> Result<PrivilegeCheck> {
126    // Same query as source
127    check_source_privileges(client).await
128}
129
130/// Check the wal_level setting on the target database
131///
132/// Queries the current `wal_level` configuration parameter.
133/// For logical replication (subscriptions), `wal_level` must be set to `logical`.
134///
135/// # Arguments
136///
137/// * `client` - Connected PostgreSQL client
138///
139/// # Returns
140///
141/// Returns the current `wal_level` setting as a String (e.g., "replica", "logical").
142///
143/// # Errors
144///
145/// This function will return an error if the database query fails.
146///
147/// # Examples
148///
149/// ```no_run
150/// # use anyhow::Result;
151/// # use database_replicator::postgres::{connect, check_wal_level};
152/// # async fn example() -> Result<()> {
153/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
154/// let wal_level = check_wal_level(&client).await?;
155/// assert_eq!(wal_level, "logical");
156/// # Ok(())
157/// # }
158/// ```
159pub async fn check_wal_level(client: &Client) -> Result<String> {
160    let row = client
161        .query_one("SHOW wal_level", &[])
162        .await
163        .context("Failed to query wal_level setting")?;
164
165    let wal_level: String = row.get(0);
166    Ok(wal_level)
167}
168
169/// Result of table-level permission check
170#[derive(Debug, Clone)]
171pub struct TablePermissionCheck {
172    /// Tables the user CAN read (has SELECT privilege)
173    pub accessible_tables: Vec<String>,
174    /// Tables the user CANNOT read (missing SELECT privilege)
175    pub inaccessible_tables: Vec<String>,
176}
177
178impl TablePermissionCheck {
179    /// Returns true if user has SELECT on all tables
180    pub fn all_accessible(&self) -> bool {
181        self.inaccessible_tables.is_empty()
182    }
183
184    /// Count of inaccessible tables
185    pub fn inaccessible_count(&self) -> usize {
186        self.inaccessible_tables.len()
187    }
188}
189
190/// Check SELECT permission on all user tables in a database
191///
192/// Queries pg_tables to find all user tables (excluding pg_catalog and
193/// information_schema) and checks if current user has SELECT privilege.
194///
195/// # Arguments
196///
197/// * `client` - Connected PostgreSQL client (must be connected to the target database)
198///
199/// # Returns
200///
201/// Returns `TablePermissionCheck` with lists of accessible and inaccessible tables.
202///
203/// # Errors
204///
205/// Returns an error if the permission query fails.
206///
207/// # Examples
208///
209/// ```no_run
210/// # use anyhow::Result;
211/// # use database_replicator::postgres::{connect, check_table_select_permissions};
212/// # async fn example() -> Result<()> {
213/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
214/// let perms = check_table_select_permissions(&client).await?;
215/// if !perms.all_accessible() {
216///     println!("Cannot read {} tables", perms.inaccessible_count());
217/// }
218/// # Ok(())
219/// # }
220/// ```
221pub async fn check_table_select_permissions(client: &Client) -> Result<TablePermissionCheck> {
222    // Query all user tables and check SELECT permission
223    let query = r#"
224        SELECT
225            schemaname,
226            tablename,
227            has_table_privilege(current_user, quote_ident(schemaname) || '.' || quote_ident(tablename), 'SELECT') as has_select
228        FROM pg_tables
229        WHERE schemaname NOT IN ('pg_catalog', 'information_schema')
230        ORDER BY schemaname, tablename
231    "#;
232
233    let rows = client
234        .query(query, &[])
235        .await
236        .context("Failed to query table permissions")?;
237
238    let mut accessible = Vec::new();
239    let mut inaccessible = Vec::new();
240
241    for row in rows {
242        let schema: String = row.get(0);
243        let table: String = row.get(1);
244        let has_select: bool = row.get(2);
245
246        let full_name = format!("{}.{}", schema, table);
247
248        if has_select {
249            accessible.push(full_name);
250        } else {
251            inaccessible.push(full_name);
252        }
253    }
254
255    Ok(TablePermissionCheck {
256        accessible_tables: accessible,
257        inaccessible_tables: inaccessible,
258    })
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use crate::postgres::connect;
265
266    #[tokio::test]
267    #[ignore]
268    async fn test_check_source_privileges() {
269        let url = std::env::var("TEST_SOURCE_URL").unwrap();
270        let client = connect(&url).await.unwrap();
271
272        let privileges = check_source_privileges(&client).await.unwrap();
273
274        // Should have at least one replication method
275        assert!(
276            privileges.can_replicate(),
277            "Source user should have REPLICATION privilege, rds_replication role, or be superuser"
278        );
279    }
280
281    #[tokio::test]
282    #[ignore]
283    async fn test_check_target_privileges() {
284        let url = std::env::var("TEST_TARGET_URL").unwrap();
285        let client = connect(&url).await.unwrap();
286
287        let privileges = check_target_privileges(&client).await.unwrap();
288
289        // Should have create privileges or be superuser
290        assert!(
291            privileges.has_create_db || privileges.is_superuser,
292            "Target user should have CREATE DATABASE privilege or be superuser"
293        );
294    }
295
296    #[tokio::test]
297    #[ignore] // Requires database connection
298    async fn test_check_table_select_permissions() {
299        let url = std::env::var("TEST_SOURCE_URL").expect("TEST_SOURCE_URL not set");
300        let client = connect(&url).await.unwrap();
301
302        let result = check_table_select_permissions(&client).await.unwrap();
303
304        // Just verify the function runs without error
305        // In a real database, results depend on actual permissions
306        println!("Accessible tables: {}", result.accessible_tables.len());
307        println!("Inaccessible tables: {}", result.inaccessible_tables.len());
308    }
309
310    #[test]
311    fn test_table_permission_check_struct() {
312        let check = TablePermissionCheck {
313            accessible_tables: vec!["public.users".to_string()],
314            inaccessible_tables: vec![],
315        };
316        assert!(check.all_accessible());
317        assert_eq!(check.inaccessible_count(), 0);
318
319        let check_with_issues = TablePermissionCheck {
320            accessible_tables: vec!["public.users".to_string()],
321            inaccessible_tables: vec!["public.secrets".to_string()],
322        };
323        assert!(!check_with_issues.all_accessible());
324        assert_eq!(check_with_issues.inaccessible_count(), 1);
325    }
326}