use std::collections::{HashMap, HashSet};
use crate::contract::{HasSchema, ModelSchema};
use crate::orm::Db;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReportStatus {
Ok,
Warning,
Error,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum IssueKind {
MissingTable,
MissingColumn,
TypeMismatch,
NullabilityMismatch,
WrongPrimaryKey,
ExtraDbColumn,
QueryFailed,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SchemaIssue {
pub column: Option<String>,
pub kind: IssueKind,
pub message: String,
pub expected: Option<String>,
pub actual: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SchemaReport {
pub table: String,
pub status: ReportStatus,
pub errors: Vec<SchemaIssue>,
pub warnings: Vec<SchemaIssue>,
}
impl SchemaReport {
pub fn is_ok(&self) -> bool {
matches!(self.status, ReportStatus::Ok)
}
pub fn has_errors(&self) -> bool {
!self.errors.is_empty()
}
}
pub async fn validate_schema<M: HasSchema>(db: &Db) -> SchemaReport {
let schema = M::SCHEMA;
validate_one(db, &schema).await
}
pub async fn validate_all(
db: &Db,
schemas: &[&'static ModelSchema],
) -> Vec<SchemaReport> {
let mut out = Vec::with_capacity(schemas.len());
for s in schemas {
out.push(validate_one(db, s).await);
}
out
}
async fn validate_one(db: &Db, schema: &ModelSchema) -> SchemaReport {
let mut errors: Vec<SchemaIssue> = Vec::new();
let mut warnings: Vec<SchemaIssue> = Vec::new();
let table = schema.table;
let db_cols = match query_columns(db, table).await {
Ok(cols) => cols,
Err(e) => {
errors.push(SchemaIssue {
column: None,
kind: IssueKind::QueryFailed,
message: format!(
"could not query information_schema.columns for table `{table}`: {e}"
),
expected: None,
actual: None,
});
return finalize(table.to_string(), errors, warnings);
}
};
if db_cols.is_empty() {
errors.push(SchemaIssue {
column: None,
kind: IssueKind::MissingTable,
message: format!(
"table `{table}` declared in Rust contract not found in database (schema 'public')"
),
expected: Some(table.to_string()),
actual: None,
});
return finalize(table.to_string(), errors, warnings);
}
let db_map: HashMap<&str, &DbColumn> = db_cols
.iter()
.map(|c| (c.column_name.as_str(), c))
.collect();
for rc in schema.columns {
let dc = match db_map.get(rc.name) {
Some(c) => c,
None => {
errors.push(SchemaIssue {
column: Some(rc.name.to_string()),
kind: IssueKind::MissingColumn,
message: format!(
"column `{table}.{}` declared in Rust contract not present in database",
rc.name
),
expected: Some(rc.sql_decl.to_string()),
actual: None,
});
continue;
}
};
let type_ok = rc.rust_type.is_compatible_with(&dc.data_type)
|| rc.rust_type.is_compatible_with(&dc.udt_name);
if !type_ok {
errors.push(SchemaIssue {
column: Some(rc.name.to_string()),
kind: IssueKind::TypeMismatch,
message: format!(
"column `{table}.{}`: Rust type {:?} is not compatible with PG type `{}` (udt: `{}`)",
rc.name, rc.rust_type, dc.data_type, dc.udt_name
),
expected: Some(format!(
"{:?} (compatible with one of {:?})",
rc.rust_type,
rc.rust_type.pg_compatible()
)),
actual: Some(dc.data_type.clone()),
});
}
let pg_nullable = dc.is_nullable.eq_ignore_ascii_case("YES");
if pg_nullable != rc.nullable {
errors.push(SchemaIssue {
column: Some(rc.name.to_string()),
kind: IssueKind::NullabilityMismatch,
message: format!(
"column `{table}.{}`: contract says nullable={}, DB says nullable={}",
rc.name, rc.nullable, pg_nullable
),
expected: Some(format!("nullable = {}", rc.nullable)),
actual: Some(format!("nullable = {pg_nullable}")),
});
}
}
match query_primary_key(db, table).await {
Ok(pk_cols) => {
let mismatch = pk_cols.len() != 1 || pk_cols[0] != schema.primary_key;
if mismatch {
errors.push(SchemaIssue {
column: Some(schema.primary_key.to_string()),
kind: IssueKind::WrongPrimaryKey,
message: format!(
"primary key drift on `{table}`: contract expects `{}`, DB has [{}]",
schema.primary_key,
pk_cols.join(", ")
),
expected: Some(schema.primary_key.to_string()),
actual: Some(if pk_cols.is_empty() {
"<no primary key>".to_string()
} else {
pk_cols.join(", ")
}),
});
}
}
Err(e) => {
errors.push(SchemaIssue {
column: None,
kind: IssueKind::QueryFailed,
message: format!(
"could not query primary-key constraints for `{table}`: {e}"
),
expected: None,
actual: None,
});
}
}
let rust_names: HashSet<&str> = schema.columns.iter().map(|c| c.name).collect();
for dc in &db_cols {
if !rust_names.contains(dc.column_name.as_str()) {
warnings.push(SchemaIssue {
column: Some(dc.column_name.clone()),
kind: IssueKind::ExtraDbColumn,
message: format!(
"DB column `{table}.{}` not declared in Rust contract (could be deliberate)",
dc.column_name
),
expected: None,
actual: Some(dc.data_type.clone()),
});
}
}
finalize(table.to_string(), errors, warnings)
}
fn finalize(
table: String,
errors: Vec<SchemaIssue>,
warnings: Vec<SchemaIssue>,
) -> SchemaReport {
let status = if !errors.is_empty() {
ReportStatus::Error
} else if !warnings.is_empty() {
ReportStatus::Warning
} else {
ReportStatus::Ok
};
SchemaReport {
table,
status,
errors,
warnings,
}
}
#[derive(Debug)]
struct DbColumn {
column_name: String,
data_type: String,
udt_name: String,
is_nullable: String, }
async fn query_columns(db: &Db, table: &str) -> Result<Vec<DbColumn>, sqlx::Error> {
use sqlx::Row;
let rows = sqlx::query(
"SELECT column_name, data_type, udt_name, is_nullable
FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = $1
ORDER BY ordinal_position",
)
.bind(table)
.fetch_all(db.pool())
.await?;
let mut out = Vec::with_capacity(rows.len());
for row in rows {
out.push(DbColumn {
column_name: row.try_get("column_name")?,
data_type: row.try_get("data_type")?,
udt_name: row.try_get("udt_name")?,
is_nullable: row.try_get("is_nullable")?,
});
}
Ok(out)
}
async fn query_primary_key(db: &Db, table: &str) -> Result<Vec<String>, sqlx::Error> {
use sqlx::Row;
let rows = sqlx::query(
"SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'PRIMARY KEY'
AND tc.table_schema = 'public'
AND tc.table_name = $1
ORDER BY kcu.ordinal_position",
)
.bind(table)
.fetch_all(db.pool())
.await?;
let mut out = Vec::with_capacity(rows.len());
for row in rows {
out.push(row.try_get::<String, _>("column_name")?);
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn finalize_no_issues_is_ok() {
let r = finalize("t".into(), vec![], vec![]);
assert_eq!(r.status, ReportStatus::Ok);
assert!(r.is_ok());
assert!(!r.has_errors());
}
#[test]
fn finalize_only_warnings_is_warning() {
let warn = SchemaIssue {
column: Some("extra".into()),
kind: IssueKind::ExtraDbColumn,
message: "x".into(),
expected: None,
actual: None,
};
let r = finalize("t".into(), vec![], vec![warn]);
assert_eq!(r.status, ReportStatus::Warning);
assert!(!r.is_ok());
assert!(!r.has_errors());
}
#[test]
fn finalize_any_error_is_error() {
let err = SchemaIssue {
column: Some("c".into()),
kind: IssueKind::TypeMismatch,
message: "x".into(),
expected: None,
actual: None,
};
let warn = SchemaIssue {
column: None,
kind: IssueKind::ExtraDbColumn,
message: "y".into(),
expected: None,
actual: None,
};
let r = finalize("t".into(), vec![err], vec![warn]);
assert_eq!(r.status, ReportStatus::Error);
assert!(r.has_errors());
}
}