squawk-linter 2.50.0

Linter for Postgres migrations & SQL
Documentation
use std::sync::OnceLock;

use rustc_hash::FxHashSet;

use squawk_syntax::ast::AstNode;
use squawk_syntax::{Parse, SourceFile, SyntaxKind};
use squawk_syntax::{ast, identifier::Identifier};

use crate::{Linter, Rule, Version, Violation};

fn non_volatile_funcs() -> &'static FxHashSet<Identifier> {
    static NON_VOLATILE_FUNCS: OnceLock<FxHashSet<Identifier>> = OnceLock::new();
    NON_VOLATILE_FUNCS.get_or_init(|| {
        NON_VOLATILE_BUILT_IN_FUNCTIONS
            .split('\n')
            .map(|x| x.trim())
            .filter(|x| !x.is_empty())
            .map(Identifier::new)
            .collect()
    })
}

fn is_non_volatile_or_const(expr: &ast::Expr) -> bool {
    match expr {
        ast::Expr::Literal(_) => true,
        ast::Expr::ArrayExpr(_) => true,
        ast::Expr::BinExpr(bin_expr) => {
            if let Some(lhs) = bin_expr.lhs()
                && let Some(rhs) = bin_expr.rhs()
            {
                return is_non_volatile_or_const(&lhs) && is_non_volatile_or_const(&rhs);
            }
            false
        }
        ast::Expr::CallExpr(call_expr) => {
            if let Some(arglist) = call_expr.arg_list() {
                let no_args = arglist.args().count() == 0;

                // TODO: what about FieldExpr? like, pg_catalog.uuid()
                let Some(ast::Expr::NameRef(name_ref)) = call_expr.expr() else {
                    return false;
                };

                let non_volatile_name =
                    non_volatile_funcs().contains(&Identifier::new(name_ref.text().as_str()));

                no_args && non_volatile_name
            } else {
                false
            }
        }
        // array[]::t[] is non-volatile. We don't check for a plain array expr
        // since postgres will reject it as a default unless it's cast to a type.
        ast::Expr::CastExpr(cast_expr) => {
            if let Some(inner_expr) = cast_expr.expr() {
                is_non_volatile_or_const(&inner_expr)
            } else {
                false
            }
        }
        // current_timestamp is the same as calling now()
        ast::Expr::NameRef(name_ref) => {
            if let Some(child) = name_ref.syntax().first_child_or_token() {
                if child.kind() == SyntaxKind::CURRENT_TIMESTAMP_KW {
                    return true;
                }
            }
            false
        }
        _ => false,
    }
}

// Generated via the following Postgres query:
//      select proname from pg_proc where provolatile <> 'v';
const NON_VOLATILE_BUILT_IN_FUNCTIONS: &str = include_str!("non_volatile_built_in_functions.txt");

pub(crate) fn adding_field_with_default(ctx: &mut Linter, parse: &Parse<SourceFile>) {
    let message = "Adding a generated column requires a table rewrite with an `ACCESS EXCLUSIVE` lock. In Postgres versions 11+, non-VOLATILE DEFAULTs can be added without a rewrite.";
    let help = "Add the column as nullable, backfill existing rows, and add a trigger to update the column on write instead.";
    let file = parse.tree();
    // TODO: use match_ast! like in #api_walkthrough
    for stmt in file.stmts() {
        if let ast::Stmt::AlterTable(alter_table) = stmt {
            for action in alter_table.actions() {
                if let ast::AlterTableAction::AddColumn(add_column) = action {
                    for constraint in add_column.constraints() {
                        match constraint {
                            ast::Constraint::DefaultConstraint(default) => {
                                let Some(expr) = default.expr() else {
                                    continue;
                                };
                                if ctx.settings.pg_version > Version::new(11, None, None)
                                    && is_non_volatile_or_const(&expr)
                                {
                                    continue;
                                }
                                ctx.report(
                                    Violation::for_node(
                                        Rule::AddingFieldWithDefault,
                                        message.into(),
                                        expr.syntax(),
                                    )
                                    .help(help),
                                )
                            }
                            ast::Constraint::GeneratedConstraint(generated) => {
                                ctx.report(
                                    Violation::for_node(
                                        Rule::AddingFieldWithDefault,
                                        message.into(),
                                        generated.syntax(),
                                    )
                                    .help(help),
                                );
                            }
                            _ => (),
                        }
                    }
                }
            }
        }
    }
}

#[cfg(test)]
mod test {
    use insta::assert_snapshot;

    use crate::test_utils::{lint_errors, lint_ok};
    use crate::{LinterSettings, Rule};

    fn lint_errors_with(sql: &str, settings: LinterSettings) -> String {
        crate::test_utils::lint_errors_with(sql, settings, Rule::AddingFieldWithDefault)
    }

    #[test]
    fn docs_example_ok_post_pg_11() {
        let sql = r#"
-- instead of
ALTER TABLE "core_recipe" ADD COLUMN "foo" integer DEFAULT 10;
        "#;

        lint_ok(sql, Rule::AddingFieldWithDefault);
    }

