pgmt 0.4.9

PostgreSQL migration tool that keeps your schema files as the source of truth
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
use anyhow::Result;
use sqlx::postgres::PgConnection;
use sqlx::postgres::types::Oid;
use std::collections::{HashMap, HashSet};
use tracing::info;

use crate::catalog::utils::is_system_schema;
use crate::catalog::{DependsOn, comments::Commentable, id::DbObjectId};

/// Command type for RLS policies
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PolicyCommand {
    All,    // 'a' or '*' - applies to all commands
    Select, // 'r' - SELECT only
    Insert, // 'a' - INSERT only
    Update, // 'w' - UPDATE only
    Delete, // 'd' - DELETE only
}

/// Represents a PostgreSQL Row-Level Security policy
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Policy {
    pub schema: String,
    pub table_name: String,
    pub name: String,

    /// Command type this policy applies to
    pub command: PolicyCommand,

    /// true = PERMISSIVE, false = RESTRICTIVE
    pub permissive: bool,

    /// Roles this policy applies to (empty = PUBLIC)
    pub roles: Vec<String>,

    /// USING expression (for SELECT, UPDATE, DELETE)
    pub using_expr: Option<String>,

    /// WITH CHECK expression (for INSERT, UPDATE)
    pub with_check_expr: Option<String>,

    /// Comment on the policy
    pub comment: Option<String>,

    /// Dependencies (primarily the table)
    pub depends_on: Vec<DbObjectId>,
}

impl DependsOn for Policy {
    fn id(&self) -> DbObjectId {
        DbObjectId::Policy {
            schema: self.schema.clone(),
            table: self.table_name.clone(),
            name: self.name.clone(),
        }
    }

    fn depends_on(&self) -> &[DbObjectId] {
        &self.depends_on
    }
}

impl Commentable for Policy {
    fn comment(&self) -> &Option<String> {
        &self.comment
    }
}

/// Row returned from the policy dependencies query.
struct PolicyDependencyRow {
    /// Column name when refobjsubid > 0 (column-level dependency)
    column_name: String,
    /// Schema of the referenced table/view
    table_schema: String,
    /// Name of the referenced table/view
    table_name: String,
    /// relkind of the referenced object ('r'=table, 'v'=view, etc.)
    relkind: String,
}

/// Fetch all policy column-level dependencies in a single query.
///
/// Returns a HashMap keyed by policy OID, containing column dependencies for each policy.
/// PostgreSQL tracks column-level dependencies via `pg_depend` with `refobjsubid > 0`.
async fn fetch_all_policy_dependencies(
    conn: &mut PgConnection,
) -> Result<HashMap<Oid, Vec<PolicyDependencyRow>>> {
    let rows = sqlx::query!(
        r#"
        SELECT
            p.oid AS "policy_oid!",
            a.attname AS "column_name!",
            n.nspname AS "table_schema!",
            c.relname AS "table_name!",
            c.relkind::text AS "relkind!"
        FROM pg_policy p
        JOIN pg_depend d ON d.objid = p.oid AND d.classid = 'pg_policy'::regclass
        JOIN pg_class c ON d.refobjid = c.oid AND d.refclassid = 'pg_class'::regclass
        JOIN pg_namespace n ON c.relnamespace = n.oid
        JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = d.refobjsubid
        WHERE d.refobjsubid > 0
          AND d.deptype = 'n'
          AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
        ORDER BY p.oid, n.nspname, c.relname, a.attname
        "#
    )
    .fetch_all(&mut *conn)
    .await?;

    // Group by policy_oid
    let mut deps_by_policy: HashMap<Oid, Vec<PolicyDependencyRow>> = HashMap::new();
    for row in rows {
        deps_by_policy
            .entry(row.policy_oid)
            .or_default()
            .push(PolicyDependencyRow {
                column_name: row.column_name,
                table_schema: row.table_schema,
                table_name: row.table_name,
                relkind: row.relkind,
            });
    }
    Ok(deps_by_policy)
}

