1use crate::client::{CheckClient, RowExt};
2use crate::error::{CheckError, CheckResult};
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
6#[serde(rename_all = "snake_case")]
7pub enum RelationKind {
8 Table,
9 PartitionedTable,
10 View,
11 MaterializedView,
12 ForeignTable,
13 Other,
14}
15
16impl RelationKind {
17 fn from_relkind(relkind: i8) -> Self {
18 match relkind as u8 as char {
20 'r' => Self::Table,
21 'p' => Self::PartitionedTable,
22 'v' => Self::View,
23 'm' => Self::MaterializedView,
24 'f' => Self::ForeignTable,
25 _ => Self::Other,
26 }
27 }
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
31pub struct ColumnInfo {
32 pub name: String,
33 pub data_type: String,
34 pub not_null: bool,
35 pub default_expr: Option<String>,
36 pub ordinal: i32,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
40pub struct TableInfo {
41 pub schema: String,
42 pub name: String,
43 pub kind: RelationKind,
44 pub columns: Vec<ColumnInfo>,
45}
46
47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
48pub struct DbSchema {
49 pub schemas: Vec<String>,
50 pub tables: Vec<TableInfo>,
51}
52
53impl DbSchema {
54 pub fn find_table(&self, schema: &str, table: &str) -> Option<&TableInfo> {
55 self.tables
56 .iter()
57 .find(|t| t.schema == schema && t.name == table)
58 }
59}
60
61pub async fn schema_fingerprint<C: CheckClient>(
62 client: &C,
63 schemas: &[String],
64) -> CheckResult<String> {
65 let row = client
66 .query_one(
67 r#"
68SELECT
69 md5(
70 COALESCE(
71 string_agg(
72 concat_ws(
73 '|',
74 n.nspname,
75 c.relname,
76 c.relkind::text,
77 a.attnum::text,
78 a.attname,
79 pg_catalog.format_type(a.atttypid, a.atttypmod),
80 a.attnotnull::text,
81 COALESCE(a.attidentity::text, ''),
82 COALESCE(a.attgenerated::text, ''),
83 COALESCE(pg_get_expr(ad.adbin, ad.adrelid), '')
84 ),
85 E'\n' ORDER BY n.nspname, c.relname, a.attnum
86 ),
87 ''
88 )
89 ) AS fingerprint
90FROM pg_catalog.pg_class c
91JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
92JOIN pg_catalog.pg_attribute a ON a.attrelid = c.oid
93LEFT JOIN pg_catalog.pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum
94WHERE c.relkind IN ('r', 'p', 'v', 'm', 'f')
95 AND a.attnum > 0
96 AND NOT a.attisdropped
97 AND n.nspname = ANY($1::text[])
98"#,
99 &[&schemas],
100 )
101 .await?;
102
103 row.try_get_column::<String>("fingerprint")
104}
105
106pub async fn load_schema_from_db<C: CheckClient>(
107 client: &C,
108 schemas: &[String],
109) -> CheckResult<(DbSchema, String)> {
110 let fingerprint = schema_fingerprint(client, schemas).await?;
111
112 let rows = client
113 .query(
114 r#"
115SELECT
116 n.nspname AS schema_name,
117 c.relname AS table_name,
118 c.relkind AS relkind,
119 a.attname AS column_name,
120 a.attnum::integer AS ordinal,
121 pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
122 a.attnotnull AS not_null,
123 pg_get_expr(ad.adbin, ad.adrelid) AS default_expr
124FROM pg_catalog.pg_class c
125JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
126JOIN pg_catalog.pg_attribute a ON a.attrelid = c.oid
127LEFT JOIN pg_catalog.pg_attrdef ad ON ad.adrelid = c.oid AND ad.adnum = a.attnum
128WHERE c.relkind IN ('r', 'p', 'v', 'm', 'f')
129 AND a.attnum > 0
130 AND NOT a.attisdropped
131 AND n.nspname = ANY($1::text[])
132ORDER BY n.nspname, c.relname, a.attnum
133"#,
134 &[&schemas],
135 )
136 .await?;
137
138 use std::collections::BTreeMap;
139 let mut tables: BTreeMap<(String, String), TableInfo> = BTreeMap::new();
140
141 for row in rows {
142 let schema_name: String = row.try_get_column("schema_name")?;
143 let table_name: String = row.try_get_column("table_name")?;
144 let relkind: i8 = row.try_get_column("relkind")?;
145
146 let column_name: String = row.try_get_column("column_name")?;
147 let ordinal: i32 = row.try_get_column("ordinal")?;
148 let data_type: String = row.try_get_column("data_type")?;
149 let not_null: bool = row.try_get_column("not_null")?;
150 let default_expr: Option<String> = row
151 .try_get::<_, Option<String>>("default_expr")
152 .map_err(|e| CheckError::decode("default_expr", e.to_string()))?;
153
154 let key = (schema_name.clone(), table_name.clone());
155
156 let table = tables.entry(key).or_insert_with(|| TableInfo {
157 schema: schema_name,
158 name: table_name,
159 kind: RelationKind::from_relkind(relkind),
160 columns: Vec::new(),
161 });
162
163 table.columns.push(ColumnInfo {
164 name: column_name,
165 data_type,
166 not_null,
167 default_expr,
168 ordinal,
169 });
170 }
171
172 let tables = tables.into_values().collect::<Vec<_>>();
173
174 if tables.is_empty() {
175 return Err(CheckError::Validation(
176 "No tables found in the selected schemas".to_string(),
177 ));
178 }
179
180 Ok((
181 DbSchema {
182 schemas: schemas.to_vec(),
183 tables,
184 },
185 fingerprint,
186 ))
187}