use std::sync::OnceLock;
use rustc_hash::FxHashSet;
use squawk_syntax::ast::AstNode;
use squawk_syntax::{Parse, SourceFile, SyntaxKind};
use squawk_syntax::{ast, identifier::Identifier};
use crate::{Linter, Rule, Version, Violation};
fn non_volatile_funcs() -> &'static FxHashSet<Identifier> {
static NON_VOLATILE_FUNCS: OnceLock<FxHashSet<Identifier>> = OnceLock::new();
NON_VOLATILE_FUNCS.get_or_init(|| {
NON_VOLATILE_BUILT_IN_FUNCTIONS
.split('\n')
.map(|x| x.trim())
.filter(|x| !x.is_empty())
.map(Identifier::new)
.collect()
})
}
fn is_non_volatile_or_const(expr: &ast::Expr) -> bool {
match expr {
ast::Expr::Literal(_) => true,
ast::Expr::ArrayExpr(_) => true,
ast::Expr::BinExpr(bin_expr) => {
if let Some(lhs) = bin_expr.lhs()
&& let Some(rhs) = bin_expr.rhs()
{
return is_non_volatile_or_const(&lhs) && is_non_volatile_or_const(&rhs);
}
false
}
ast::Expr::CallExpr(call_expr) => {
if let Some(arglist) = call_expr.arg_list() {
let no_args = arglist.args().count() == 0;
let Some(ast::Expr::NameRef(name_ref)) = call_expr.expr() else {
return false;
};
let non_volatile_name =
non_volatile_funcs().contains(&Identifier::new(name_ref.text().as_str()));
no_args && non_volatile_name
} else {
false
}
}
ast::Expr::CastExpr(cast_expr) => {
if let Some(inner_expr) = cast_expr.expr() {
is_non_volatile_or_const(&inner_expr)
} else {
false
}
}
ast::Expr::NameRef(name_ref) => {
if let Some(child) = name_ref.syntax().first_child_or_token() {
if child.kind() == SyntaxKind::CURRENT_TIMESTAMP_KW {
return true;
}
}
false
}
_ => false,
}
}
const NON_VOLATILE_BUILT_IN_FUNCTIONS: &str = include_str!("non_volatile_built_in_functions.txt");
pub(crate) fn adding_field_with_default(ctx: &mut Linter, parse: &Parse<SourceFile>) {
let message = "Adding a generated column requires a table rewrite with an `ACCESS EXCLUSIVE` lock. In Postgres versions 11+, non-VOLATILE DEFAULTs can be added without a rewrite.";
let help = "Add the column as nullable, backfill existing rows, and add a trigger to update the column on write instead.";
let file = parse.tree();
for stmt in file.stmts() {
if let ast::Stmt::AlterTable(alter_table) = stmt {
for action in alter_table.actions() {
if let ast::AlterTableAction::AddColumn(add_column) = action {
for constraint in add_column.constraints() {
match constraint {
ast::Constraint::DefaultConstraint(default) => {
let Some(expr) = default.expr() else {
continue;
};
if ctx.settings.pg_version > Version::new(11, None, None)
&& is_non_volatile_or_const(&expr)
{
continue;
}
ctx.report(
Violation::for_node(
Rule::AddingFieldWithDefault,
message.into(),
expr.syntax(),
)
.help(help),
)
}
ast::Constraint::GeneratedConstraint(generated) => {
ctx.report(
Violation::for_node(
Rule::AddingFieldWithDefault,
message.into(),
generated.syntax(),
)
.help(help),
);
}
_ => (),
}
}
}
}
}
}
}
#[cfg(test)]
mod test {
use insta::assert_snapshot;
use crate::test_utils::{lint_errors, lint_ok};
use crate::{LinterSettings, Rule};
fn lint_errors_with(sql: &str, settings: LinterSettings) -> String {
crate::test_utils::lint_errors_with(sql, settings, Rule::AddingFieldWithDefault)
}
#[test]
fn docs_example_ok_post_pg_11() {
let sql = r#"
-- instead of
ALTER TABLE "core_recipe" ADD COLUMN "foo" integer DEFAULT 10;
"#;
lint_ok(sql, Rule::AddingFieldWithDefault);
}
#[test]
fn docs_example_ok() {
let sql = r#"
-- use
ALTER TABLE "core_recipe" ADD COLUMN "foo" integer;
ALTER TABLE "core_recipe" ALTER COLUMN "foo" SET DEFAULT 10;
-- backfill
-- remove nullability
"#;
lint_ok(sql, Rule::AddingFieldWithDefault);
}
#[test]
fn default_uuid_error_multi_stmt() {
let sql = r#"
alter table t set logged, add column c integer default uuid();
"#;
assert_snapshot!(lint_errors(sql, Rule::AddingFieldWithDefault));
}
#[test]
fn default_uuid_error() {
let sql = r#"
ALTER TABLE "core_recipe" ADD COLUMN "foo" integer DEFAULT uuid();
"#;
assert_snapshot!(lint_errors(sql, Rule::AddingFieldWithDefault));
}
#[test]
fn default_volatile_func_err() {
let sql = r#"
-- VOLATILE
ALTER TABLE "core_recipe" ADD COLUMN "foo" boolean DEFAULT random();
"#;
assert_snapshot!(lint_errors(sql, Rule::AddingFieldWithDefault));
}
#[test]
fn default_bool_ok() {
let sql = r#"
-- NON-VOLATILE
ALTER TABLE "core_recipe" ADD COLUMN "foo" boolean DEFAULT true;
"#;
lint_ok(sql, Rule::AddingFieldWithDefault);
}
#[test]
fn default_empty_array_ok() {
let sql = r#"
alter table t add column a double precision[] default array[]::double precision[];
alter table t add column b bigint[] default cast(array[] as bigint[]);
alter table t add column c text[] default array['foo', 'bar']::text[];
"#;
lint_ok(sql, Rule::AddingFieldWithDefault);
}
#[test]
fn default_with_const_bin_expr() {
let sql = r#"
ALTER TABLE assessments
ADD COLUMN statistics_last_updated_at timestamptz NOT NULL DEFAULT now() - interval '100 years';
"#;
lint_ok(sql, Rule::AddingFieldWithDefault);
}
#[test]
fn default_str_ok() {
let sql = r#"
-- NON-VOLATILE
ALTER TABLE "core_recipe" ADD COLUMN "foo" text DEFAULT 'some-str';
"#;
lint_ok(sql, Rule::AddingFieldWithDefault);
}
#[test]
fn default_enum_ok() {
let sql = r#"
-- NON-VOLATILE
ALTER TABLE "core_recipe" ADD COLUMN "foo" some_enum_type DEFAULT 'my-enum-variant';
"#;
lint_ok(sql, Rule::AddingFieldWithDefault);
}
#[test]
fn default_jsonb_ok() {
let sql = r#"
-- NON-VOLATILE
ALTER TABLE "core_recipe" ADD COLUMN "foo" jsonb DEFAULT '{}'::jsonb;
"#;
lint_ok(sql, Rule::AddingFieldWithDefault);
}
#[test]
fn arbitrary_func_err() {
let sql = r#"
-- NON-VOLATILE
ALTER TABLE "core_recipe" ADD COLUMN "foo" jsonb DEFAULT myjsonb();
"#;
assert_snapshot!(lint_errors(sql, Rule::AddingFieldWithDefault));
}
#[test]
fn default_random_with_args_err() {
let sql = r#"
-- NON-VOLATILE
ALTER TABLE "core_recipe" ADD COLUMN "foo" timestamptz DEFAULT now(123);
"#;
assert_snapshot!(lint_errors(sql, Rule::AddingFieldWithDefault));
}
#[test]
fn default_func_now_ok() {
let sql = r#"
-- NON-VOLATILE
ALTER TABLE "core_recipe" ADD COLUMN "foo" timestamptz DEFAULT now();
"#;
lint_ok(sql, Rule::AddingFieldWithDefault);
}
#[test]
fn default_func_current_timestamp_ok() {
let sql = r#"
alter table t add column c timestamptz default current_timestamp;
"#;
lint_ok(sql, Rule::AddingFieldWithDefault);
}
#[test]
fn add_numbers_ok() {
let sql = r#"
alter table account_metadata add column blah integer default 2 + 2;
"#;
lint_ok(sql, Rule::AddingFieldWithDefault);
}
#[test]
fn generated_stored_err() {
let sql = r#"
ALTER TABLE foo
ADD COLUMN bar numeric GENERATED ALWAYS AS (bar + baz) STORED;
"#;
assert_snapshot!(lint_errors(sql, Rule::AddingFieldWithDefault));
}
#[test]
fn docs_example_error_on_pg_11() {
let sql = r#"
-- instead of
ALTER TABLE "core_recipe" ADD COLUMN "foo" integer DEFAULT 10;
"#;
assert_snapshot!(lint_errors_with(
sql,
LinterSettings {
pg_version: "11".parse().expect("Invalid PostgreSQL version"),
..Default::default()
},
));
}
}