/// Populate policy column-level dependencies using pg_depend.
///
/// PostgreSQL tracks column-level dependencies for RLS policies via `pg_depend`
/// with `refobjsubid > 0`. This enables precise cascade handling: only policies
/// that reference a changed column need to be dropped and recreated.
///
/// Also derives object-level dependencies (views, tables) from column-level deps,
/// since PostgreSQL may only record column-level entries for referenced objects
/// without separate object-level entries.
fn populate_policy_dependencies(
    policies: &mut [Policy],
    policy_oids: &[Oid],
    deps_map: &HashMap<Oid, Vec<PolicyDependencyRow>>,
) {
    for (policy, oid) in policies.iter_mut().zip(policy_oids.iter()) {
        let Some(deps) = deps_map.get(oid) else {
            continue;
        };

        // The parent table is already in depends_on
        let parent_table = DbObjectId::Table {
            schema: policy.schema.clone(),
            name: policy.table_name.clone(),
        };

        for dep in deps {
            // Skip system schemas
            if is_system_schema(&dep.table_schema) {
                continue;
            }

            // Add column-level dependency
            let col_dep_id = DbObjectId::Column {
                schema: dep.table_schema.clone(),
                table: dep.table_name.clone(),
                column: dep.column_name.clone(),
            };
            if !policy.depends_on.contains(&col_dep_id) {
                policy.depends_on.push(col_dep_id);
            }

            // Derive object-level dependency from the column's parent object
            let obj_dep = match dep.relkind.as_str() {
                "v" | "m" => DbObjectId::View {
                    schema: dep.table_schema.clone(),
                    name: dep.table_name.clone(),
                },
                "r" | "p" => {
                    let table_dep = DbObjectId::Table {
                        schema: dep.table_schema.clone(),
                        name: dep.table_name.clone(),
                    };
                    // Skip the policy's own parent table (already hard-coded)
                    if table_dep == parent_table {
                        continue;
                    }
                    table_dep
                }
                _ => continue,
            };
            if !policy.depends_on.contains(&obj_dep) {
                policy.depends_on.push(obj_dep);
            }
        }
    }
}

/// Fetch all policy object-level dependencies in a single query.
///
/// Returns a HashMap keyed by policy OID, containing object-level dependencies.
/// These are dependencies where `refobjsubid = 0`, covering views, tables, and
/// functions referenced in USING/WITH CHECK expressions.
async fn fetch_all_policy_object_dependencies(
    conn: &mut PgConnection,
) -> Result<HashMap<Oid, Vec<PolicyObjectDependencyRow>>> {
    let rows = sqlx::query!(
        r#"
        SELECT
            p.oid AS "policy_oid!",
            cls.relkind::text AS "cls_relkind?",
            cls_n.nspname AS "cls_schema?",
            cls.relname AS "cls_name?",
            proc.proname AS "proc_name?",
            proc_n.nspname AS "proc_schema?",
            pg_catalog.pg_get_function_identity_arguments(proc.oid) AS "proc_args?",
            ext_procs.extname AS "proc_extension_name?"
        FROM pg_policy p
        JOIN pg_depend d ON d.objid = p.oid AND d.classid = 'pg_policy'::regclass
        LEFT JOIN pg_class cls ON d.refclassid = 'pg_class'::regclass AND d.refobjid = cls.oid
        LEFT JOIN pg_namespace cls_n ON cls.relnamespace = cls_n.oid
        LEFT JOIN pg_proc proc ON d.refclassid = 'pg_proc'::regclass AND d.refobjid = proc.oid
        LEFT JOIN pg_namespace proc_n ON proc.pronamespace = proc_n.oid
        LEFT JOIN (
            SELECT DISTINCT dep.objid AS proc_oid, e.extname
            FROM pg_depend dep
            JOIN pg_extension e ON dep.refobjid = e.oid
            WHERE dep.deptype = 'e'
        ) ext_procs ON ext_procs.proc_oid = proc.oid
        WHERE d.refobjsubid = 0
          AND d.deptype = 'n'
          AND (cls.oid IS NOT NULL OR proc.oid IS NOT NULL)
        ORDER BY p.oid
        "#
    )
    .fetch_all(&mut *conn)
    .await?;

    let mut deps_by_policy: HashMap<Oid, Vec<PolicyObjectDependencyRow>> = HashMap::new();
    for row in rows {
        deps_by_policy
            .entry(row.policy_oid)
            .or_default()
            .push(PolicyObjectDependencyRow {
                cls_relkind: row.cls_relkind,
                cls_schema: row.cls_schema,
                cls_name: row.cls_name,
                proc_name: row.proc_name,
                proc_schema: row.proc_schema,
                proc_args: row.proc_args,
                proc_extension_name: row.proc_extension_name,
            });
    }
    Ok(deps_by_policy)
}

/// Row returned from the policy object-level dependencies query.
struct PolicyObjectDependencyRow {
    cls_relkind: Option<String>,
    cls_schema: Option<String>,
    cls_name: Option<String>,
    proc_name: Option<String>,
    proc_schema: Option<String>,
    proc_args: Option<String>,
    proc_extension_name: Option<String>,
}

