use rustc_hash::FxHashSet;
use rowan::TextRange;
use squawk_syntax::{
Parse, SourceFile, TokenText,
ast::{self, AstNode},
identifier::Identifier,
};
use crate::visitors::check_not_allowed_types;
use crate::{Edit, Fix, Linter, Rule, Violation};
use std::sync::OnceLock;
fn char_types() -> &'static FxHashSet<Identifier> {
static CHAR_TYPES: OnceLock<FxHashSet<Identifier>> = OnceLock::new();
CHAR_TYPES.get_or_init(|| {
FxHashSet::from_iter([
Identifier::new("char"),
Identifier::new("character"),
Identifier::new("bpchar"),
])
})
}
fn is_char_type(x: TokenText<'_>) -> bool {
char_types().contains(&Identifier::new(x.as_ref()))
}
fn create_fix(range: TextRange, args: Option<ast::ArgList>) -> Fix {
if let Some(args_list) = args {
let end = args_list.syntax().text_range().start();
let edit = Edit::replace(TextRange::new(range.start(), end), "varchar");
Fix::new("Replace with `varchar`".to_string(), vec![edit])
} else {
let edit = Edit::replace(range, "text");
Fix::new("Replace with `text`".to_string(), vec![edit])
}
}
fn check_path_type(ctx: &mut Linter, path_type: ast::PathType) {
if let Some(name_ref) = path_type
.path()
.and_then(|x| x.segment())
.and_then(|x| x.name_ref())
&& is_char_type(name_ref.text())
{
let fix = create_fix(name_ref.syntax().text_range(), path_type.arg_list());
ctx.report(Violation::for_node(
Rule::BanCharField,
"Using `character` is likely a mistake and should almost always be replaced by `text` or `varchar`.".into(),
path_type.syntax(),
).fix(fix));
}
}
fn check_char_type(ctx: &mut Linter, char_type: ast::CharType) {
if is_char_type(char_type.text()) {
let fix = create_fix(char_type.syntax().text_range(), char_type.arg_list());
ctx.report(Violation::for_node(
Rule::BanCharField,
"Using `character` is likely a mistake and should almost always be replaced by `text` or `varchar`.".into(),
char_type.syntax(),
).fix(fix));
}
}
fn check_ty(ctx: &mut Linter, ty: Option<ast::Type>) {
match ty {
Some(ast::Type::ArrayType(array_type)) => match array_type.ty() {
Some(ast::Type::CharType(char_type)) => {
check_char_type(ctx, char_type);
}
Some(ast::Type::PathType(path_type)) => {
check_path_type(ctx, path_type);
}
_ => (),
},
Some(ast::Type::PathType(path_type)) => {
check_path_type(ctx, path_type);
}
Some(ast::Type::CharType(char_type)) => {
check_char_type(ctx, char_type);
}
_ => (),
}
}
pub(crate) fn ban_char_field(ctx: &mut Linter, parse: &Parse<SourceFile>) {
let file = parse.tree();
check_not_allowed_types(ctx, &file, check_ty);
}
#[cfg(test)]
mod test {
use insta::assert_snapshot;
use crate::{
Rule,
test_utils::{fix_sql, lint_errors, lint_ok},
};
fn fix(sql: &str) -> String {
fix_sql(sql, Rule::BanCharField)
}
#[test]
fn fix_char_without_length() {
assert_snapshot!(fix("CREATE TABLE t (c char);"), @"CREATE TABLE t (c text);");
assert_snapshot!(fix("CREATE TABLE t (c character);"), @"CREATE TABLE t (c text);");
assert_snapshot!(fix("CREATE TABLE t (c bpchar);"), @"CREATE TABLE t (c text);");
}
#[test]
fn fix_char_with_length() {
assert_snapshot!(fix("CREATE TABLE t (c char(100));"), @"CREATE TABLE t (c varchar(100));");
assert_snapshot!(fix("CREATE TABLE t (c character(255));"), @"CREATE TABLE t (c varchar(255));");
assert_snapshot!(fix("CREATE TABLE t (c bpchar(50));"), @"CREATE TABLE t (c varchar(50));");
assert_snapshot!(fix(r#"CREATE TABLE t (c "char"(100));"#), @"CREATE TABLE t (c varchar(100));");
assert_snapshot!(fix(r#"CREATE TABLE t (c "character"(255));"#), @"CREATE TABLE t (c varchar(255));");
assert_snapshot!(fix(r#"CREATE TABLE t (c "bpchar"(50));"#), @"CREATE TABLE t (c varchar(50));");
}
#[test]
fn fix_mixed_case() {
assert_snapshot!(fix("CREATE TABLE t (c CHAR);"), @"CREATE TABLE t (c text);");
assert_snapshot!(fix("CREATE TABLE t (c CHARACTER(100));"), @"CREATE TABLE t (c varchar(100));");
assert_snapshot!(fix("CREATE TABLE t (c Char(50));"), @"CREATE TABLE t (c varchar(50));");
}
#[test]
fn fix_array_types() {
assert_snapshot!(fix("CREATE TABLE t (c char[]);"), @"CREATE TABLE t (c text[]);");
assert_snapshot!(fix("CREATE TABLE t (c character(100)[]);"), @"CREATE TABLE t (c varchar(100)[]);");
}
#[test]
fn fix_alter_table() {
assert_snapshot!(fix("ALTER TABLE t ADD COLUMN c char;"), @"ALTER TABLE t ADD COLUMN c text;");
assert_snapshot!(fix("ALTER TABLE t ADD COLUMN c character(100);"), @"ALTER TABLE t ADD COLUMN c varchar(100);");
}
#[test]
fn fix_multiple_columns() {
assert_snapshot!(fix("CREATE TABLE t (a char, b character(100), c bpchar(50));"), @"CREATE TABLE t (a text, b varchar(100), c varchar(50));");
}
#[test]
fn creating_table_with_char_errors() {
let sql = r#"
CREATE TABLE "core_bar" (
"id" serial NOT NULL PRIMARY KEY,
"alpha" char(100) NOT NULL,
"beta" character(100) NOT NULL,
"charlie" char NOT NULL,
"delta" character NOT NULL
);
"#;
assert_snapshot!(lint_errors(sql, Rule::BanCharField));
}
#[test]
fn creating_table_with_var_char_and_text_okay() {
let sql = r#"
CREATE TABLE "core_bar" (
"id" serial NOT NULL PRIMARY KEY,
"alpha" varchar(100) NOT NULL,
"beta" text NOT NULL
);
"#;
lint_ok(sql, Rule::BanCharField);
}
#[test]
fn all_the_types() {
let sql = r#"
create table t (
a serial not null primary key,
b char(100),
c character(100),
d char,
e character,
f double precision,
g time with time zone,
h interval,
j int[5][10],
k bar(10),
l bit varying,
m int array[],
o pg_catalog.char,
p char[]
);
"#;
assert_snapshot!(lint_errors(sql, Rule::BanCharField));
}
#[test]
fn case_insensitive() {
let sql = r#"
create table t (
a Char
);
"#;
assert_snapshot!(lint_errors(sql, Rule::BanCharField));
}
#[test]
fn array_char_type_err() {
let sql = r#"
create table t (
a char[]
);
"#;
assert_snapshot!(lint_errors(sql, Rule::BanCharField));
}
#[test]
fn alter_table_err() {
let sql = r#"
alter table t add column c char;
"#;
assert_snapshot!(lint_errors(sql, Rule::BanCharField));
}
}