use rustc_hash::FxHashSet;
use squawk_syntax::{
Parse, SourceFile,
ast::{self, AstNode},
identifier::Identifier,
};
use crate::{Linter, Rule, Violation};
pub fn tables_created_in_transaction(
assume_in_transaction: bool,
file: &ast::SourceFile,
) -> FxHashSet<Identifier> {
let mut created_table_names = FxHashSet::default();
let mut inside_transaction = assume_in_transaction;
for stmt in file.stmts() {
match stmt {
ast::Stmt::Begin(_) => {
inside_transaction = true;
}
ast::Stmt::Commit(_) => {
inside_transaction = false;
}
ast::Stmt::CreateTable(create_table) if inside_transaction => {
let Some(table_name) = create_table
.path()
.and_then(|x| x.segment())
.and_then(|x| x.name())
else {
continue;
};
created_table_names.insert(Identifier::new(&table_name.text()));
}
_ => (),
}
}
created_table_names
}
fn not_valid_validate_in_transaction(
ctx: &mut Linter,
assume_in_transaction: bool,
file: &ast::SourceFile,
) {
let mut inside_transaction = assume_in_transaction;
let mut not_valid_names: FxHashSet<Identifier> = FxHashSet::default();
for stmt in file.stmts() {
match stmt {
ast::Stmt::AlterTable(alter_table) => {
for action in alter_table.actions() {
match action {
ast::AlterTableAction::ValidateConstraint(validate_constraint) => {
if let Some(constraint_name) =
validate_constraint.name_ref().map(|x| x.text().to_string())
{
if inside_transaction
&& not_valid_names.contains(&Identifier::new(&constraint_name))
{
ctx.report(
Violation::for_node(
Rule::ConstraintMissingNotValid,
"Using `NOT VALID` and `VALIDATE CONSTRAINT` in the same transaction will block all reads while the constraint is validated.".into(),
validate_constraint.syntax(),
).help("Add constraint as `NOT VALID` in one transaction and `VALIDATE CONSTRAINT` in a separate transaction."))
}
}
}
ast::AlterTableAction::AddConstraint(add_constraint) => {
if add_constraint.not_valid().is_some()
&& let Some(constraint) = add_constraint.constraint()
&& let Some(constraint_name) =
constraint.constraint_name().and_then(|c| c.name())
{
not_valid_names.insert(Identifier::new(&constraint_name.text()));
}
}
_ => (),
}
}
}
ast::Stmt::Begin(_) => {
if !inside_transaction {
not_valid_names.clear();
}
inside_transaction = true;
}
ast::Stmt::Commit(_) => {
inside_transaction = false;
}
_ => (),
}
}
}
pub(crate) fn constraint_missing_not_valid(ctx: &mut Linter, parse: &Parse<SourceFile>) {
let file = parse.tree();
let assume_in_transaction = ctx.settings.assume_in_transaction;
not_valid_validate_in_transaction(ctx, assume_in_transaction, &file);
let tables_created = tables_created_in_transaction(assume_in_transaction, &file);
for stmt in file.stmts() {
if let ast::Stmt::AlterTable(alter_table) = stmt {
let Some(table_name) = alter_table
.relation_name()
.and_then(|x| x.path())
.and_then(|x| x.segment())
.and_then(|x| x.name_ref())
.map(|x| x.text().to_string())
else {
continue;
};
for action in alter_table.actions() {
if let ast::AlterTableAction::AddConstraint(add_constraint) = action {
if !tables_created.contains(&Identifier::new(&table_name))
&& add_constraint.not_valid().is_none()
{
if let Some(ast::Constraint::UniqueConstraint(uc)) =
add_constraint.constraint()
{
if uc.using_index().is_some() {
continue;
}
}
if let Some(ast::Constraint::PrimaryKeyConstraint(pk)) =
add_constraint.constraint()
{
if pk.using_index().is_some() {
continue;
}
}
ctx.report(Violation::for_node(
Rule::ConstraintMissingNotValid,
"By default new constraints require a table scan and block writes to the table while that scan occurs.".into(),
add_constraint.syntax(),
).help("Use `NOT VALID` with a later `VALIDATE CONSTRAINT` call."));
}
}
}
}
}
}
#[cfg(test)]
mod test {
use insta::assert_snapshot;
use crate::test_utils::{lint_errors, lint_ok};
use crate::{LinterSettings, Rule};
fn lint_ok_with(sql: &str, settings: LinterSettings) {
crate::test_utils::lint_ok_with(sql, settings, Rule::ConstraintMissingNotValid);
}
fn lint_errors_with(sql: &str, settings: LinterSettings) -> String {
crate::test_utils::lint_errors_with(sql, settings, Rule::ConstraintMissingNotValid)
}
#[test]
fn not_valid_validate_transaction_err() {
let sql = r#"
BEGIN;
ALTER TABLE "app_email" ADD CONSTRAINT "fk_user" FOREIGN KEY (user_id) REFERENCES "app_user" (id) NOT VALID;
ALTER TABLE "app_email" VALIDATE CONSTRAINT "fk_user";
COMMIT;
"#;
assert_snapshot!(lint_errors(sql, Rule::ConstraintMissingNotValid));
}
#[test]
fn not_valid_validate_assume_transaction_err() {
let sql = r#"
ALTER TABLE "app_email" ADD CONSTRAINT "fk_user" FOREIGN KEY (user_id) REFERENCES "app_user" (id) NOT VALID;
ALTER TABLE "app_email" VALIDATE CONSTRAINT "fk_user";
"#;
assert_snapshot!(lint_errors_with(
sql,
LinterSettings {
assume_in_transaction: true,
..Default::default()
},
));
}
#[test]
fn not_valid_validate_with_assume_in_transaction_with_explicit_commit_err() {
let sql = r#"
ALTER TABLE "app_email" ADD CONSTRAINT "fk_user" FOREIGN KEY (user_id) REFERENCES "app_user" (id) NOT VALID;
ALTER TABLE "app_email" VALIDATE CONSTRAINT "fk_user";
COMMIT;
"#;
assert_snapshot!(lint_errors_with(
sql,
LinterSettings {
assume_in_transaction: true,
..Default::default()
},
));
}
#[test]
fn adding_fk_err() {
let sql = r#"
-- instead of
ALTER TABLE distributors ADD CONSTRAINT distfk FOREIGN KEY (address) REFERENCES addresses (address);
"#;
assert_snapshot!(lint_errors(sql, Rule::ConstraintMissingNotValid));
}
#[test]
fn adding_fk_not_valid_ok() {
let sql = r#"
-- use `NOT VALID`
ALTER TABLE distributors ADD CONSTRAINT distfk FOREIGN KEY (address) REFERENCES addresses (address) NOT VALID;
ALTER TABLE distributors VALIDATE CONSTRAINT distfk;
"#;
lint_ok(sql, Rule::ConstraintMissingNotValid);
}
#[test]
fn adding_using_index_ok() {
let sql = r#"
ALTER TABLE account ADD CONSTRAINT account_pk PRIMARY KEY USING INDEX account_pk_idx;
"#;
lint_ok(sql, Rule::ConstraintMissingNotValid);
}
#[test]
fn adding_check_constraint_err() {
let sql = r#"
-- instead of
ALTER TABLE "accounts" ADD CONSTRAINT "positive_balance" CHECK ("balance" >= 0);
"#;
assert_snapshot!(lint_errors_with(
sql,
LinterSettings {
assume_in_transaction: true,
..Default::default()
},
));
}
#[test]
fn adding_check_constraint_ok() {
let sql = r#"
-- use `NOT VALID`
ALTER TABLE "accounts" ADD CONSTRAINT "positive_balance" CHECK ("balance" >= 0) NOT VALID;
ALTER TABLE accounts VALIDATE CONSTRAINT positive_balance;
"#;
lint_ok(sql, Rule::ConstraintMissingNotValid);
}
#[test]
fn new_table_with_transaction_ok() {
let sql = r#"
BEGIN;
CREATE TABLE "core_foo" (
"id" serial NOT NULL PRIMARY KEY,
"age" integer NOT NULL
);
ALTER TABLE "core_foo" ADD CONSTRAINT "age_restriction" CHECK ("age" >= 25);
COMMIT;
"#;
lint_ok(sql, Rule::ConstraintMissingNotValid);
}
#[test]
fn new_table_assume_transaction_ok() {
let sql = r#"
CREATE TABLE "core_foo" (
"id" serial NOT NULL PRIMARY KEY,
"age" integer NOT NULL
);
ALTER TABLE "core_foo" ADD CONSTRAINT "age_restriction" CHECK ("age" >= 25);
"#;
lint_ok_with(
sql,
LinterSettings {
assume_in_transaction: true,
..Default::default()
},
);
}
#[test]
fn regression_with_indexing_ok() {
let sql = r#"
CREATE TABLE "core_foo" (
"id" serial NOT NULL PRIMARY KEY,
"age" integer NOT NULL
);
ALTER TABLE "core_foo" ADD CONSTRAINT "age_restriction" CHECK ("age" >= 25);
"#;
lint_ok_with(
sql,
LinterSettings {
assume_in_transaction: true,
..Default::default()
},
);
}
#[test]
fn using_unique_index_ok() {
let sql = r#"
ALTER TABLE "app_email" ADD CONSTRAINT "email_uniq" UNIQUE USING INDEX "email_idx";
"#;
lint_ok(sql, Rule::ConstraintMissingNotValid);
}
}