use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[cfg(feature = "postgres")]
use sqlx::{PgPool, Row};
#[cfg(any(feature = "postgres", feature = "wasm"))]
pub const FK_INTROSPECTION_QUERY: &str = r#"
SELECT
con.conname AS constraint_name,
sn.nspname AS from_schema,
sc.relname AS from_table,
sa.attname AS from_column,
tn.nspname AS to_schema,
tc.relname AS to_table,
ta.attname AS to_column
FROM pg_constraint con
JOIN pg_class sc ON sc.oid = con.conrelid
JOIN pg_namespace sn ON sn.oid = sc.relnamespace
JOIN pg_class tc ON tc.oid = con.confrelid
JOIN pg_namespace tn ON tn.oid = tc.relnamespace
JOIN pg_attribute sa ON sa.attrelid = sc.oid AND sa.attnum = con.conkey[1]
JOIN pg_attribute ta ON ta.attrelid = tc.oid AND ta.attnum = con.confkey[1]
WHERE con.contype = 'f'
AND sn.nspname NOT IN ('pg_catalog', 'information_schema')
AND array_length(con.conkey, 1) = 1
ORDER BY sn.nspname, sc.relname, con.conname
"#;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ForeignKey {
pub from_schema: String,
pub from_table: String,
pub from_column: String,
pub to_schema: String,
pub to_table: String,
pub to_column: String,
pub constraint_name: String,
}
impl ForeignKey {
#[cfg(test)]
pub fn test(from_table: &str, from_col: &str, to_table: &str, to_col: &str) -> Self {
Self {
from_schema: "public".to_string(),
from_table: from_table.to_string(),
from_column: from_col.to_string(),
to_schema: "public".to_string(),
to_table: to_table.to_string(),
to_column: to_col.to_string(),
constraint_name: format!("{}_{}_fkey", from_table, from_col),
}
}
pub fn links(&self, from_table: &str, to_table: &str) -> bool {
self.from_table == from_table && self.to_table == to_table
}
pub fn join_condition(&self, from_alias: &str, to_alias: &str) -> String {
format!(
"\"{}\".\"{}\" = \"{}\".\"{}\"",
from_alias, &self.from_column, to_alias, &self.to_column
)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RelationType {
ManyToOne,
OneToMany,
ManyToMany {
junction_table: String,
},
}
#[derive(Debug, Clone)]
pub struct Relationship {
pub from_table: String,
pub to_table: String,
pub foreign_key: ForeignKey,
pub relation_type: RelationType,
}
#[derive(Debug, Clone, Default)]
pub struct SchemaCache {
foreign_keys: HashMap<(String, String), Vec<ForeignKey>>,
reverse_fks: HashMap<(String, String), Vec<ForeignKey>>,
}
impl SchemaCache {
pub fn new() -> Self {
Self::default()
}
pub fn from_foreign_keys(fks: Vec<ForeignKey>) -> Self {
let mut cache = Self::new();
for fk in fks {
cache
.foreign_keys
.entry((fk.from_schema.clone(), fk.from_table.clone()))
.or_default()
.push(fk.clone());
cache
.reverse_fks
.entry((fk.to_schema.clone(), fk.to_table.clone()))
.or_default()
.push(fk);
}
cache
}
#[cfg(feature = "postgres")]
pub async fn load_from_database(pool: &PgPool) -> Result<Self, sqlx::Error> {
let rows = sqlx::query(FK_INTROSPECTION_QUERY)
.fetch_all(pool)
.await?;
let fks: Vec<ForeignKey> = rows
.iter()
.map(|row| ForeignKey {
from_schema: row.get("from_schema"),
from_table: row.get("from_table"),
from_column: row.get("from_column"),
to_schema: row.get("to_schema"),
to_table: row.get("to_table"),
to_column: row.get("to_column"),
constraint_name: row.get("constraint_name"),
})
.collect();
Ok(Self::from_foreign_keys(fks))
}
pub fn find_relationship(
&self,
from_schema: &str,
from_table: &str,
to_table: &str,
) -> Option<Relationship> {
if let Some(fks) = self
.foreign_keys
.get(&(from_schema.to_string(), from_table.to_string()))
{
if let Some(fk) = fks.iter().find(|fk| fk.to_table == to_table) {
return Some(Relationship {
from_table: from_table.to_string(),
to_table: to_table.to_string(),
foreign_key: fk.clone(),
relation_type: RelationType::ManyToOne,
});
}
}
if let Some(fks) = self
.reverse_fks
.get(&(from_schema.to_string(), from_table.to_string()))
{
if let Some(fk) = fks.iter().find(|fk| fk.from_table == to_table) {
return Some(Relationship {
from_table: from_table.to_string(),
to_table: to_table.to_string(),
foreign_key: fk.clone(),
relation_type: RelationType::OneToMany,
});
}
}
None
}
pub fn get_foreign_keys(&self, schema: &str, table: &str) -> Vec<&ForeignKey> {
self.foreign_keys
.get(&(schema.to_string(), table.to_string()))
.map(|fks| fks.iter().collect())
.unwrap_or_default()
}
pub fn get_referencing_tables(&self, schema: &str, table: &str) -> Vec<&ForeignKey> {
self.reverse_fks
.get(&(schema.to_string(), table.to_string()))
.map(|fks| fks.iter().collect())
.unwrap_or_default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_foreign_key_links() {
let fk = ForeignKey::test("orders", "customer_id", "customers", "id");
assert!(fk.links("orders", "customers"));
assert!(!fk.links("customers", "orders"));
}
#[test]
fn test_join_condition() {
let fk = ForeignKey::test("orders", "customer_id", "customers", "id");
let condition = fk.join_condition("orders", "customers");
assert_eq!(condition, r#""orders"."customer_id" = "customers"."id""#);
}
#[test]
fn test_schema_cache_empty() {
let cache = SchemaCache::new();
assert_eq!(cache.get_foreign_keys("public", "users").len(), 0);
}
#[test]
fn test_from_foreign_keys_indexes_forward_and_reverse() {
let fks = vec![
ForeignKey::test("orders", "customer_id", "customers", "id"),
ForeignKey::test("order_items", "order_id", "orders", "id"),
];
let cache = SchemaCache::from_foreign_keys(fks);
let order_fks = cache.get_foreign_keys("public", "orders");
assert_eq!(order_fks.len(), 1);
assert_eq!(order_fks[0].to_table, "customers");
let item_fks = cache.get_foreign_keys("public", "order_items");
assert_eq!(item_fks.len(), 1);
assert_eq!(item_fks[0].to_table, "orders");
let cust_refs = cache.get_referencing_tables("public", "customers");
assert_eq!(cust_refs.len(), 1);
assert_eq!(cust_refs[0].from_table, "orders");
let order_refs = cache.get_referencing_tables("public", "orders");
assert_eq!(order_refs.len(), 1);
assert_eq!(order_refs[0].from_table, "order_items");
}
#[test]
fn test_from_foreign_keys_empty() {
let cache = SchemaCache::from_foreign_keys(vec![]);
assert_eq!(cache.get_foreign_keys("public", "anything").len(), 0);
assert_eq!(cache.get_referencing_tables("public", "anything").len(), 0);
}
#[test]
fn test_from_foreign_keys_multiple_fks_same_table() {
let fks = vec![
ForeignKey::test("transfers", "from_account_id", "accounts", "id"),
ForeignKey::test("transfers", "to_account_id", "accounts", "id"),
];
let cache = SchemaCache::from_foreign_keys(fks);
let transfer_fks = cache.get_foreign_keys("public", "transfers");
assert_eq!(transfer_fks.len(), 2);
let account_refs = cache.get_referencing_tables("public", "accounts");
assert_eq!(account_refs.len(), 2);
}
#[test]
fn test_find_relationship_many_to_one() {
let fks = vec![ForeignKey::test("orders", "customer_id", "customers", "id")];
let cache = SchemaCache::from_foreign_keys(fks);
let rel = cache.find_relationship("public", "orders", "customers");
assert!(rel.is_some());
let rel = rel.unwrap();
assert_eq!(rel.relation_type, RelationType::ManyToOne);
assert_eq!(rel.foreign_key.from_column, "customer_id");
}
#[test]
fn test_find_relationship_one_to_many() {
let fks = vec![ForeignKey::test("orders", "customer_id", "customers", "id")];
let cache = SchemaCache::from_foreign_keys(fks);
let rel = cache.find_relationship("public", "customers", "orders");
assert!(rel.is_some());
let rel = rel.unwrap();
assert_eq!(rel.relation_type, RelationType::OneToMany);
}
#[test]
fn test_find_relationship_not_found() {
let fks = vec![ForeignKey::test("orders", "customer_id", "customers", "id")];
let cache = SchemaCache::from_foreign_keys(fks);
assert!(cache.find_relationship("public", "customers", "products").is_none());
}
}