use std::collections::HashMap;
pub trait TableMeta {
fn table_name() -> &'static str;
fn schema_name() -> &'static str {
"public"
}
fn columns() -> &'static [&'static str];
fn primary_key() -> Option<&'static str> {
None
}
}
#[derive(Debug, Clone)]
pub struct ColumnMeta {
pub name: String,
pub is_primary_key: bool,
}
#[derive(Debug, Clone)]
pub struct TableSchema {
pub schema: String,
pub name: String,
pub columns: Vec<ColumnMeta>,
}
impl TableSchema {
pub fn new(schema: impl Into<String>, name: impl Into<String>) -> Self {
Self {
schema: schema.into(),
name: name.into(),
columns: Vec::new(),
}
}
pub fn add_column(&mut self, name: impl Into<String>, is_primary_key: bool) {
self.columns.push(ColumnMeta {
name: name.into(),
is_primary_key,
});
}
pub fn with_columns(mut self, columns: &[&str]) -> Self {
for col in columns {
self.columns.push(ColumnMeta {
name: col.to_string(),
is_primary_key: false,
});
}
self
}
pub fn with_primary_key(mut self, pk: &str) -> Self {
for col in &mut self.columns {
col.is_primary_key = col.name == pk;
}
if !self.columns.iter().any(|c| c.name == pk) {
self.columns.push(ColumnMeta {
name: pk.to_string(),
is_primary_key: true,
});
}
self
}
pub fn has_column(&self, name: &str) -> bool {
self.columns.iter().any(|c| c.name == name)
}
}
#[derive(Debug, Clone)]
pub struct SchemaRegistry {
tables: HashMap<(String, String), TableSchema>,
#[cfg(feature = "check")]
parse_cache: std::sync::Arc<pgorm_check::SqlParseCache>,
}
impl SchemaRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register<T: TableMeta>(&mut self) {
let schema_name = T::schema_name().to_string();
let table_name = T::table_name().to_string();
let columns = T::columns();
let pk = T::primary_key();
let mut table = TableSchema::new(&schema_name, &table_name);
for col in columns {
let is_pk = pk == Some(*col);
table.add_column(*col, is_pk);
}
self.tables.insert((schema_name, table_name), table);
}
pub fn register_table(&mut self, table: TableSchema) {
let key = (table.schema.clone(), table.name.clone());
self.tables.insert(key, table);
}
pub fn get_table(&self, schema: &str, name: &str) -> Option<&TableSchema> {
self.tables.get(&(schema.to_string(), name.to_string()))
}
pub fn find_table(&self, name: &str) -> Option<&TableSchema> {
if let Some(t) = self.get_table("public", name) {
return Some(t);
}
self.tables.values().find(|t| t.name == name)
}
pub fn has_table(&self, schema: &str, name: &str) -> bool {
self.tables
.contains_key(&(schema.to_string(), name.to_string()))
}
pub fn tables(&self) -> impl Iterator<Item = &TableSchema> {
self.tables.values()
}
pub fn len(&self) -> usize {
self.tables.len()
}
pub fn is_empty(&self) -> bool {
self.tables.is_empty()
}
}
impl Default for SchemaRegistry {
fn default() -> Self {
Self {
tables: HashMap::new(),
#[cfg(feature = "check")]
parse_cache: std::sync::Arc::new(pgorm_check::SqlParseCache::default()),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SchemaIssueLevel {
Info,
Warning,
Error,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SchemaIssueKind {
ParseError,
MissingTable,
MissingColumn,
AmbiguousColumn,
Unsupported,
}
#[derive(Debug, Clone)]
pub struct SchemaIssue {
pub level: SchemaIssueLevel,
pub kind: SchemaIssueKind,
pub message: String,
}
impl std::fmt::Display for SchemaIssue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?} {:?}: {}", self.level, self.kind, self.message)
}
}
#[cfg(feature = "check")]
#[allow(unused_imports)]
pub use pgorm_check::{
CheckClient,
CheckError,
CheckResult,
ColumnInfo,
ColumnRef,
ColumnRefFull,
DbSchema,
InsertAnalysis,
LintIssue,
LintLevel,
LintResult,
OnConflictAnalysis,
ParseResult,
RelationKind,
SchemaCache,
SchemaCacheConfig,
SchemaCacheLoad,
SqlAnalysis,
SqlCheckIssue,
SqlCheckIssueKind,
SqlCheckLevel,
SqlParseCache,
StatementKind,
TableInfo,
TargetColumn,
UpdateAnalysis,
check_sql,
check_sql_analysis,
check_sql_cached,
delete_has_where,
detect_statement_kind,
get_column_refs,
get_table_names,
is_valid_sql,
lint_select_many,
lint_sql,
schema_introspect::load_schema_from_db,
select_has_limit,
select_has_star,
update_has_where,
};
#[cfg(feature = "check")]
impl SchemaRegistry {
pub(crate) fn analyze_sql(&self, sql: &str) -> std::sync::Arc<SqlAnalysis> {
self.parse_cache.analyze(sql)
}
pub fn with_parse_cache_capacity(mut self, capacity: usize) -> Self {
self.parse_cache = std::sync::Arc::new(SqlParseCache::new(capacity));
self
}
pub fn check_sql(&self, sql: &str) -> Vec<SchemaIssue> {
let mut issues = Vec::new();
let analysis = self.parse_cache.analyze(sql);
if !analysis.parse_result.valid {
issues.push(SchemaIssue {
level: SchemaIssueLevel::Error,
kind: SchemaIssueKind::ParseError,
message: format!(
"SQL syntax error: {}",
analysis.parse_result.error.clone().unwrap_or_default()
),
});
return issues;
}
fn is_system_column(col: &str) -> bool {
matches!(col, "ctid" | "xmin" | "xmax" | "cmin" | "cmax" | "tableoid")
}
let mut qualifier_to_table: std::collections::HashMap<String, &TableSchema> =
std::collections::HashMap::new();
let mut visible_tables: Vec<&TableSchema> = Vec::new();
for rv in &analysis.range_vars {
if analysis.cte_names.contains(&rv.table) {
continue;
}
let rel_schema = rv.schema.as_deref();
let rel_name = rv.table.as_str();
let qualifier = rv.alias.as_deref().unwrap_or(rel_name);
let table = if let Some(s) = rel_schema {
self.get_table(s, rel_name)
} else {
self.find_table(rel_name)
};
match table {
Some(t) => {
if qualifier_to_table
.insert(qualifier.to_string(), t)
.is_none()
{
visible_tables.push(t);
}
}
None => {
let name = match rel_schema {
Some(s) => format!("{s}.{rel_name}"),
None => rel_name.to_string(),
};
issues.push(SchemaIssue {
level: SchemaIssueLevel::Error,
kind: SchemaIssueKind::MissingTable,
message: format!("Table not found in registry: {name}"),
});
}
}
}
if let Some(insert) = &analysis.insert {
if let Some(target) = &insert.target {
let table = if let Some(s) = target.schema.as_deref() {
self.get_table(s, &target.table)
} else {
self.find_table(&target.table)
};
if let Some(t) = table {
for col in &insert.columns {
if is_system_column(col.name.as_str()) {
continue;
}
if !t.has_column(&col.name) {
issues.push(SchemaIssue {
level: SchemaIssueLevel::Error,
kind: SchemaIssueKind::MissingColumn,
message: format!(
"Column not found: {}.{} (INSERT target table '{}')",
t.name, col.name, t.name
),
});
}
}
if let Some(oc) = &insert.on_conflict {
if oc.has_inference_expressions {
issues.push(SchemaIssue {
level: SchemaIssueLevel::Warning,
kind: SchemaIssueKind::Unsupported,
message: "ON CONFLICT inference uses expressions; only simple column targets are checked".to_string(),
});
}
for col in &oc.inference_columns {
if is_system_column(col.name.as_str()) {
continue;
}
if !t.has_column(&col.name) {
issues.push(SchemaIssue {
level: SchemaIssueLevel::Error,
kind: SchemaIssueKind::MissingColumn,
message: format!(
"Column not found: {}.{} (ON CONFLICT target table '{}')",
t.name, col.name, t.name
),
});
}
}
for col in &oc.update_set_columns {
if is_system_column(col.name.as_str()) {
continue;
}
if !t.has_column(&col.name) {
issues.push(SchemaIssue {
level: SchemaIssueLevel::Error,
kind: SchemaIssueKind::MissingColumn,
message: format!(
"Column not found: {}.{} (ON CONFLICT DO UPDATE SET on table '{}')",
t.name, col.name, t.name
),
});
}
}
}
} else {
let name = match target.schema.as_deref() {
Some(s) => format!("{s}.{}", target.table),
None => target.table.clone(),
};
issues.push(SchemaIssue {
level: SchemaIssueLevel::Error,
kind: SchemaIssueKind::MissingTable,
message: format!("Table not found in registry: {name}"),
});
}
}
}
if let Some(update) = &analysis.update {
if let Some(target) = &update.target {
let table = if let Some(s) = target.schema.as_deref() {
self.get_table(s, &target.table)
} else {
self.find_table(&target.table)
};
if let Some(t) = table {
for col in &update.set_columns {
if is_system_column(col.name.as_str()) {
continue;
}
if !t.has_column(&col.name) {
issues.push(SchemaIssue {
level: SchemaIssueLevel::Error,
kind: SchemaIssueKind::MissingColumn,
message: format!(
"Column not found: {}.{} (UPDATE target table '{}')",
t.name, col.name, t.name
),
});
}
}
} else {
let name = match target.schema.as_deref() {
Some(s) => format!("{s}.{}", target.table),
None => target.table.clone(),
};
issues.push(SchemaIssue {
level: SchemaIssueLevel::Error,
kind: SchemaIssueKind::MissingTable,
message: format!("Table not found in registry: {name}"),
});
}
}
}
for c in &analysis.column_refs {
if c.has_star || c.parts.is_empty() {
continue;
}
if c.parts.len() == 1 {
let col = c.parts[0].as_str();
if is_system_column(col) {
continue;
}
let matches = visible_tables.iter().filter(|t| t.has_column(col)).count();
match matches {
0 => {
if !visible_tables.is_empty() {
issues.push(SchemaIssue {
level: SchemaIssueLevel::Error,
kind: SchemaIssueKind::MissingColumn,
message: format!(
"Column not found: {col} (not in any referenced tables)"
),
});
}
}
1 => {}
_ => issues.push(SchemaIssue {
level: SchemaIssueLevel::Error,
kind: SchemaIssueKind::AmbiguousColumn,
message: format!(
"Ambiguous column reference: {col} (found in multiple tables)"
),
}),
}
continue;
}
if c.parts.len() == 2 {
let qualifier = c.parts[0].as_str();
let col = c.parts[1].as_str();
if is_system_column(col) {
continue;
}
if let Some(t) = qualifier_to_table.get(qualifier) {
if !t.has_column(col) {
issues.push(SchemaIssue {
level: SchemaIssueLevel::Error,
kind: SchemaIssueKind::MissingColumn,
message: format!(
"Column not found: {qualifier}.{col} (table resolved to '{}')",
t.name
),
});
}
} else {
issues.push(SchemaIssue {
level: SchemaIssueLevel::Error,
kind: SchemaIssueKind::MissingTable,
message: format!("Unknown table/alias qualifier: {qualifier}"),
});
}
continue;
}
if c.parts.len() == 3 || c.parts.len() == 4 {
let (schema_part, table_part, col_part) = if c.parts.len() == 3 {
(&c.parts[0], &c.parts[1], &c.parts[2])
} else {
(&c.parts[1], &c.parts[2], &c.parts[3])
};
if is_system_column(col_part.as_str()) {
continue;
}
let Some(t) = self.get_table(schema_part, table_part) else {
issues.push(SchemaIssue {
level: SchemaIssueLevel::Error,
kind: SchemaIssueKind::MissingTable,
message: format!("Table not found in registry: {schema_part}.{table_part}"),
});
continue;
};
if !t.has_column(col_part) {
issues.push(SchemaIssue {
level: SchemaIssueLevel::Error,
kind: SchemaIssueKind::MissingColumn,
message: format!("Column not found: {schema_part}.{table_part}.{col_part}"),
});
}
continue;
}
issues.push(SchemaIssue {
level: SchemaIssueLevel::Warning,
kind: SchemaIssueKind::Unsupported,
message: format!(
"Unsupported column reference form ({} parts): {}",
c.parts.len(),
c.parts.join(".")
),
});
}
issues
}
pub fn lint(&self, sql: &str) -> LintResult {
lint_sql(sql)
}
pub fn is_valid(&self, sql: &str) -> bool {
is_valid_sql(sql).valid
}
}
#[macro_export]
macro_rules! check_models {
($registry:expr, $($model:ty),+ $(,)?) => {{
let mut results: Vec<(&'static str, std::collections::HashMap<&'static str, Vec<$crate::SchemaIssue>>)> = Vec::new();
$(
results.push((stringify!($model), <$model>::check_schema(&$registry)));
)+
results
}};
}
#[macro_export]
macro_rules! assert_models_valid {
($registry:expr, $($model:ty),+ $(,)?) => {{
let mut all_issues: Vec<(&'static str, Vec<String>)> = Vec::new();
$(
let issues = <$model>::check_schema(&$registry);
if !issues.is_empty() {
let messages: Vec<String> = issues
.iter()
.flat_map(|(sql_name, issue_list)| {
issue_list.iter().map(move |i| format!("{}: {}", sql_name, i.message))
})
.collect();
all_issues.push((stringify!($model), messages));
}
)+
if !all_issues.is_empty() {
let mut msg = String::from("Schema validation failed:\n");
for (model, issues) in &all_issues {
msg.push_str(&format!("\n{}:\n", model));
for issue in issues {
msg.push_str(&format!(" - {}\n", issue));
}
}
panic!("{}", msg);
}
}};
}
#[macro_export]
macro_rules! print_model_check {
($registry:expr, $($model:ty),+ $(,)?) => {{
println!("Model Schema Validation:");
let mut all_valid = true;
$(
let issues = <$model>::check_schema(&$registry);
if issues.is_empty() {
println!(" ✓ {}", stringify!($model));
} else {
all_valid = false;
let total: usize = issues.values().map(|v| v.len()).sum();
println!(" ✗ {} ({} issues)", stringify!($model), total);
for (sql_name, issue_list) in &issues {
for issue in issue_list {
println!(" {}: {:?} - {}", sql_name, issue.kind, issue.message);
}
}
}
)+
all_valid
}};
}
#[macro_export]
macro_rules! check_models_db {
($client:expr, $($model:ty),+ $(,)?) => {{
async {
let db_schema = $client.load_db_schema().await?;
let mut results: Vec<$crate::ModelCheckResult> = Vec::new();
$(
results.push($crate::ModelCheckResult::check::<$model>(&db_schema));
)+
Ok::<_, $crate::OrmError>(results)
}
}};
}
#[macro_export]
macro_rules! print_models_db_check {
($client:expr, $($model:ty),+ $(,)?) => {{
async {
let db_schema = $client.load_db_schema().await?;
println!("Model Database Validation:");
let mut all_valid = true;
$(
let result = $crate::ModelCheckResult::check::<$model>(&db_schema);
if !result.is_valid() {
all_valid = false;
}
result.print();
)+
Ok::<_, $crate::OrmError>(all_valid)
}
}};
}
#[macro_export]
macro_rules! assert_models_db_valid {
($client:expr, $($model:ty),+ $(,)?) => {{
async {
let db_schema = $client.load_db_schema().await?;
let mut errors: Vec<String> = Vec::new();
$(
let result = $crate::ModelCheckResult::check::<$model>(&db_schema);
if !result.table_found {
errors.push(format!("{}: table '{}' not found", result.model, result.table));
} else if !result.missing_in_db.is_empty() {
errors.push(format!("{}: missing columns {:?}", result.model, result.missing_in_db));
}
)+
if !errors.is_empty() {
panic!("Schema validation failed:\n {}", errors.join("\n "));
}
Ok::<_, $crate::OrmError>(())
}
}};
}
#[cfg(test)]
mod tests {
use super::*;
struct TestUser;
impl TableMeta for TestUser {
fn table_name() -> &'static str {
"users"
}
fn columns() -> &'static [&'static str] {
&["id", "name", "email", "created_at"]
}
fn primary_key() -> Option<&'static str> {
Some("id")
}
}
struct TestOrder;
impl TableMeta for TestOrder {
fn table_name() -> &'static str {
"orders"
}
fn columns() -> &'static [&'static str] {
&["id", "user_id", "total", "status"]
}
fn primary_key() -> Option<&'static str> {
Some("id")
}
}
#[test]
fn test_register_table() {
let mut registry = SchemaRegistry::new();
registry.register::<TestUser>();
registry.register::<TestOrder>();
assert_eq!(registry.len(), 2);
assert!(registry.has_table("public", "users"));
assert!(registry.has_table("public", "orders"));
assert!(!registry.has_table("public", "products"));
}
#[test]
fn test_find_table() {
let mut registry = SchemaRegistry::new();
registry.register::<TestUser>();
let table = registry.find_table("users").unwrap();
assert_eq!(table.name, "users");
assert!(table.has_column("id"));
assert!(table.has_column("name"));
assert!(!table.has_column("nonexistent"));
}
#[test]
fn test_table_schema_builder() {
let table = TableSchema::new("public", "products")
.with_columns(&["id", "name", "price"])
.with_primary_key("id");
assert_eq!(table.name, "products");
assert!(table.has_column("id"));
assert!(table.has_column("name"));
assert!(table.has_column("price"));
let pk_col = table.columns.iter().find(|c| c.is_primary_key).unwrap();
assert_eq!(pk_col.name, "id");
}
#[cfg(feature = "check")]
mod check_tests {
use super::*;
#[test]
fn test_is_valid_sql() {
assert!(is_valid_sql("SELECT * FROM users").valid);
assert!(!is_valid_sql("SELEC * FROM users").valid);
}
#[test]
fn test_detect_statement_kind() {
assert_eq!(
detect_statement_kind("SELECT * FROM users"),
Some(StatementKind::Select)
);
assert_eq!(
detect_statement_kind("DELETE FROM users"),
Some(StatementKind::Delete)
);
assert_eq!(
detect_statement_kind("UPDATE users SET name = 'foo'"),
Some(StatementKind::Update)
);
}
#[test]
fn test_lint_sql() {
let result = lint_sql("DELETE FROM users");
assert!(result.has_errors());
let result = lint_sql("DELETE FROM users WHERE id = 1");
assert!(!result.has_errors());
}
#[test]
fn test_check_sql_schema() {
let mut registry = SchemaRegistry::new();
registry.register::<TestUser>();
registry.register::<TestOrder>();
let issues = registry.check_sql("SELECT * FROM users");
assert!(issues.is_empty());
let issues = registry.check_sql("SELECT * FROM products");
assert!(!issues.is_empty());
assert!(matches!(issues[0].kind, SchemaIssueKind::MissingTable));
}
#[test]
fn test_check_sql_alias_and_ambiguous_column() {
let mut registry = SchemaRegistry::new();
registry.register::<TestUser>();
registry.register::<TestOrder>();
let issues = registry.check_sql(
"SELECT u.id FROM users u JOIN orders o ON u.id = o.user_id WHERE o.status = 'paid'",
);
assert!(issues.is_empty());
let issues =
registry.check_sql("SELECT id FROM users u JOIN orders o ON u.id = o.user_id");
assert!(
issues
.iter()
.any(|i| i.kind == SchemaIssueKind::AmbiguousColumn)
);
}
#[test]
fn test_check_sql_insert_update_on_conflict_columns() {
let mut registry = SchemaRegistry::new();
registry.register::<TestUser>();
let issues = registry.check_sql("INSERT INTO users (id, missing_col) VALUES (1, 'x')");
assert!(
issues
.iter()
.any(|i| i.kind == SchemaIssueKind::MissingColumn)
);
let issues = registry.check_sql("UPDATE users SET missing_col = 1 WHERE id = 1");
assert!(
issues
.iter()
.any(|i| i.kind == SchemaIssueKind::MissingColumn)
);
let issues = registry.check_sql(
"INSERT INTO users (id, name) VALUES (1, 'a') ON CONFLICT (id) DO UPDATE SET missing_col = EXCLUDED.name",
);
assert!(
issues
.iter()
.any(|i| i.kind == SchemaIssueKind::MissingColumn)
);
}
#[test]
fn test_check_sql_allows_system_columns() {
let mut registry = SchemaRegistry::new();
registry.register::<TestUser>();
let issues = registry.check_sql("SELECT ctid FROM users");
assert!(issues.is_empty());
let issues = registry.check_sql("INSERT INTO users (ctid) VALUES ('(0,0)')");
assert!(issues.is_empty());
let issues = registry.check_sql("UPDATE users SET ctid = ctid");
assert!(issues.is_empty());
}
}
}