database_replicator/postgres/privileges.rs
1// ABOUTME: Privilege checking utilities for migration prerequisites
2// ABOUTME: Validates source and target databases have required permissions
3
4use anyhow::{Context, Result};
5use tokio_postgres::Client;
6
7/// Result of privilege check for a PostgreSQL user
8///
9/// Contains information about the user's permissions required for migration.
10pub struct PrivilegeCheck {
11 /// User has REPLICATION privilege (required for source database)
12 pub has_replication: bool,
13 /// User has CREATEDB privilege (required for target database)
14 pub has_create_db: bool,
15 /// User has CREATEROLE privilege (optional, for role migration)
16 pub has_create_role: bool,
17 /// User is a superuser (bypasses other privilege requirements)
18 pub is_superuser: bool,
19}
20
21/// Check if connected user has replication privileges (needed for source)
22///
23/// Queries `pg_roles` to determine the privileges of the currently connected user.
24/// For source databases, the user must have REPLICATION privilege (or be a superuser)
25/// to enable logical replication.
26///
27/// # Arguments
28///
29/// * `client` - Connected PostgreSQL client
30///
31/// # Returns
32///
33/// Returns a `PrivilegeCheck` containing the user's privileges.
34///
35/// # Errors
36///
37/// This function will return an error if the database query fails.
38///
39/// # Examples
40///
41/// ```no_run
42/// # use anyhow::Result;
43/// # use database_replicator::postgres::{connect, check_source_privileges};
44/// # async fn example() -> Result<()> {
45/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
46/// let privs = check_source_privileges(&client).await?;
47/// assert!(privs.has_replication || privs.is_superuser);
48/// # Ok(())
49/// # }
50/// ```
51pub async fn check_source_privileges(client: &Client) -> Result<PrivilegeCheck> {
52 let row = client
53 .query_one(
54 "SELECT rolreplication, rolcreatedb, rolcreaterole, rolsuper
55 FROM pg_roles
56 WHERE rolname = current_user",
57 &[],
58 )
59 .await
60 .context("Failed to query user privileges")?;
61
62 Ok(PrivilegeCheck {
63 has_replication: row.get(0),
64 has_create_db: row.get(1),
65 has_create_role: row.get(2),
66 is_superuser: row.get(3),
67 })
68}
69
70/// Check if connected user has sufficient privileges for target database
71///
72/// Queries `pg_roles` to determine the privileges of the currently connected user.
73/// For target databases, the user must have CREATEDB privilege (or be a superuser)
74/// to create new databases during migration.
75///
76/// # Arguments
77///
78/// * `client` - Connected PostgreSQL client
79///
80/// # Returns
81///
82/// Returns a `PrivilegeCheck` containing the user's privileges.
83///
84/// # Errors
85///
86/// This function will return an error if the database query fails.
87///
88/// # Examples
89///
90/// ```no_run
91/// # use anyhow::Result;
92/// # use database_replicator::postgres::{connect, check_target_privileges};
93/// # async fn example() -> Result<()> {
94/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
95/// let privs = check_target_privileges(&client).await?;
96/// assert!(privs.has_create_db || privs.is_superuser);
97/// # Ok(())
98/// # }
99/// ```
100pub async fn check_target_privileges(client: &Client) -> Result<PrivilegeCheck> {
101 // Same query as source
102 check_source_privileges(client).await
103}
104
105/// Check the wal_level setting on the target database
106///
107/// Queries the current `wal_level` configuration parameter.
108/// For logical replication (subscriptions), `wal_level` must be set to `logical`.
109///
110/// # Arguments
111///
112/// * `client` - Connected PostgreSQL client
113///
114/// # Returns
115///
116/// Returns the current `wal_level` setting as a String (e.g., "replica", "logical").
117///
118/// # Errors
119///
120/// This function will return an error if the database query fails.
121///
122/// # Examples
123///
124/// ```no_run
125/// # use anyhow::Result;
126/// # use database_replicator::postgres::{connect, check_wal_level};
127/// # async fn example() -> Result<()> {
128/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
129/// let wal_level = check_wal_level(&client).await?;
130/// assert_eq!(wal_level, "logical");
131/// # Ok(())
132/// # }
133/// ```
134pub async fn check_wal_level(client: &Client) -> Result<String> {
135 let row = client
136 .query_one("SHOW wal_level", &[])
137 .await
138 .context("Failed to query wal_level setting")?;
139
140 let wal_level: String = row.get(0);
141 Ok(wal_level)
142}
143
144/// Result of table-level permission check
145#[derive(Debug, Clone)]
146pub struct TablePermissionCheck {
147 /// Tables the user CAN read (has SELECT privilege)
148 pub accessible_tables: Vec<String>,
149 /// Tables the user CANNOT read (missing SELECT privilege)
150 pub inaccessible_tables: Vec<String>,
151}
152
153impl TablePermissionCheck {
154 /// Returns true if user has SELECT on all tables
155 pub fn all_accessible(&self) -> bool {
156 self.inaccessible_tables.is_empty()
157 }
158
159 /// Count of inaccessible tables
160 pub fn inaccessible_count(&self) -> usize {
161 self.inaccessible_tables.len()
162 }
163}
164
165/// Check SELECT permission on all user tables in a database
166///
167/// Queries pg_tables to find all user tables (excluding pg_catalog and
168/// information_schema) and checks if current user has SELECT privilege.
169///
170/// # Arguments
171///
172/// * `client` - Connected PostgreSQL client (must be connected to the target database)
173///
174/// # Returns
175///
176/// Returns `TablePermissionCheck` with lists of accessible and inaccessible tables.
177///
178/// # Errors
179///
180/// Returns an error if the permission query fails.
181///
182/// # Examples
183///
184/// ```no_run
185/// # use anyhow::Result;
186/// # use database_replicator::postgres::{connect, check_table_select_permissions};
187/// # async fn example() -> Result<()> {
188/// let client = connect("postgresql://user:pass@localhost:5432/mydb").await?;
189/// let perms = check_table_select_permissions(&client).await?;
190/// if !perms.all_accessible() {
191/// println!("Cannot read {} tables", perms.inaccessible_count());
192/// }
193/// # Ok(())
194/// # }
195/// ```
196pub async fn check_table_select_permissions(client: &Client) -> Result<TablePermissionCheck> {
197 // Query all user tables and check SELECT permission
198 let query = r#"
199 SELECT
200 schemaname,
201 tablename,
202 has_table_privilege(current_user, quote_ident(schemaname) || '.' || quote_ident(tablename), 'SELECT') as has_select
203 FROM pg_tables
204 WHERE schemaname NOT IN ('pg_catalog', 'information_schema')
205 ORDER BY schemaname, tablename
206 "#;
207
208 let rows = client
209 .query(query, &[])
210 .await
211 .context("Failed to query table permissions")?;
212
213 let mut accessible = Vec::new();
214 let mut inaccessible = Vec::new();
215
216 for row in rows {
217 let schema: String = row.get(0);
218 let table: String = row.get(1);
219 let has_select: bool = row.get(2);
220
221 let full_name = format!("{}.{}", schema, table);
222
223 if has_select {
224 accessible.push(full_name);
225 } else {
226 inaccessible.push(full_name);
227 }
228 }
229
230 Ok(TablePermissionCheck {
231 accessible_tables: accessible,
232 inaccessible_tables: inaccessible,
233 })
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use crate::postgres::connect;
240
241 #[tokio::test]
242 #[ignore]
243 async fn test_check_source_privileges() {
244 let url = std::env::var("TEST_SOURCE_URL").unwrap();
245 let client = connect(&url).await.unwrap();
246
247 let privileges = check_source_privileges(&client).await.unwrap();
248
249 // Should have at least one privilege
250 assert!(
251 privileges.has_replication || privileges.is_superuser,
252 "Source user should have REPLICATION privilege or be superuser"
253 );
254 }
255
256 #[tokio::test]
257 #[ignore]
258 async fn test_check_target_privileges() {
259 let url = std::env::var("TEST_TARGET_URL").unwrap();
260 let client = connect(&url).await.unwrap();
261
262 let privileges = check_target_privileges(&client).await.unwrap();
263
264 // Should have create privileges or be superuser
265 assert!(
266 privileges.has_create_db || privileges.is_superuser,
267 "Target user should have CREATE DATABASE privilege or be superuser"
268 );
269 }
270
271 #[tokio::test]
272 #[ignore] // Requires database connection
273 async fn test_check_table_select_permissions() {
274 let url = std::env::var("TEST_SOURCE_URL").expect("TEST_SOURCE_URL not set");
275 let client = connect(&url).await.unwrap();
276
277 let result = check_table_select_permissions(&client).await.unwrap();
278
279 // Just verify the function runs without error
280 // In a real database, results depend on actual permissions
281 println!("Accessible tables: {}", result.accessible_tables.len());
282 println!("Inaccessible tables: {}", result.inaccessible_tables.len());
283 }
284
285 #[test]
286 fn test_table_permission_check_struct() {
287 let check = TablePermissionCheck {
288 accessible_tables: vec!["public.users".to_string()],
289 inaccessible_tables: vec![],
290 };
291 assert!(check.all_accessible());
292 assert_eq!(check.inaccessible_count(), 0);
293
294 let check_with_issues = TablePermissionCheck {
295 accessible_tables: vec!["public.users".to_string()],
296 inaccessible_tables: vec!["public.secrets".to_string()],
297 };
298 assert!(!check_with_issues.all_accessible());
299 assert_eq!(check_with_issues.inaccessible_count(), 1);
300 }
301}