mod classify;
mod columns;
mod location;
use std::fmt;
use pg_query::protobuf::node::Node;
use classify::{imperative_reason, warn_create_table_missing_if_not_exists};
use columns::{TableDef, check_index_columns, check_view_columns, collect_create_stmt};
use location::{LineIndex, StmtLoc, stmt_start_offset};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LintSeverity {
Error,
Warning,
}
impl fmt::Display for LintSeverity {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Error => f.write_str("error"),
Self::Warning => f.write_str("warning"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LintError {
pub line: u32,
pub column: u32,
pub severity: LintSeverity,
pub message: String,
pub source: String,
}
impl fmt::Display for LintError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}:{}:{}: {}: {}",
self.source, self.line, self.column, self.severity, self.message
)
}
}
#[must_use]
pub fn created_table_names(sql: &str) -> Vec<String> {
let Ok(parsed) = pg_query::parse(sql) else {
return Vec::new();
};
parsed
.protobuf
.stmts
.iter()
.filter_map(|raw| match raw.stmt.as_ref()?.node.as_ref()? {
Node::CreateStmt(create) => collect_create_stmt(create).map(|t| t.name().to_string()),
_ => None,
})
.collect()
}
pub fn lint_declarative_schema(sql: &str, source: &str) -> Result<(), Vec<LintError>> {
let parsed = match pg_query::parse(sql) {
Ok(p) => p,
Err(e) => {
return Err(vec![LintError {
line: 1,
column: 1,
severity: LintSeverity::Error,
message: format!("SQL parse failed: {e}"),
source: source.to_string(),
}]);
},
};
let line_index = LineIndex::new(sql);
let stmts = &parsed.protobuf.stmts;
let (tables, mut errors) = classify_pass(stmts, sql, &line_index, source);
errors.extend(column_ref_pass(stmts, sql, &line_index, &tables, source));
if errors.iter().any(|e| e.severity == LintSeverity::Error) {
return Err(errors);
}
Ok(())
}
fn classify_pass(
stmts: &[pg_query::protobuf::RawStmt],
sql: &str,
line_index: &LineIndex,
source: &str,
) -> (Vec<TableDef>, Vec<LintError>) {
let mut errors: Vec<LintError> = Vec::new();
let mut tables: Vec<TableDef> = Vec::new();
for raw in stmts {
let location = stmt_start_offset(sql, raw.stmt_location.max(0) as usize);
let (line, col) = line_index.position(location);
let loc = StmtLoc { line, col, source };
let Some(stmt) = raw.stmt.as_ref() else {
continue;
};
let Some(node) = stmt.node.as_ref() else {
continue;
};
match node {
Node::CreateStmt(create) => {
if let Some(table) = collect_create_stmt(create) {
tables.push(table);
}
if let Some(warn) = warn_create_table_missing_if_not_exists(create, &loc) {
errors.push(warn);
}
},
Node::IndexStmt(_)
| Node::CreateFunctionStmt(_)
| Node::ViewStmt(_)
| Node::CreateTrigStmt(_)
| Node::CompositeTypeStmt(_)
| Node::CreateEnumStmt(_)
| Node::CommentStmt(_) => {},
Node::CreateExtensionStmt(ext) => {
if !ext.if_not_exists {
errors.push(LintError {
line,
column: col,
severity: LintSeverity::Warning,
message: "CREATE EXTENSION without IF NOT EXISTS".into(),
source: source.to_string(),
});
}
},
other => {
if let Some(reason) = imperative_reason(other) {
errors.push(LintError {
line,
column: col,
severity: LintSeverity::Error,
message: format!(
"imperative SQL in declarative schema: {reason} — move to \
schema/migrations/NNN_<name>.sql"
),
source: source.to_string(),
});
}
},
}
}
(tables, errors)
}
fn column_ref_pass(
stmts: &[pg_query::protobuf::RawStmt],
sql: &str,
line_index: &LineIndex,
tables: &[TableDef],
source: &str,
) -> Vec<LintError> {
let mut errors: Vec<LintError> = Vec::new();
for raw in stmts {
let Some(stmt) = raw.stmt.as_ref() else {
continue;
};
let Some(node) = stmt.node.as_ref() else {
continue;
};
let location = stmt_start_offset(sql, raw.stmt_location.max(0) as usize);
let (line, col) = line_index.position(location);
let loc = StmtLoc { line, col, source };
match node {
Node::IndexStmt(idx) => {
check_index_columns(idx, tables, &loc, &mut errors);
},
Node::ViewStmt(view) => {
check_view_columns(view, tables, &loc, &mut errors);
},
_ => {},
}
}
errors
}