use crate::{
CheckConstraint, Column, ForeignKey, Index, IndexColumn, PgType, Result, Schema,
SourceLocation, Table, TriggerCheckConstraint,
};
use indexmap::IndexMap;
#[cfg(test)]
use crate::{NullsOrder, SortOrder};
use tokio_postgres::Client;
pub trait SchemaIntrospect {
fn from_database(client: &Client) -> impl std::future::Future<Output = Result<Schema>> + Send;
}
impl SchemaIntrospect for Schema {
async fn from_database(client: &Client) -> Result<Self> {
let tables = introspect_tables(client).await?;
Ok(Self { tables })
}
}
async fn introspect_tables(client: &Client) -> Result<IndexMap<String, Table>> {
let rows = client
.query(
r#"
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_type = 'BASE TABLE'
AND table_name NOT LIKE '_dibs_%'
AND table_name NOT LIKE '__dibs_%'
ORDER BY table_name
"#,
&[],
)
.await?;
let mut tables = IndexMap::new();
for row in rows {
let table_name: String = row.get(0);
let table = introspect_table(client, &table_name).await?;
tables.insert(table_name, table);
}
Ok(tables)
}
async fn introspect_table(client: &Client, table_name: &str) -> Result<Table> {
let columns = introspect_columns(client, table_name).await?;
let primary_keys = introspect_primary_keys(client, table_name).await?;
let unique_columns = introspect_unique_constraints(client, table_name).await?;
let check_constraints = introspect_check_constraints(client, table_name).await?;
let trigger_checks = introspect_trigger_checks(client, table_name).await?;
let foreign_keys = introspect_foreign_keys(client, table_name).await?;
let indices = introspect_indices(client, table_name).await?;
let columns: Vec<Column> = columns
.into_iter()
.map(|mut col| {
col.primary_key = primary_keys.contains(&col.name);
col.unique = unique_columns.contains(&col.name);
col
})
.collect();
Ok(Table {
name: table_name.to_string(),
columns,
check_constraints,
trigger_checks,
foreign_keys,
indices,
source: SourceLocation::default(), doc: None,
icon: None, })
}
async fn introspect_trigger_checks(
client: &Client,
table_name: &str,
) -> Result<Vec<TriggerCheckConstraint>> {
let rows = client
.query(
r#"
SELECT tg.tgname
FROM pg_trigger tg
JOIN pg_class rel ON rel.oid = tg.tgrelid
JOIN pg_namespace nsp ON nsp.oid = rel.relnamespace
JOIN pg_proc pr ON pr.oid = tg.tgfoid
WHERE nsp.nspname = 'public'
AND rel.relname = $1
AND tg.tgisinternal = false
AND pr.proname LIKE 'trgfn\_%' ESCAPE '\'
ORDER BY tg.tgname
"#,
&[&table_name],
)
.await?;
Ok(rows
.into_iter()
.map(|row| TriggerCheckConstraint {
name: row.get::<_, String>(0),
expr: String::new(),
message: None,
})
.collect())
}
async fn introspect_check_constraints(
client: &Client,
table_name: &str,
) -> Result<Vec<CheckConstraint>> {
let rows = client
.query(
r#"
SELECT
con.conname,
pg_get_expr(con.conbin, con.conrelid) AS expr
FROM pg_constraint con
JOIN pg_class rel ON rel.oid = con.conrelid
JOIN pg_namespace nsp ON nsp.oid = rel.relnamespace
WHERE nsp.nspname = 'public'
AND rel.relname = $1
AND con.contype = 'c'
ORDER BY con.conname
"#,
&[&table_name],
)
.await?;
let mut checks = Vec::new();
for row in rows {
let name: String = row.get(0);
let expr: String = row.get(1);
checks.push(CheckConstraint { name, expr });
}
Ok(checks)
}
async fn introspect_columns(client: &Client, table_name: &str) -> Result<Vec<Column>> {
let rows = client
.query(
r#"
SELECT
column_name,
data_type,
udt_name,
is_nullable,
column_default,
is_identity
FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = $1
ORDER BY ordinal_position
"#,
&[&table_name],
)
.await?;
let mut columns = Vec::new();
for row in rows {
let name: String = row.get(0);
let data_type: String = row.get(1);
let udt_name: String = row.get(2);
let is_nullable: String = row.get(3);
let column_default: Option<String> = row.get(4);
let is_identity: String = row.get(5);
let pg_type = pg_type_from_info_schema(&data_type, &udt_name);
let nullable = is_nullable == "YES";
let default = column_default.map(|d| clean_default_value(&d));
let auto_generated = is_identity == "YES" || is_auto_generated(&default);
columns.push(Column {
name,
pg_type,
rust_type: None, nullable,
default,
primary_key: false, unique: false, auto_generated,
long: false, label: false, enum_variants: vec![], doc: None, lang: None, icon: None, subtype: None, });
}
Ok(columns)
}
async fn introspect_primary_keys(client: &Client, table_name: &str) -> Result<Vec<String>> {
let rows = client
.query(
r#"
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = 'public'
AND tc.table_name = $1
ORDER BY kcu.ordinal_position
"#,
&[&table_name],
)
.await?;
Ok(rows.iter().map(|r| r.get(0)).collect())
}
async fn introspect_unique_constraints(client: &Client, table_name: &str) -> Result<Vec<String>> {
let rows = client
.query(
r#"
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'UNIQUE'
AND tc.table_schema = 'public'
AND tc.table_name = $1
"#,
&[&table_name],
)
.await?;
Ok(rows.iter().map(|r| r.get(0)).collect())
}
#[allow(clippy::type_complexity)]
async fn introspect_foreign_keys(client: &Client, table_name: &str) -> Result<Vec<ForeignKey>> {
let rows = client
.query(
r#"
SELECT
tc.constraint_name,
kcu.column_name,
ccu.table_name AS foreign_table,
ccu.column_name AS foreign_column,
kcu.ordinal_position
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage ccu
ON tc.constraint_name = ccu.constraint_name
AND tc.table_schema = ccu.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = 'public'
AND tc.table_name = $1
ORDER BY tc.constraint_name, kcu.ordinal_position
"#,
&[&table_name],
)
.await?;
let mut fk_map: std::collections::HashMap<String, (ForeignKey, Vec<(i32, String, String)>)> =
std::collections::HashMap::new();
for row in rows {
let constraint_name: String = row.get(0);
let column: String = row.get(1);
let foreign_table: String = row.get(2);
let foreign_column: String = row.get(3);
let ordinal: i32 = row.get(4);
fk_map
.entry(constraint_name)
.or_insert_with(|| {
(
ForeignKey {
columns: Vec::new(),
references_table: foreign_table,
references_columns: Vec::new(),
},
Vec::new(),
)
})
.1
.push((ordinal, column, foreign_column));
}
Ok(fk_map
.into_values()
.map(|(mut fk, mut cols)| {
cols.sort_by_key(|(ord, _, _)| *ord);
for (_, col, ref_col) in cols {
fk.columns.push(col);
if !fk.references_columns.contains(&ref_col) {
fk.references_columns.push(ref_col);
}
}
fk
})
.collect())
}
async fn introspect_indices(client: &Client, table_name: &str) -> Result<Vec<Index>> {
let rows = client
.query(
r#"
SELECT
i.indexname,
i.indexdef
FROM pg_indexes i
WHERE i.schemaname = 'public'
AND i.tablename = $1
AND NOT EXISTS (
SELECT 1 FROM information_schema.table_constraints tc
WHERE tc.constraint_name = i.indexname
AND tc.table_schema = 'public'
)
"#,
&[&table_name],
)
.await?;
let mut indices = Vec::new();
for row in rows {
let name: String = row.get(0);
let indexdef: String = row.get(1);
let unique = indexdef.to_uppercase().contains("UNIQUE");
let columns = parse_index_columns(&indexdef);
let where_clause = parse_index_where_clause(&indexdef);
indices.push(Index {
name,
columns,
unique,
where_clause,
});
}
Ok(indices)
}
fn parse_index_columns(indexdef: &str) -> Vec<IndexColumn> {
let indexdef_upper = indexdef.to_uppercase();
let where_pos = indexdef_upper.find(" WHERE ");
let search_str = if let Some(pos) = where_pos {
&indexdef[..pos]
} else {
indexdef
};
if let Some(start) = search_str.rfind('(')
&& let Some(end) = search_str.rfind(')')
{
let cols_str = &search_str[start + 1..end];
return cols_str.split(',').map(IndexColumn::parse).collect();
}
Vec::new()
}
fn parse_index_where_clause(indexdef: &str) -> Option<String> {
let indexdef_upper = indexdef.to_uppercase();
if let Some(where_pos) = indexdef_upper.find(" WHERE ") {
let where_clause = &indexdef[where_pos + 7..]; let trimmed = where_clause.trim();
if trimmed.starts_with('(') && trimmed.ends_with(')') {
Some(trimmed[1..trimmed.len() - 1].to_string())
} else {
Some(trimmed.to_string())
}
} else {
None
}
}
fn pg_type_from_info_schema(data_type: &str, udt_name: &str) -> PgType {
match data_type.to_uppercase().as_str() {
"SMALLINT" => PgType::SmallInt,
"INTEGER" => PgType::Integer,
"BIGINT" => PgType::BigInt,
"REAL" => PgType::Real,
"DOUBLE PRECISION" => PgType::DoublePrecision,
"NUMERIC" | "DECIMAL" => PgType::Numeric,
"BOOLEAN" => PgType::Boolean,
"TEXT" => PgType::Text,
"BYTEA" => PgType::Bytea,
"DATE" => PgType::Date,
"TIME WITHOUT TIME ZONE" | "TIME" => PgType::Time,
"TIMESTAMP WITH TIME ZONE" | "TIMESTAMP WITHOUT TIME ZONE" | "TIMESTAMP" => {
PgType::Timestamptz
}
"UUID" => PgType::Uuid,
"JSONB" => PgType::Jsonb,
"USER-DEFINED" => {
match udt_name {
"uuid" => PgType::Uuid,
"jsonb" => PgType::Jsonb,
_ => PgType::Text, }
}
"CHARACTER VARYING" | "VARCHAR" | "CHAR" | "CHARACTER" => PgType::Text,
"ARRAY" => {
match udt_name {
"_text" | "_varchar" => PgType::TextArray,
"_int8" => PgType::BigIntArray,
"_int4" => PgType::IntegerArray,
_ => PgType::Jsonb, }
}
_ => {
match udt_name {
"int2" => PgType::SmallInt,
"int4" => PgType::Integer,
"int8" => PgType::BigInt,
"float4" => PgType::Real,
"float8" => PgType::DoublePrecision,
"numeric" => PgType::Numeric,
"bool" => PgType::Boolean,
"text" | "varchar" | "bpchar" => PgType::Text,
"bytea" => PgType::Bytea,
"timestamptz" | "timestamp" => PgType::Timestamptz,
"date" => PgType::Date,
"time" => PgType::Time,
"uuid" => PgType::Uuid,
"jsonb" => PgType::Jsonb,
_ => PgType::Text, }
}
}
}
fn clean_default_value(default: &str) -> String {
let s = default.trim();
if let Some(idx) = s.find("::") {
return s[..idx].to_string();
}
s.to_string()
}
fn is_auto_generated(default: &Option<String>) -> bool {
let Some(def) = default else {
return false;
};
let lower = def.to_lowercase();
if lower.contains("nextval(") {
return true;
}
if lower.contains("gen_random_uuid()") || lower.contains("uuid_generate_v") {
return true;
}
if lower.contains("now()") || lower.contains("current_timestamp") {
return true;
}
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_index_columns() {
assert_eq!(
parse_index_columns("CREATE INDEX idx_users_name ON public.users USING btree (name)"),
vec![IndexColumn::new("name")]
);
assert_eq!(
parse_index_columns(
"CREATE UNIQUE INDEX idx_users_email ON public.users USING btree (email)"
),
vec![IndexColumn::new("email")]
);
assert_eq!(
parse_index_columns(
"CREATE INDEX idx_posts_author ON public.posts USING btree (author_id, created_at)"
),
vec![
IndexColumn::new("author_id"),
IndexColumn::new("created_at")
]
);
assert_eq!(
parse_index_columns(
"CREATE UNIQUE INDEX uq_product_primary ON public.product_category USING btree (product_id) WHERE (is_primary = true)"
),
vec![IndexColumn::new("product_id")]
);
assert_eq!(
parse_index_columns(
"CREATE INDEX idx_versions ON public.product_version USING btree (product_id, synced_at DESC)"
),
vec![
IndexColumn::new("product_id"),
IndexColumn {
name: "synced_at".to_string(),
order: SortOrder::Desc,
nulls: NullsOrder::Default,
}
]
);
assert_eq!(
parse_index_columns(
"CREATE INDEX idx_test ON public.test USING btree (col1 ASC, col2 DESC)"
),
vec![
IndexColumn::new("col1"),
IndexColumn {
name: "col2".to_string(),
order: SortOrder::Desc,
nulls: NullsOrder::Default,
}
]
);
assert_eq!(
parse_index_columns(
"CREATE INDEX idx_reminder ON public.cart USING btree (reminder_sent_at NULLS FIRST)"
),
vec![IndexColumn {
name: "reminder_sent_at".to_string(),
order: SortOrder::Asc,
nulls: NullsOrder::First,
}]
);
assert_eq!(
parse_index_columns(
"CREATE INDEX idx_test ON public.test USING btree (col DESC NULLS LAST)"
),
vec![IndexColumn {
name: "col".to_string(),
order: SortOrder::Desc,
nulls: NullsOrder::Last,
}]
);
}
#[test]
fn test_parse_index_where_clause() {
assert_eq!(
parse_index_where_clause(
"CREATE INDEX idx_users_name ON public.users USING btree (name)"
),
None
);
assert_eq!(
parse_index_where_clause(
"CREATE UNIQUE INDEX uq_product_primary ON public.product_category USING btree (product_id) WHERE (is_primary = true)"
),
Some("is_primary = true".to_string())
);
assert_eq!(
parse_index_where_clause(
"CREATE UNIQUE INDEX uq_discount_applied ON public.discount_redemption USING btree (order_id) WHERE ((status)::text = 'applied'::text)"
),
Some("(status)::text = 'applied'::text".to_string())
);
assert_eq!(
parse_index_where_clause(
"CREATE UNIQUE INDEX uq_test ON public.test USING btree (col) WHERE is_active"
),
Some("is_active".to_string())
);
}
#[test]
fn test_clean_default_value() {
assert_eq!(clean_default_value("'foo'::text"), "'foo'");
assert_eq!(clean_default_value("0::bigint"), "0");
assert_eq!(clean_default_value("now()"), "now()");
assert_eq!(clean_default_value(" 42 "), "42");
}
#[test]
fn test_pg_type_from_info_schema() {
assert_eq!(pg_type_from_info_schema("BIGINT", "int8"), PgType::BigInt);
assert_eq!(pg_type_from_info_schema("TEXT", "text"), PgType::Text);
assert_eq!(pg_type_from_info_schema("BOOLEAN", "bool"), PgType::Boolean);
assert_eq!(
pg_type_from_info_schema("USER-DEFINED", "uuid"),
PgType::Uuid
);
assert_eq!(
pg_type_from_info_schema("CHARACTER VARYING", "varchar"),
PgType::Text
);
}
}