    #[test]
    fn docs_example_ok() {
        let sql = r#"
-- use
ALTER TABLE "core_recipe" ADD COLUMN "foo" integer;
ALTER TABLE "core_recipe" ALTER COLUMN "foo" SET DEFAULT 10;
-- backfill
-- remove nullability
            "#;

        lint_ok(sql, Rule::AddingFieldWithDefault);
    }

    #[test]
    fn default_uuid_error_multi_stmt() {
        let sql = r#"
alter table t set logged, add column c integer default uuid();
        "#;

        assert_snapshot!(lint_errors(sql, Rule::AddingFieldWithDefault));
    }

    #[test]
    fn default_uuid_error() {
        let sql = r#"
ALTER TABLE "core_recipe" ADD COLUMN "foo" integer DEFAULT uuid();
        "#;

        assert_snapshot!(lint_errors(sql, Rule::AddingFieldWithDefault));
    }

    #[test]
    fn default_volatile_func_err() {
        let sql = r#"
    -- VOLATILE
ALTER TABLE "core_recipe" ADD COLUMN "foo" boolean DEFAULT random();
            "#;

        assert_snapshot!(lint_errors(sql, Rule::AddingFieldWithDefault));
    }

    #[test]
    fn default_bool_ok() {
        let sql = r#"
    -- NON-VOLATILE
ALTER TABLE "core_recipe" ADD COLUMN "foo" boolean DEFAULT true;
            "#;

        lint_ok(sql, Rule::AddingFieldWithDefault);
    }

    #[test]
    fn default_empty_array_ok() {
        let sql = r#"
alter table t add column a double precision[] default array[]::double precision[];

alter table t add column b bigint[] default cast(array[] as bigint[]);

alter table t add column c text[] default array['foo', 'bar']::text[];
            "#;

        lint_ok(sql, Rule::AddingFieldWithDefault);
    }

    #[test]
    fn default_with_const_bin_expr() {
        let sql = r#"
ALTER TABLE assessments
ADD COLUMN statistics_last_updated_at timestamptz NOT NULL DEFAULT now() - interval '100 years';
            "#;

        lint_ok(sql, Rule::AddingFieldWithDefault);
    }

    #[test]
    fn default_str_ok() {
        let sql = r#"
-- NON-VOLATILE
ALTER TABLE "core_recipe" ADD COLUMN "foo" text DEFAULT 'some-str';
            "#;

        lint_ok(sql, Rule::AddingFieldWithDefault);
    }

    #[test]
    fn default_enum_ok() {
        let sql = r#"
-- NON-VOLATILE
ALTER TABLE "core_recipe" ADD COLUMN "foo" some_enum_type DEFAULT 'my-enum-variant';
        "#;

        lint_ok(sql, Rule::AddingFieldWithDefault);
    }

    #[test]
    fn default_jsonb_ok() {
        let sql = r#"
-- NON-VOLATILE
ALTER TABLE "core_recipe" ADD COLUMN "foo" jsonb DEFAULT '{}'::jsonb;
        "#;

        lint_ok(sql, Rule::AddingFieldWithDefault);
    }

    #[test]
    fn arbitrary_func_err() {
        let sql = r#"
-- NON-VOLATILE
ALTER TABLE "core_recipe" ADD COLUMN "foo" jsonb DEFAULT myjsonb();
        "#;

        assert_snapshot!(lint_errors(sql, Rule::AddingFieldWithDefault));
    }

    #[test]
    fn default_random_with_args_err() {
        let sql = r#"
-- NON-VOLATILE
ALTER TABLE "core_recipe" ADD COLUMN "foo" timestamptz DEFAULT now(123);
        "#;

        assert_snapshot!(lint_errors(sql, Rule::AddingFieldWithDefault));
    }

    #[test]
    fn default_func_now_ok() {
        let sql = r#"
-- NON-VOLATILE
ALTER TABLE "core_recipe" ADD COLUMN "foo" timestamptz DEFAULT now();
        "#;

        lint_ok(sql, Rule::AddingFieldWithDefault);
    }

    #[test]
    fn default_func_current_timestamp_ok() {
        let sql = r#"
alter table t add column c timestamptz default current_timestamp;
        "#;

        lint_ok(sql, Rule::AddingFieldWithDefault);
    }

    #[test]
    fn add_numbers_ok() {
        let sql = r#"
alter table account_metadata add column blah integer default 2 + 2;
        "#;

        lint_ok(sql, Rule::AddingFieldWithDefault);
    }

    #[test]
    fn generated_stored_err() {
        let sql = r#"
ALTER TABLE foo
ADD COLUMN bar numeric GENERATED ALWAYS AS (bar + baz) STORED;
        "#;

        assert_snapshot!(lint_errors(sql, Rule::AddingFieldWithDefault));
    }

    #[test]
    fn docs_example_error_on_pg_11() {
        let sql = r#"
-- instead of
ALTER TABLE "core_recipe" ADD COLUMN "foo" integer DEFAULT 10;
        "#;

        assert_snapshot!(lint_errors_with(
            sql,
            LinterSettings {
                pg_version: "11".parse().expect("Invalid PostgreSQL version"),
                ..Default::default()
            },
        ));
    }
}