Skip to main content

pgorm_check/
schema_introspect.rs

1use crate::client::{CheckClient, RowExt};
2use crate::error::{CheckError, CheckResult};
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6#[serde(rename_all = "snake_case")]
7pub enum RelationKind {
8    Table,
9    PartitionedTable,
10    View,
11    MaterializedView,
12    ForeignTable,
13    Other,
14}
15
16impl RelationKind {
17    fn from_relkind(relkind: i8) -> Self {
18        // Postgres stores `relkind` as a "char" internally. tokio-postgres exposes it as i8.
19        match relkind as u8 as char {
20            'r' => Self::Table,
21            'p' => Self::PartitionedTable,
22            'v' => Self::View,
23            'm' => Self::MaterializedView,
24            'f' => Self::ForeignTable,
25            _ => Self::Other,
26        }
27    }
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
31pub struct ColumnInfo {
32    pub name: String,
33    pub data_type: String,
34    pub not_null: bool,
35    pub default_expr: Option<String>,
36    pub ordinal: i32,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
40pub struct TableInfo {
41    pub schema: String,
42    pub name: String,
43    pub kind: RelationKind,
44    pub columns: Vec<ColumnInfo>,
45}
46
47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
48pub struct DbSchema {
49    pub schemas: Vec<String>,
50    pub tables: Vec<TableInfo>,
51}
52
53impl DbSchema {
54    pub fn find_table(&self, schema: &str, table: &str) -> Option<&TableInfo> {
55        self.tables
56            .iter()
57            .find(|t| t.schema == schema && t.name == table)
58    }
59}
60
61pub async fn schema_fingerprint<C: CheckClient>(
62    client: &C,
63    schemas: &[String],
64) -> CheckResult<String> {
65    let row = client
66        .query_one(
67            r#"
68SELECT
69  md5(
70    COALESCE(
71      string_agg(
72        concat_ws(
73          '|',
74          n.nspname,
75          c.relname,
76          c.relkind::text,
77          a.attnum::text,
78          a.attname,
79          pg_catalog.format_type(a.atttypid, a.atttypmod),
80          a.attnotnull::text,
81          COALESCE(a.attidentity::text, ''),
82          COALESCE(a.attgenerated::text, ''),
83          COALESCE(pg_get_expr(ad.adbin, ad.adrelid), '')
84        ),
85        E'\n' ORDER BY n.nspname, c.relname, a.attnum
86      ),
87      ''
88    )
89  ) AS fingerprint
90FROM pg_catalog.pg_class c
91JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
92JOIN pg_catalog.pg_attribute a ON a.attrelid = c.oid
93LEFT JOIN pg_catalog.pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum
94WHERE c.relkind IN ('r', 'p', 'v', 'm', 'f')
95  AND a.attnum > 0
96  AND NOT a.attisdropped
97  AND n.nspname = ANY($1::text[])
98"#,
99            &[&schemas],
100        )
101        .await?;
102
103    row.try_get_column::<String>("fingerprint")
104}
105
106pub async fn load_schema_from_db<C: CheckClient>(
107    client: &C,
108    schemas: &[String],
109) -> CheckResult<(DbSchema, String)> {
110    let fingerprint = schema_fingerprint(client, schemas).await?;
111
112    let rows = client
113        .query(
114            r#"
115SELECT
116  n.nspname AS schema_name,
117  c.relname AS table_name,
118  c.relkind AS relkind,
119  a.attname AS column_name,
120  a.attnum::integer AS ordinal,
121  pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
122  a.attnotnull AS not_null,
123  pg_get_expr(ad.adbin, ad.adrelid) AS default_expr
124FROM pg_catalog.pg_class c
125JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
126JOIN pg_catalog.pg_attribute a ON a.attrelid = c.oid
127LEFT JOIN pg_catalog.pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum
128WHERE c.relkind IN ('r', 'p', 'v', 'm', 'f')
129  AND a.attnum > 0
130  AND NOT a.attisdropped
131  AND n.nspname = ANY($1::text[])
132ORDER BY n.nspname, c.relname, a.attnum
133"#,
134            &[&schemas],
135        )
136        .await?;
137
138    use std::collections::BTreeMap;
139    let mut tables: BTreeMap<(String, String), TableInfo> = BTreeMap::new();
140
141    for row in rows {
142        let schema_name: String = row.try_get_column("schema_name")?;
143        let table_name: String = row.try_get_column("table_name")?;
144        let relkind: i8 = row.try_get_column("relkind")?;
145
146        let column_name: String = row.try_get_column("column_name")?;
147        let ordinal: i32 = row.try_get_column("ordinal")?;
148        let data_type: String = row.try_get_column("data_type")?;
149        let not_null: bool = row.try_get_column("not_null")?;
150        let default_expr: Option<String> = row
151            .try_get::<_, Option<String>>("default_expr")
152            .map_err(|e| CheckError::decode("default_expr", e.to_string()))?;
153
154        let key = (schema_name.clone(), table_name.clone());
155
156        let table = tables.entry(key).or_insert_with(|| TableInfo {
157            schema: schema_name,
158            name: table_name,
159            kind: RelationKind::from_relkind(relkind),
160            columns: Vec::new(),
161        });
162
163        table.columns.push(ColumnInfo {
164            name: column_name,
165            data_type,
166            not_null,
167            default_expr,
168            ordinal,
169        });
170    }
171
172    let tables = tables.into_values().collect::<Vec<_>>();
173
174    if tables.is_empty() {
175        return Err(CheckError::Validation(
176            "No tables found in the selected schemas".to_string(),
177        ));
178    }
179
180    Ok((
181        DbSchema {
182            schemas: schemas.to_vec(),
183            tables,
184        },
185        fingerprint,
186    ))
187}