/// Populate policy object-level dependencies (views, tables, functions) using pg_depend.
///
/// This complements `populate_policy_dependencies` (which handles column-level deps)
/// by resolving object-level references from USING/WITH CHECK expressions. Without this,
/// policies that reference views or functions would not have those dependencies tracked,
/// causing incorrect ordering of DROP operations.
fn populate_policy_object_dependencies(
    policies: &mut [Policy],
    policy_oids: &[Oid],
    deps_map: &HashMap<Oid, Vec<PolicyObjectDependencyRow>>,
) {
    for (policy, oid) in policies.iter_mut().zip(policy_oids.iter()) {
        let Some(deps) = deps_map.get(oid) else {
            continue;
        };

        // The parent table is already in depends_on; track it to skip duplicates
        let parent_table = DbObjectId::Table {
            schema: policy.schema.clone(),
            name: policy.table_name.clone(),
        };

        for dep in deps {
            // Resolve table/view references
            if let (Some(relkind), Some(schema), Some(name)) =
                (&dep.cls_relkind, &dep.cls_schema, &dep.cls_name)
            {
                if is_system_schema(schema) {
                    continue;
                }
                let dep_id = match relkind.as_str() {
                    "v" | "m" => DbObjectId::View {
                        schema: schema.clone(),
                        name: name.clone(),
                    },
                    "r" | "p" => {
                        let table_dep = DbObjectId::Table {
                            schema: schema.clone(),
                            name: name.clone(),
                        };
                        // Skip the policy's own parent table (already hard-coded)
                        if table_dep == parent_table {
                            continue;
                        }
                        table_dep
                    }
                    _ => continue,
                };
                if !policy.depends_on.contains(&dep_id) {
                    policy.depends_on.push(dep_id);
                }
                continue;
            }

            // Resolve function references
            if let (Some(name), Some(schema), Some(args)) =
                (&dep.proc_name, &dep.proc_schema, &dep.proc_args)
            {
                if is_system_schema(schema) {
                    continue;
                }
                let dep_id = if let Some(ext_name) = &dep.proc_extension_name {
                    DbObjectId::Extension {
                        name: ext_name.clone(),
                    }
                } else {
                    DbObjectId::Function {
                        schema: schema.clone(),
                        name: name.clone(),
                        arguments: args.clone(),
                    }
                };
                if !policy.depends_on.contains(&dep_id) {
                    policy.depends_on.push(dep_id);
                }
            }
        }
    }

    // Deduplicate dependencies for each policy
    for policy in policies.iter_mut() {
        let unique_deps: HashSet<_> = policy.depends_on.drain(..).collect();
        policy.depends_on.extend(unique_deps);
    }
}

/// Fetch all RLS policies from the database
pub async fn fetch(conn: &mut PgConnection) -> Result<Vec<Policy>> {
    info!("Fetching RLS policies...");

    let policies = sqlx::query!(
        r#"
        SELECT
            p.oid AS "policy_oid!",
            n.nspname AS schema_name,
            c.relname AS table_name,
            p.polname AS policy_name,
            p.polcmd::text AS "command!",
            p.polpermissive AS "permissive!",
            COALESCE(
                ARRAY(
                    SELECT rolname FROM pg_roles
                    WHERE oid = ANY(p.polroles)
                    ORDER BY rolname
                ),
                '{}'::text[]
            ) AS "roles!: Vec<String>",
            pg_get_expr(p.polqual, p.polrelid) AS "using_expr?",
            pg_get_expr(p.polwithcheck, p.polrelid) AS "with_check_expr?",
            d.description AS "comment?"
        FROM pg_policy p
        JOIN pg_class c ON p.polrelid = c.oid
        JOIN pg_namespace n ON c.relnamespace = n.oid
        LEFT JOIN pg_description d ON d.objoid = p.oid AND d.objsubid = 0
        WHERE n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
        ORDER BY n.nspname, c.relname, p.polname
        "#
    )
    .fetch_all(&mut *conn)
    .await?;

    let mut result = Vec::new();
    let mut policy_oids = Vec::new();

    for row in &policies {
        // Parse command type
        let command = match row.command.as_str() {
            "*" => PolicyCommand::All,
            "r" => PolicyCommand::Select,
            "a" => PolicyCommand::Insert,
            "w" => PolicyCommand::Update,
            "d" => PolicyCommand::Delete,
            _ => PolicyCommand::All, // Default fallback
        };

        // Build dependencies - start with table dependency
        let depends_on = vec![
            // Policies depend on their table
            DbObjectId::Table {
                schema: row.schema_name.clone(),
                name: row.table_name.clone(),
            },
        ];

        let policy = Policy {
            schema: row.schema_name.clone(),
            table_name: row.table_name.clone(),
            name: row.policy_name.clone(),
            command,
            permissive: row.permissive,
            roles: row.roles.clone(),
            using_expr: row.using_expr.clone(),
            with_check_expr: row.with_check_expr.clone(),
            comment: row.comment.clone(),
            depends_on,
        };

        policy_oids.push(row.policy_oid);
        result.push(policy);
    }

    // Phase 2: Populate dependencies using pg_depend
    if !result.is_empty() {
        info!("Fetching policy dependencies...");
        let col_deps_map = fetch_all_policy_dependencies(&mut *conn).await?;
        populate_policy_dependencies(&mut result, &policy_oids, &col_deps_map);

        let obj_deps_map = fetch_all_policy_object_dependencies(&mut *conn).await?;
        populate_policy_object_dependencies(&mut result, &policy_oids, &obj_deps_map);
    }

    Ok(result)
}