use rustc_hash::FxHashSet;
use std::fmt;
use enum_iterator::Sequence;
use enum_iterator::all;
pub use ignore::Ignore;
use ignore::find_ignores;
use ignore_index::IgnoreIndex;
use rowan::TextRange;
use rowan::TextSize;
use serde::Deserialize;
use squawk_syntax::SyntaxNode;
use squawk_syntax::{Parse, SourceFile};
pub use version::Version;
pub mod analyze;
pub mod ignore;
mod ignore_index;
mod version;
mod visitors;
mod rules;
#[cfg(test)]
mod test_utils;
use rules::adding_field_with_default;
use rules::adding_foreign_key_constraint;
use rules::adding_not_null_field;
use rules::adding_primary_key_constraint;
use rules::adding_required_field;
use rules::ban_alter_domain_with_add_constraint;
use rules::ban_char_field;
use rules::ban_concurrent_index_creation_in_transaction;
use rules::ban_create_domain_with_constraint;
use rules::ban_drop_column;
use rules::ban_drop_database;
use rules::ban_drop_not_null;
use rules::ban_drop_table;
use rules::ban_truncate_cascade;
use rules::ban_uncommitted_transaction;
use rules::changing_column_type;
use rules::constraint_missing_not_valid;
use rules::disallow_unique_constraint;
use rules::identifier_too_long;
use rules::prefer_bigint_over_int;
use rules::prefer_bigint_over_smallint;
use rules::prefer_identity;
use rules::prefer_repack;
use rules::prefer_robust_stmts;
use rules::prefer_text_field;
use rules::prefer_timestamptz;
use rules::renaming_column;
use rules::renaming_table;
use rules::require_concurrent_index_creation;
use rules::require_concurrent_index_deletion;
use rules::require_concurrent_partition_detach;
use rules::require_concurrent_reindex;
use rules::require_enum_value_ordering;
use rules::require_table_schema;
use rules::require_timeout_settings;
use rules::transaction_nesting;
#[derive(Debug, PartialEq, Clone, Copy, Hash, Eq, Sequence)]
pub enum Rule {
RequireConcurrentIndexCreation,
RequireConcurrentIndexDeletion,
ConstraintMissingNotValid,
AddingFieldWithDefault,
AddingForeignKeyConstraint,
ChangingColumnType,
AddingNotNullableField,
AddingSerialPrimaryKeyField,
RenamingColumn,
RenamingTable,
DisallowedUniqueConstraint,
BanDropDatabase,
PreferBigintOverInt,
PreferBigintOverSmallint,
PreferIdentity,
PreferRepack,
PreferRobustStmts,
PreferTextField,
PreferTimestampTz,
BanCharField,
BanDropColumn,
BanDropTable,
BanDropNotNull,
TransactionNesting,
AddingRequiredField,
BanConcurrentIndexCreationInTransaction,
UnusedIgnore,
BanCreateDomainWithConstraint,
BanAlterDomainWithAddConstraint,
BanTruncateCascade,
RequireTimeoutSettings,
BanUncommittedTransaction,
RequireEnumValueOrdering,
RequireTableSchema,
IdentifierTooLong,
RequireConcurrentPartitionDetach,
RequireConcurrentReindex,
}
impl Rule {
pub fn is_opt_in(&self) -> bool {
matches!(self, Rule::RequireTableSchema)
}
}
impl TryFrom<&str> for Rule {
type Error = String;
fn try_from(s: &str) -> Result<Self, Self::Error> {
match s {
"require-concurrent-index-creation" => Ok(Rule::RequireConcurrentIndexCreation),
"require-concurrent-index-deletion" => Ok(Rule::RequireConcurrentIndexDeletion),
"constraint-missing-not-valid" => Ok(Rule::ConstraintMissingNotValid),
"adding-field-with-default" => Ok(Rule::AddingFieldWithDefault),
"adding-foreign-key-constraint" => Ok(Rule::AddingForeignKeyConstraint),
"changing-column-type" => Ok(Rule::ChangingColumnType),
"adding-not-nullable-field" => Ok(Rule::AddingNotNullableField),
"adding-serial-primary-key-field" => Ok(Rule::AddingSerialPrimaryKeyField),
"renaming-column" => Ok(Rule::RenamingColumn),
"renaming-table" => Ok(Rule::RenamingTable),
"disallowed-unique-constraint" => Ok(Rule::DisallowedUniqueConstraint),
"ban-drop-database" => Ok(Rule::BanDropDatabase),
"prefer-bigint-over-int" => Ok(Rule::PreferBigintOverInt),
"prefer-bigint-over-smallint" => Ok(Rule::PreferBigintOverSmallint),
"prefer-identity" => Ok(Rule::PreferIdentity),
"prefer-repack" => Ok(Rule::PreferRepack),
"prefer-robust-stmts" => Ok(Rule::PreferRobustStmts),
"prefer-text-field" => Ok(Rule::PreferTextField),
"prefer-timestamptz" => Ok(Rule::PreferTimestampTz),
"prefer-timestamp-tz" => Ok(Rule::PreferTimestampTz),
"ban-char-field" => Ok(Rule::BanCharField),
"ban-drop-column" => Ok(Rule::BanDropColumn),
"ban-drop-table" => Ok(Rule::BanDropTable),
"ban-drop-not-null" => Ok(Rule::BanDropNotNull),
"transaction-nesting" => Ok(Rule::TransactionNesting),
"adding-required-field" => Ok(Rule::AddingRequiredField),
"ban-concurrent-index-creation-in-transaction" => {
Ok(Rule::BanConcurrentIndexCreationInTransaction)
}
"ban-create-domain-with-constraint" => Ok(Rule::BanCreateDomainWithConstraint),
"ban-alter-domain-with-add-constraint" => Ok(Rule::BanAlterDomainWithAddConstraint),
"ban-truncate-cascade" => Ok(Rule::BanTruncateCascade),
"require-timeout-settings" => Ok(Rule::RequireTimeoutSettings),
"ban-uncommitted-transaction" => Ok(Rule::BanUncommittedTransaction),
"require-enum-value-ordering" => Ok(Rule::RequireEnumValueOrdering),
"require-table-schema" => Ok(Rule::RequireTableSchema),
"identifier-too-long" => Ok(Rule::IdentifierTooLong),
"require-concurrent-partition-detach" => Ok(Rule::RequireConcurrentPartitionDetach),
"require-concurrent-reindex" => Ok(Rule::RequireConcurrentReindex),
_ => Err(format!("Unknown violation name: {s}")),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UnknownRuleName {
val: String,
}
impl std::fmt::Display for UnknownRuleName {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "invalid rule name {}", self.val)
}
}
impl std::error::Error for UnknownRuleName {}
impl std::str::FromStr for Rule {
type Err = UnknownRuleName;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Rule::try_from(s).map_err(|_| UnknownRuleName { val: s.to_string() })
}
}
impl fmt::Display for Rule {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let val = match &self {
Rule::RequireConcurrentIndexCreation => "require-concurrent-index-creation",
Rule::RequireConcurrentIndexDeletion => "require-concurrent-index-deletion",
Rule::ConstraintMissingNotValid => "constraint-missing-not-valid",
Rule::AddingFieldWithDefault => "adding-field-with-default",
Rule::AddingForeignKeyConstraint => "adding-foreign-key-constraint",
Rule::ChangingColumnType => "changing-column-type",
Rule::AddingNotNullableField => "adding-not-nullable-field",
Rule::AddingSerialPrimaryKeyField => "adding-serial-primary-key-field",
Rule::RenamingColumn => "renaming-column",
Rule::RenamingTable => "renaming-table",
Rule::DisallowedUniqueConstraint => "disallowed-unique-constraint",
Rule::BanDropDatabase => "ban-drop-database",
Rule::PreferBigintOverInt => "prefer-bigint-over-int",
Rule::PreferBigintOverSmallint => "prefer-bigint-over-smallint",
Rule::PreferIdentity => "prefer-identity",
Rule::PreferRepack => "prefer-repack",
Rule::PreferRobustStmts => "prefer-robust-stmts",
Rule::PreferTextField => "prefer-text-field",
Rule::PreferTimestampTz => "prefer-timestamp-tz",
Rule::BanCharField => "ban-char-field",
Rule::BanDropColumn => "ban-drop-column",
Rule::BanDropTable => "ban-drop-table",
Rule::BanDropNotNull => "ban-drop-not-null",
Rule::TransactionNesting => "transaction-nesting",
Rule::AddingRequiredField => "adding-required-field",
Rule::BanConcurrentIndexCreationInTransaction => {
"ban-concurrent-index-creation-in-transaction"
}
Rule::BanCreateDomainWithConstraint => "ban-create-domain-with-constraint",
Rule::UnusedIgnore => "unused-ignore",
Rule::BanAlterDomainWithAddConstraint => "ban-alter-domain-with-add-constraint",
Rule::BanTruncateCascade => "ban-truncate-cascade",
Rule::RequireTimeoutSettings => "require-timeout-settings",
Rule::BanUncommittedTransaction => "ban-uncommitted-transaction",
Rule::RequireEnumValueOrdering => "require-enum-value-ordering",
Rule::RequireTableSchema => "require-table-schema",
Rule::IdentifierTooLong => "identifier-too-long",
Rule::RequireConcurrentPartitionDetach => "require-concurrent-partition-detach",
Rule::RequireConcurrentReindex => "require-concurrent-reindex",
};
write!(f, "{val}")
}
}
impl<'de> Deserialize<'de> for Rule {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Fix {
pub title: String,
pub edits: Vec<Edit>,
}
impl Fix {
fn new<T: Into<String>>(title: T, edits: Vec<Edit>) -> Fix {
Fix {
title: title.into(),
edits,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Edit {
pub text_range: TextRange,
pub text: Option<String>,
}
impl Edit {
pub fn insert<T: Into<String>>(text: T, at: TextSize) -> Self {
Self {
text_range: TextRange::new(at, at),
text: Some(text.into()),
}
}
pub fn replace<T: Into<String>>(text_range: TextRange, text: T) -> Self {
Self {
text_range,
text: Some(text.into()),
}
}
pub fn delete(text_range: TextRange) -> Self {
Self {
text_range,
text: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Violation {
pub code: Rule,
pub message: String,
pub text_range: TextRange,
pub help: Option<String>,
pub fix: Option<Fix>,
}
impl Violation {
#[must_use]
pub fn for_node(code: Rule, message: String, node: &SyntaxNode) -> Self {
let range = node.text_range();
let start = node
.children_with_tokens()
.find(|x| !x.kind().is_trivia())
.map(|x| x.text_range().start())
.unwrap_or_else(|| range.start());
Self {
code,
text_range: TextRange::new(start, range.end()),
message,
help: None,
fix: None,
}
}
#[must_use]
pub fn for_range(code: Rule, message: String, text_range: TextRange) -> Self {
Self {
code,
text_range,
message,
help: None,
fix: None,
}
}
fn fix<F: Into<Option<Fix>>>(mut self, fix: F) -> Violation {
self.fix = fix.into();
self
}
fn help(mut self, help: impl Into<String>) -> Violation {
self.help = Some(help.into());
self
}
}
#[derive(Clone, Default)]
pub struct LinterSettings {
pub pg_version: Version,
pub assume_in_transaction: bool,
}
pub struct Linter {
errors: Vec<Violation>,
ignores: Vec<Ignore>,
pub rules: FxHashSet<Rule>,
pub settings: LinterSettings,
}
impl Linter {
fn report(&mut self, error: Violation) {
self.errors.push(error);
}
fn ignore(&mut self, ignore: Ignore) {
self.ignores.push(ignore);
}
#[must_use]
pub fn lint(&mut self, file: &Parse<SourceFile>, text: &str) -> Vec<Violation> {
if self.rules.contains(&Rule::AddingFieldWithDefault) {
adding_field_with_default(self, file);
}
if self.rules.contains(&Rule::AddingForeignKeyConstraint) {
adding_foreign_key_constraint(self, file);
}
if self.rules.contains(&Rule::AddingNotNullableField) {
adding_not_null_field(self, file);
}
if self.rules.contains(&Rule::AddingSerialPrimaryKeyField) {
adding_primary_key_constraint(self, file);
}
if self.rules.contains(&Rule::AddingRequiredField) {
adding_required_field(self, file);
}
if self.rules.contains(&Rule::BanDropDatabase) {
ban_drop_database(self, file);
}
if self.rules.contains(&Rule::BanCharField) {
ban_char_field(self, file);
}
if self
.rules
.contains(&Rule::BanConcurrentIndexCreationInTransaction)
{
ban_concurrent_index_creation_in_transaction(self, file);
}
if self.rules.contains(&Rule::BanDropColumn) {
ban_drop_column(self, file);
}
if self.rules.contains(&Rule::BanDropNotNull) {
ban_drop_not_null(self, file);
}
if self.rules.contains(&Rule::BanDropTable) {
ban_drop_table(self, file);
}
if self.rules.contains(&Rule::ChangingColumnType) {
changing_column_type(self, file);
}
if self.rules.contains(&Rule::ConstraintMissingNotValid) {
constraint_missing_not_valid(self, file);
}
if self.rules.contains(&Rule::DisallowedUniqueConstraint) {
disallow_unique_constraint(self, file);
}
if self.rules.contains(&Rule::PreferBigintOverInt) {
prefer_bigint_over_int(self, file);
}
if self.rules.contains(&Rule::PreferBigintOverSmallint) {
prefer_bigint_over_smallint(self, file);
}
if self.rules.contains(&Rule::PreferIdentity) {
prefer_identity(self, file);
}
if self.rules.contains(&Rule::PreferRepack) {
prefer_repack(self, file);
}
if self.rules.contains(&Rule::PreferRobustStmts) {
prefer_robust_stmts(self, file);
}
if self.rules.contains(&Rule::PreferTextField) {
prefer_text_field(self, file);
}
if self.rules.contains(&Rule::PreferTimestampTz) {
prefer_timestamptz(self, file);
}
if self.rules.contains(&Rule::RenamingColumn) {
renaming_column(self, file);
}
if self.rules.contains(&Rule::RenamingTable) {
renaming_table(self, file);
}
if self.rules.contains(&Rule::RequireConcurrentIndexCreation) {
require_concurrent_index_creation(self, file);
}
if self.rules.contains(&Rule::RequireConcurrentIndexDeletion) {
require_concurrent_index_deletion(self, file);
}
if self.rules.contains(&Rule::BanCreateDomainWithConstraint) {
ban_create_domain_with_constraint(self, file);
}
if self.rules.contains(&Rule::BanAlterDomainWithAddConstraint) {
ban_alter_domain_with_add_constraint(self, file);
}
if self.rules.contains(&Rule::TransactionNesting) {
transaction_nesting(self, file);
}
if self.rules.contains(&Rule::BanTruncateCascade) {
ban_truncate_cascade(self, file);
}
if self.rules.contains(&Rule::RequireTimeoutSettings) {
require_timeout_settings(self, file);
}
if self.rules.contains(&Rule::BanUncommittedTransaction) {
ban_uncommitted_transaction(self, file);
}
if self.rules.contains(&Rule::RequireEnumValueOrdering) {
require_enum_value_ordering(self, file);
}
if self.rules.contains(&Rule::RequireTableSchema) {
require_table_schema(self, file);
}
if self.rules.contains(&Rule::IdentifierTooLong) {
identifier_too_long(self, file);
}
if self.rules.contains(&Rule::RequireConcurrentPartitionDetach) {
require_concurrent_partition_detach(self, file);
}
if self.rules.contains(&Rule::RequireConcurrentReindex) {
require_concurrent_reindex(self, file);
}
find_ignores(self, &file.syntax_node());
self.errors(text)
}
fn errors(&mut self, text: &str) -> Vec<Violation> {
let ignore_index = IgnoreIndex::new(text, &self.ignores);
let mut errors: Vec<Violation> = self
.errors
.iter()
.filter(|err| !ignore_index.contains(err.text_range, err.code))
.cloned()
.collect::<Vec<_>>();
errors.sort_by_key(|x| x.text_range.start());
errors
}
fn default_rules() -> FxHashSet<Rule> {
all::<Rule>()
.filter(|r| !r.is_opt_in())
.collect::<FxHashSet<_>>()
}
pub fn with_default_rules() -> Self {
let rules = Linter::default_rules();
Linter::from(rules)
}
pub fn with_rules(include: &[Rule], exclude: &[Rule]) -> Self {
let mut default_rules = Linter::default_rules();
for rule in include {
default_rules.insert(*rule);
}
for rule in exclude {
default_rules.remove(rule);
}
Linter::from(default_rules)
}
pub fn from(rules: impl IntoIterator<Item = Rule>) -> Self {
Self {
errors: vec![],
ignores: vec![],
rules: rules.into_iter().collect(),
settings: Default::default(),
}
}
}
#[cfg(test)]
mod tests {
use insta::assert_debug_snapshot;
use super::*;
#[test]
fn prefer_timestamp_aliases() {
let rule1: Rule = "prefer-timestamp-tz".parse().unwrap();
let rule2: Rule = "prefer-timestamptz".parse().unwrap();
assert_eq!(rule1, rule2);
assert_debug_snapshot!(rule1, @"PreferTimestampTz");
}
#[test]
fn invalid_rule_name() {
let result: Result<Rule, _> = "invalid-rule-name".parse();
assert!(result.is_err());
}
#[test]
fn with_rules_opt_in_disabled_by_default() {
let linter = Linter::with_rules(&[], &[]);
assert!(!linter.rules.contains(&Rule::RequireTableSchema));
}
#[test]
fn with_rules_opt_in_enabled_via_include() {
let linter = Linter::with_rules(&[Rule::RequireTableSchema], &[]);
assert!(linter.rules.contains(&Rule::RequireTableSchema));
}
#[test]
fn with_rules_exclude_takes_precedence_over_include() {
let linter = Linter::with_rules(&[Rule::RequireTableSchema], &[Rule::RequireTableSchema]);
assert!(!linter.rules.contains(&Rule::RequireTableSchema));
}
#[test]
fn with_rules_exclude_removes_default_rule() {
let linter = Linter::with_rules(&[], &[Rule::BanDropTable]);
assert!(!linter.rules.contains(&Rule::BanDropTable));
}
}