use rustc_hash::FxHashSet;
use squawk_syntax::{
Parse, SourceFile,
ast::{self, AstNode},
};
use crate::{Edit, Fix, Linter, Rule, Violation};
use std::sync::OnceLock;
use crate::visitors::{check_not_allowed_types, is_not_valid_int_type};
fn serial_types() -> &'static FxHashSet<&'static str> {
static SERIAL_TYPES: OnceLock<FxHashSet<&'static str>> = OnceLock::new();
SERIAL_TYPES.get_or_init(|| {
FxHashSet::from_iter([
"serial",
"serial2",
"serial4",
"serial8",
"smallserial",
"bigserial",
])
})
}
fn replace_serial(serial_type: &str) -> &'static str {
match serial_type {
"serial" | "serial4" => "integer generated by default as identity",
"serial2" | "smallserial" => "smallint generated by default as identity",
"serial8" | "bigserial" => "bigint generated by default as identity",
_ => "integer generated by default as identity",
}
}
fn create_identity_fix(ty: &ast::Type) -> Option<Fix> {
let name = match ty {
ast::Type::ArrayType(array_type) => return create_identity_fix(&array_type.ty()?),
ast::Type::PathType(path_type) => path_type.path()?.segment()?.name_ref()?,
ast::Type::BitType(_)
| ast::Type::CharType(_)
| ast::Type::DoubleType(_)
| ast::Type::ExprType(_)
| ast::Type::PercentType(_)
| ast::Type::TimeType(_)
| ast::Type::IntervalType(_) => return None,
};
let text = replace_serial(&name.text());
let edit = Edit::replace(name.syntax().text_range(), text);
Some(Fix::new("Replace with IDENTITY column", vec![edit]))
}
fn check_ty_for_serial(ctx: &mut Linter, ty: Option<ast::Type>) {
if let Some(ty) = ty {
if is_not_valid_int_type(&ty, serial_types()) {
let fix = create_identity_fix(&ty);
ctx.report(
Violation::for_node(
Rule::PreferIdentity,
"Serial types make schema, dependency, and permission management difficult."
.into(),
ty.syntax(),
)
.help("Use an `IDENTITY` column instead.")
.fix(fix),
);
};
}
}
pub(crate) fn prefer_identity(ctx: &mut Linter, parse: &Parse<SourceFile>) {
let file = parse.tree();
check_not_allowed_types(ctx, &file, check_ty_for_serial);
}
#[cfg(test)]
mod test {
use insta::assert_snapshot;
use crate::{
Rule,
test_utils::{fix_sql, lint_errors, lint_ok},
};
#[must_use]
fn fix(sql: &str) -> String {
fix_sql(sql, Rule::PreferIdentity)
}
#[test]
fn fix_serial_types() {
assert_snapshot!(fix("create table users (id serial);"), @"create table users (id integer generated by default as identity);");
assert_snapshot!(fix("create table users (id serial2);"), @"create table users (id smallint generated by default as identity);");
assert_snapshot!(fix("create table users (id serial4);"), @"create table users (id integer generated by default as identity);");
assert_snapshot!(fix("create table users (id serial8);"), @"create table users (id bigint generated by default as identity);");
assert_snapshot!(fix("create table users (id smallserial);"), @"create table users (id smallint generated by default as identity);");
assert_snapshot!(fix("create table users (id bigserial);"), @"create table users (id bigint generated by default as identity);");
}
#[test]
fn fix_mixed_case() {
assert_snapshot!(fix("create table users (id BIGSERIAL);"), @"create table users (id bigint generated by default as identity);");
assert_snapshot!(fix("create table users (id Serial);"), @"create table users (id integer generated by default as identity);");
}
#[test]
fn err() {
let sql = r#"
create table users (
id serial
);
create table users (
id serial2
);
create table users (
id serial4
);
create table users (
id serial8
);
create table users (
id smallserial
);
create table users (
id bigserial
);
create table users (
id BIGSERIAL
);
"#;
assert_snapshot!(lint_errors(sql, Rule::PreferIdentity));
}
#[test]
fn ok_when_quoted() {
let sql = r#"
create table users (
id "serial"
);
create table users (
id "bigserial"
);
"#;
assert_snapshot!(lint_errors(sql, Rule::PreferIdentity));
}
#[test]
fn ok() {
let sql = r#"
create table users (
id bigint generated by default as identity primary key
);
create table users (
id bigint generated always as identity primary key
);
"#;
lint_ok(sql, Rule::PreferIdentity);
}
}