database_replicator/postgres/
privileges.rs

1// ABOUTME: Privilege checking utilities for migration prerequisites
2// ABOUTME: Validates source and target databases have required permissions
3
4use anyhow::{Context, Result};
5use tokio_postgres::Client;
6
7/// Result of privilege check for a PostgreSQL user
8///
9/// Contains information about the user's permissions required for migration.
10pub struct PrivilegeCheck {
11    /// User has REPLICATION privilege (required for source database)
12    pub has_replication: bool,
13    /// User has CREATEDB privilege (required for target database)
14    pub has_create_db: bool,
15    /// User has CREATEROLE privilege (optional, for role migration)
16    pub has_create_role: bool,
17    /// User is a superuser (bypasses other privilege requirements)
18    pub is_superuser: bool,
19}
20
21/// Check if connected user has replication privileges (needed for source)
22///
23/// Queries `pg_roles` to determine the privileges of the currently connected user.
24/// For source databases, the user must have REPLICATION privilege (or be a superuser)
25/// to enable logical replication.
26///
27/// # Arguments
28///
29/// * `client` - Connected PostgreSQL client
30///
31/// # Returns
32///
33/// Returns a `PrivilegeCheck` containing the user's privileges.
34///
35/// # Errors
36///
37/// This function will return an error if the database query fails.
38///
39/// # Examples
40///
41/// ```no_run
42/// # use anyhow::Result;
43/// # use database_replicator::postgres::{connect, check_source_privileges};
44/// # async fn example() -> Result<()> {
45/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
46/// let privs = check_source_privileges(&client).await?;
47/// assert!(privs.has_replication || privs.is_superuser);
48/// # Ok(())
49/// # }
50/// ```
51pub async fn check_source_privileges(client: &Client) -> Result<PrivilegeCheck> {
52    let row = client
53        .query_one(
54            "SELECT rolreplication, rolcreatedb, rolcreaterole, rolsuper
55             FROM pg_roles
56             WHERE rolname = current_user",
57            &[],
58        )
59        .await
60        .context("Failed to query user privileges")?;
61
62    Ok(PrivilegeCheck {
63        has_replication: row.get(0),
64        has_create_db: row.get(1),
65        has_create_role: row.get(2),
66        is_superuser: row.get(3),
67    })
68}
69
70/// Check if connected user has sufficient privileges for target database
71///
72/// Queries `pg_roles` to determine the privileges of the currently connected user.
73/// For target databases, the user must have CREATEDB privilege (or be a superuser)
74/// to create new databases during migration.
75///
76/// # Arguments
77///
78/// * `client` - Connected PostgreSQL client
79///
80/// # Returns
81///
82/// Returns a `PrivilegeCheck` containing the user's privileges.
83///
84/// # Errors
85///
86/// This function will return an error if the database query fails.
87///
88/// # Examples
89///
90/// ```no_run
91/// # use anyhow::Result;
92/// # use database_replicator::postgres::{connect, check_target_privileges};
93/// # async fn example() -> Result<()> {
94/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
95/// let privs = check_target_privileges(&client).await?;
96/// assert!(privs.has_create_db || privs.is_superuser);
97/// # Ok(())
98/// # }
99/// ```
100pub async fn check_target_privileges(client: &Client) -> Result<PrivilegeCheck> {
101    // Same query as source
102    check_source_privileges(client).await
103}
104
105/// Check the wal_level setting on the target database
106///
107/// Queries the current `wal_level` configuration parameter.
108/// For logical replication (subscriptions), `wal_level` must be set to `logical`.
109///
110/// # Arguments
111///
112/// * `client` - Connected PostgreSQL client
113///
114/// # Returns
115///
116/// Returns the current `wal_level` setting as a String (e.g., "replica", "logical").
117///
118/// # Errors
119///
120/// This function will return an error if the database query fails.
121///
122/// # Examples
123///
124/// ```no_run
125/// # use anyhow::Result;
126/// # use database_replicator::postgres::{connect, check_wal_level};
127/// # async fn example() -> Result<()> {
128/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
129/// let wal_level = check_wal_level(&client).await?;
130/// assert_eq!(wal_level, "logical");
131/// # Ok(())
132/// # }
133/// ```
134pub async fn check_wal_level(client: &Client) -> Result<String> {
135    let row = client
136        .query_one("SHOW wal_level", &[])
137        .await
138        .context("Failed to query wal_level setting")?;
139
140    let wal_level: String = row.get(0);
141    Ok(wal_level)
142}
143
144/// Result of table-level permission check
145#[derive(Debug, Clone)]
146pub struct TablePermissionCheck {
147    /// Tables the user CAN read (has SELECT privilege)
148    pub accessible_tables: Vec<String>,
149    /// Tables the user CANNOT read (missing SELECT privilege)
150    pub inaccessible_tables: Vec<String>,
151}
152
153impl TablePermissionCheck {
154    /// Returns true if user has SELECT on all tables
155    pub fn all_accessible(&self) -> bool {
156        self.inaccessible_tables.is_empty()
157    }
158
159    /// Count of inaccessible tables
160    pub fn inaccessible_count(&self) -> usize {
161        self.inaccessible_tables.len()
162    }
163}
164
165/// Check SELECT permission on all user tables in a database
166///
167/// Queries pg_tables to find all user tables (excluding pg_catalog and
168/// information_schema) and checks if current user has SELECT privilege.
169///
170/// # Arguments
171///
172/// * `client` - Connected PostgreSQL client (must be connected to the target database)
173///
174/// # Returns
175///
176/// Returns `TablePermissionCheck` with lists of accessible and inaccessible tables.
177///
178/// # Errors
179///
180/// Returns an error if the permission query fails.
181///
182/// # Examples
183///
184/// ```no_run
185/// # use anyhow::Result;
186/// # use database_replicator::postgres::{connect, check_table_select_permissions};
187/// # async fn example() -> Result<()> {
188/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
189/// let perms = check_table_select_permissions(&client).await?;
190/// if !perms.all_accessible() {
191///     println!("Cannot read {} tables", perms.inaccessible_count());
192/// }
193/// # Ok(())
194/// # }
195/// ```
196pub async fn check_table_select_permissions(client: &Client) -> Result<TablePermissionCheck> {
197    // Query all user tables and check SELECT permission
198    let query = r#"
199        SELECT
200            schemaname,
201            tablename,
202            has_table_privilege(current_user, quote_ident(schemaname) || '.' || quote_ident(tablename), 'SELECT') as has_select
203        FROM pg_tables
204        WHERE schemaname NOT IN ('pg_catalog', 'information_schema')
205        ORDER BY schemaname, tablename
206    "#;
207
208    let rows = client
209        .query(query, &[])
210        .await
211        .context("Failed to query table permissions")?;
212
213    let mut accessible = Vec::new();
214    let mut inaccessible = Vec::new();
215
216    for row in rows {
217        let schema: String = row.get(0);
218        let table: String = row.get(1);
219        let has_select: bool = row.get(2);
220
221        let full_name = format!("{}.{}", schema, table);
222
223        if has_select {
224            accessible.push(full_name);
225        } else {
226            inaccessible.push(full_name);
227        }
228    }
229
230    Ok(TablePermissionCheck {
231        accessible_tables: accessible,
232        inaccessible_tables: inaccessible,
233    })
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use crate::postgres::connect;
240
241    #[tokio::test]
242    #[ignore]
243    async fn test_check_source_privileges() {
244        let url = std::env::var("TEST_SOURCE_URL").unwrap();
245        let client = connect(&url).await.unwrap();
246
247        let privileges = check_source_privileges(&client).await.unwrap();
248
249        // Should have at least one privilege
250        assert!(
251            privileges.has_replication || privileges.is_superuser,
252            "Source user should have REPLICATION privilege or be superuser"
253        );
254    }
255
256    #[tokio::test]
257    #[ignore]
258    async fn test_check_target_privileges() {
259        let url = std::env::var("TEST_TARGET_URL").unwrap();
260        let client = connect(&url).await.unwrap();
261
262        let privileges = check_target_privileges(&client).await.unwrap();
263
264        // Should have create privileges or be superuser
265        assert!(
266            privileges.has_create_db || privileges.is_superuser,
267            "Target user should have CREATE DATABASE privilege or be superuser"
268        );
269    }
270
271    #[tokio::test]
272    #[ignore] // Requires database connection
273    async fn test_check_table_select_permissions() {
274        let url = std::env::var("TEST_SOURCE_URL").expect("TEST_SOURCE_URL not set");
275        let client = connect(&url).await.unwrap();
276
277        let result = check_table_select_permissions(&client).await.unwrap();
278
279        // Just verify the function runs without error
280        // In a real database, results depend on actual permissions
281        println!("Accessible tables: {}", result.accessible_tables.len());
282        println!("Inaccessible tables: {}", result.inaccessible_tables.len());
283    }
284
285    #[test]
286    fn test_table_permission_check_struct() {
287        let check = TablePermissionCheck {
288            accessible_tables: vec!["public.users".to_string()],
289            inaccessible_tables: vec![],
290        };
291        assert!(check.all_accessible());
292        assert_eq!(check.inaccessible_count(), 0);
293
294        let check_with_issues = TablePermissionCheck {
295            accessible_tables: vec!["public.users".to_string()],
296            inaccessible_tables: vec!["public.secrets".to_string()],
297        };
298        assert!(!check_with_issues.all_accessible());
299        assert_eq!(check_with_issues.inaccessible_count(), 1);
300    }
301}