Skip to main content

squawk_linter/
lib.rs

1use rustc_hash::{FxBuildHasher, FxHashSet};
2use std::fmt;
3
4use enum_iterator::Sequence;
5use enum_iterator::all;
6pub use ignore::Ignore;
7use ignore::find_ignores;
8use ignore_index::IgnoreIndex;
9use rowan::TextRange;
10use rowan::TextSize;
11use serde::Deserialize;
12
13use squawk_syntax::SyntaxNode;
14use squawk_syntax::{Parse, SourceFile};
15
16pub use version::Version;
17
18pub mod analyze;
19pub mod ignore;
20mod ignore_index;
21mod version;
22mod visitors;
23
24mod rules;
25
26#[cfg(test)]
27mod test_utils;
28use rules::adding_field_with_default;
29use rules::adding_foreign_key_constraint;
30use rules::adding_not_null_field;
31use rules::adding_primary_key_constraint;
32use rules::adding_required_field;
33use rules::ban_alter_domain_with_add_constraint;
34use rules::ban_char_field;
35use rules::ban_concurrent_index_creation_in_transaction;
36use rules::ban_create_domain_with_constraint;
37use rules::ban_drop_column;
38use rules::ban_drop_database;
39use rules::ban_drop_not_null;
40use rules::ban_drop_table;
41use rules::ban_truncate_cascade;
42use rules::ban_uncommitted_transaction;
43use rules::changing_column_type;
44use rules::constraint_missing_not_valid;
45use rules::disallow_unique_constraint;
46use rules::prefer_bigint_over_int;
47use rules::prefer_bigint_over_smallint;
48use rules::prefer_identity;
49use rules::prefer_robust_stmts;
50use rules::prefer_text_field;
51use rules::prefer_timestamptz;
52use rules::renaming_column;
53use rules::renaming_table;
54use rules::require_concurrent_index_creation;
55use rules::require_concurrent_index_deletion;
56use rules::require_enum_value_ordering;
57use rules::require_timeout_settings;
58use rules::transaction_nesting;
59// xtask:new-rule:rule-import
60
61#[derive(Debug, PartialEq, Clone, Copy, Hash, Eq, Sequence)]
62pub enum Rule {
63    RequireConcurrentIndexCreation,
64    RequireConcurrentIndexDeletion,
65    ConstraintMissingNotValid,
66    AddingFieldWithDefault,
67    AddingForeignKeyConstraint,
68    ChangingColumnType,
69    AddingNotNullableField,
70    AddingSerialPrimaryKeyField,
71    RenamingColumn,
72    RenamingTable,
73    DisallowedUniqueConstraint,
74    BanDropDatabase,
75    PreferBigintOverInt,
76    PreferBigintOverSmallint,
77    PreferIdentity,
78    PreferRobustStmts,
79    PreferTextField,
80    PreferTimestampTz,
81    BanCharField,
82    BanDropColumn,
83    BanDropTable,
84    BanDropNotNull,
85    TransactionNesting,
86    AddingRequiredField,
87    BanConcurrentIndexCreationInTransaction,
88    UnusedIgnore,
89    BanCreateDomainWithConstraint,
90    BanAlterDomainWithAddConstraint,
91    BanTruncateCascade,
92    RequireTimeoutSettings,
93    BanUncommittedTransaction,
94    RequireEnumValueOrdering,
95    // xtask:new-rule:error-name
96}
97
98impl TryFrom<&str> for Rule {
99    type Error = String;
100
101    fn try_from(s: &str) -> Result<Self, Self::Error> {
102        match s {
103            "require-concurrent-index-creation" => Ok(Rule::RequireConcurrentIndexCreation),
104            "require-concurrent-index-deletion" => Ok(Rule::RequireConcurrentIndexDeletion),
105            "constraint-missing-not-valid" => Ok(Rule::ConstraintMissingNotValid),
106            "adding-field-with-default" => Ok(Rule::AddingFieldWithDefault),
107            "adding-foreign-key-constraint" => Ok(Rule::AddingForeignKeyConstraint),
108            "changing-column-type" => Ok(Rule::ChangingColumnType),
109            "adding-not-nullable-field" => Ok(Rule::AddingNotNullableField),
110            "adding-serial-primary-key-field" => Ok(Rule::AddingSerialPrimaryKeyField),
111            "renaming-column" => Ok(Rule::RenamingColumn),
112            "renaming-table" => Ok(Rule::RenamingTable),
113            "disallowed-unique-constraint" => Ok(Rule::DisallowedUniqueConstraint),
114            "ban-drop-database" => Ok(Rule::BanDropDatabase),
115            "prefer-bigint-over-int" => Ok(Rule::PreferBigintOverInt),
116            "prefer-bigint-over-smallint" => Ok(Rule::PreferBigintOverSmallint),
117            "prefer-identity" => Ok(Rule::PreferIdentity),
118            "prefer-robust-stmts" => Ok(Rule::PreferRobustStmts),
119            "prefer-text-field" => Ok(Rule::PreferTextField),
120            // this is typo'd so we just support both
121            "prefer-timestamptz" => Ok(Rule::PreferTimestampTz),
122            "prefer-timestamp-tz" => Ok(Rule::PreferTimestampTz),
123            "ban-char-field" => Ok(Rule::BanCharField),
124            "ban-drop-column" => Ok(Rule::BanDropColumn),
125            "ban-drop-table" => Ok(Rule::BanDropTable),
126            "ban-drop-not-null" => Ok(Rule::BanDropNotNull),
127            "transaction-nesting" => Ok(Rule::TransactionNesting),
128            "adding-required-field" => Ok(Rule::AddingRequiredField),
129            "ban-concurrent-index-creation-in-transaction" => {
130                Ok(Rule::BanConcurrentIndexCreationInTransaction)
131            }
132            "ban-create-domain-with-constraint" => Ok(Rule::BanCreateDomainWithConstraint),
133            "ban-alter-domain-with-add-constraint" => Ok(Rule::BanAlterDomainWithAddConstraint),
134            "ban-truncate-cascade" => Ok(Rule::BanTruncateCascade),
135            "require-timeout-settings" => Ok(Rule::RequireTimeoutSettings),
136            "ban-uncommitted-transaction" => Ok(Rule::BanUncommittedTransaction),
137            "require-enum-value-ordering" => Ok(Rule::RequireEnumValueOrdering),
138            // xtask:new-rule:str-name
139            _ => Err(format!("Unknown violation name: {s}")),
140        }
141    }
142}
143
144#[derive(Debug, Clone, PartialEq, Eq)]
145pub struct UnknownRuleName {
146    val: String,
147}
148
149impl std::fmt::Display for UnknownRuleName {
150    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
151        write!(f, "invalid rule name {}", self.val)
152    }
153}
154
155impl std::error::Error for UnknownRuleName {}
156
157impl std::str::FromStr for Rule {
158    type Err = UnknownRuleName;
159    fn from_str(s: &str) -> Result<Self, Self::Err> {
160        Rule::try_from(s).map_err(|_| UnknownRuleName { val: s.to_string() })
161    }
162}
163
164impl fmt::Display for Rule {
165    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166        let val = match &self {
167            Rule::RequireConcurrentIndexCreation => "require-concurrent-index-creation",
168            Rule::RequireConcurrentIndexDeletion => "require-concurrent-index-deletion",
169            Rule::ConstraintMissingNotValid => "constraint-missing-not-valid",
170            Rule::AddingFieldWithDefault => "adding-field-with-default",
171            Rule::AddingForeignKeyConstraint => "adding-foreign-key-constraint",
172            Rule::ChangingColumnType => "changing-column-type",
173            Rule::AddingNotNullableField => "adding-not-nullable-field",
174            Rule::AddingSerialPrimaryKeyField => "adding-serial-primary-key-field",
175            Rule::RenamingColumn => "renaming-column",
176            Rule::RenamingTable => "renaming-table",
177            Rule::DisallowedUniqueConstraint => "disallowed-unique-constraint",
178            Rule::BanDropDatabase => "ban-drop-database",
179            Rule::PreferBigintOverInt => "prefer-bigint-over-int",
180            Rule::PreferBigintOverSmallint => "prefer-bigint-over-smallint",
181            Rule::PreferIdentity => "prefer-identity",
182            Rule::PreferRobustStmts => "prefer-robust-stmts",
183            Rule::PreferTextField => "prefer-text-field",
184            Rule::PreferTimestampTz => "prefer-timestamp-tz",
185            Rule::BanCharField => "ban-char-field",
186            Rule::BanDropColumn => "ban-drop-column",
187            Rule::BanDropTable => "ban-drop-table",
188            Rule::BanDropNotNull => "ban-drop-not-null",
189            Rule::TransactionNesting => "transaction-nesting",
190            Rule::AddingRequiredField => "adding-required-field",
191            Rule::BanConcurrentIndexCreationInTransaction => {
192                "ban-concurrent-index-creation-in-transaction"
193            }
194            Rule::BanCreateDomainWithConstraint => "ban-create-domain-with-constraint",
195            Rule::UnusedIgnore => "unused-ignore",
196            Rule::BanAlterDomainWithAddConstraint => "ban-alter-domain-with-add-constraint",
197            Rule::BanTruncateCascade => "ban-truncate-cascade",
198            Rule::RequireTimeoutSettings => "require-timeout-settings",
199            Rule::BanUncommittedTransaction => "ban-uncommitted-transaction",
200            Rule::RequireEnumValueOrdering => "require-enum-value-ordering",
201            // xtask:new-rule:variant-to-name
202        };
203        write!(f, "{val}")
204    }
205}
206
207impl<'de> Deserialize<'de> for Rule {
208    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
209    where
210        D: serde::Deserializer<'de>,
211    {
212        let s = String::deserialize(deserializer)?;
213        s.parse().map_err(serde::de::Error::custom)
214    }
215}
216
217#[derive(Debug, Clone, PartialEq, Eq)]
218pub struct Fix {
219    pub title: String,
220    pub edits: Vec<Edit>,
221}
222
223impl Fix {
224    fn new<T: Into<String>>(title: T, edits: Vec<Edit>) -> Fix {
225        Fix {
226            title: title.into(),
227            edits,
228        }
229    }
230}
231
232#[derive(Debug, Clone, PartialEq, Eq)]
233pub struct Edit {
234    pub text_range: TextRange,
235    // TODO: does this need to be an Option?
236    pub text: Option<String>,
237}
238impl Edit {
239    pub fn insert<T: Into<String>>(text: T, at: TextSize) -> Self {
240        Self {
241            text_range: TextRange::new(at, at),
242            text: Some(text.into()),
243        }
244    }
245    pub fn replace<T: Into<String>>(text_range: TextRange, text: T) -> Self {
246        Self {
247            text_range,
248            text: Some(text.into()),
249        }
250    }
251    pub fn delete(text_range: TextRange) -> Self {
252        Self {
253            text_range,
254            text: None,
255        }
256    }
257}
258
259#[derive(Debug, Clone, PartialEq, Eq)]
260pub struct Violation {
261    // TODO: should this be String instead?
262    pub code: Rule,
263    pub message: String,
264    pub text_range: TextRange,
265    pub help: Option<String>,
266    pub fix: Option<Fix>,
267}
268
269impl Violation {
270    #[must_use]
271    pub fn for_node(code: Rule, message: String, node: &SyntaxNode) -> Self {
272        let range = node.text_range();
273
274        let start = node
275            .children_with_tokens()
276            .find(|x| !x.kind().is_trivia())
277            .map(|x| x.text_range().start())
278            // Not sure we actually hit this, but just being safe
279            .unwrap_or_else(|| range.start());
280
281        Self {
282            code,
283            text_range: TextRange::new(start, range.end()),
284            message,
285            help: None,
286            fix: None,
287        }
288    }
289
290    #[must_use]
291    pub fn for_range(code: Rule, message: String, text_range: TextRange) -> Self {
292        Self {
293            code,
294            text_range,
295            message,
296            help: None,
297            fix: None,
298        }
299    }
300
301    fn fix(mut self, fix: Option<Fix>) -> Violation {
302        self.fix = fix;
303        self
304    }
305    fn help(mut self, help: impl Into<String>) -> Violation {
306        self.help = Some(help.into());
307        self
308    }
309}
310
311#[derive(Default)]
312pub struct LinterSettings {
313    pub pg_version: Version,
314    pub assume_in_transaction: bool,
315}
316
317pub struct Linter {
318    errors: Vec<Violation>,
319    ignores: Vec<Ignore>,
320    pub rules: FxHashSet<Rule>,
321    pub settings: LinterSettings,
322}
323
324impl Linter {
325    fn report(&mut self, error: Violation) {
326        self.errors.push(error);
327    }
328
329    fn ignore(&mut self, ignore: Ignore) {
330        self.ignores.push(ignore);
331    }
332
333    #[must_use]
334    pub fn lint(&mut self, file: &Parse<SourceFile>, text: &str) -> Vec<Violation> {
335        if self.rules.contains(&Rule::AddingFieldWithDefault) {
336            adding_field_with_default(self, file);
337        }
338        if self.rules.contains(&Rule::AddingForeignKeyConstraint) {
339            adding_foreign_key_constraint(self, file);
340        }
341        if self.rules.contains(&Rule::AddingNotNullableField) {
342            adding_not_null_field(self, file);
343        }
344        if self.rules.contains(&Rule::AddingSerialPrimaryKeyField) {
345            adding_primary_key_constraint(self, file);
346        }
347        if self.rules.contains(&Rule::AddingRequiredField) {
348            adding_required_field(self, file);
349        }
350        if self.rules.contains(&Rule::BanDropDatabase) {
351            ban_drop_database(self, file);
352        }
353        if self.rules.contains(&Rule::BanCharField) {
354            ban_char_field(self, file);
355        }
356        if self
357            .rules
358            .contains(&Rule::BanConcurrentIndexCreationInTransaction)
359        {
360            ban_concurrent_index_creation_in_transaction(self, file);
361        }
362        if self.rules.contains(&Rule::BanDropColumn) {
363            ban_drop_column(self, file);
364        }
365        if self.rules.contains(&Rule::BanDropNotNull) {
366            ban_drop_not_null(self, file);
367        }
368        if self.rules.contains(&Rule::BanDropTable) {
369            ban_drop_table(self, file);
370        }
371        if self.rules.contains(&Rule::ChangingColumnType) {
372            changing_column_type(self, file);
373        }
374        if self.rules.contains(&Rule::ConstraintMissingNotValid) {
375            constraint_missing_not_valid(self, file);
376        }
377        if self.rules.contains(&Rule::DisallowedUniqueConstraint) {
378            disallow_unique_constraint(self, file);
379        }
380        if self.rules.contains(&Rule::PreferBigintOverInt) {
381            prefer_bigint_over_int(self, file);
382        }
383        if self.rules.contains(&Rule::PreferBigintOverSmallint) {
384            prefer_bigint_over_smallint(self, file);
385        }
386        if self.rules.contains(&Rule::PreferIdentity) {
387            prefer_identity(self, file);
388        }
389        if self.rules.contains(&Rule::PreferRobustStmts) {
390            prefer_robust_stmts(self, file);
391        }
392        if self.rules.contains(&Rule::PreferTextField) {
393            prefer_text_field(self, file);
394        }
395        if self.rules.contains(&Rule::PreferTimestampTz) {
396            prefer_timestamptz(self, file);
397        }
398        if self.rules.contains(&Rule::RenamingColumn) {
399            renaming_column(self, file);
400        }
401        if self.rules.contains(&Rule::RenamingTable) {
402            renaming_table(self, file);
403        }
404        if self.rules.contains(&Rule::RequireConcurrentIndexCreation) {
405            require_concurrent_index_creation(self, file);
406        }
407        if self.rules.contains(&Rule::RequireConcurrentIndexDeletion) {
408            require_concurrent_index_deletion(self, file);
409        }
410        if self.rules.contains(&Rule::BanCreateDomainWithConstraint) {
411            ban_create_domain_with_constraint(self, file);
412        }
413        if self.rules.contains(&Rule::BanAlterDomainWithAddConstraint) {
414            ban_alter_domain_with_add_constraint(self, file);
415        }
416        if self.rules.contains(&Rule::TransactionNesting) {
417            transaction_nesting(self, file);
418        }
419        if self.rules.contains(&Rule::BanTruncateCascade) {
420            ban_truncate_cascade(self, file);
421        }
422        if self.rules.contains(&Rule::RequireTimeoutSettings) {
423            require_timeout_settings(self, file);
424        }
425        if self.rules.contains(&Rule::BanUncommittedTransaction) {
426            ban_uncommitted_transaction(self, file);
427        }
428        if self.rules.contains(&Rule::RequireEnumValueOrdering) {
429            require_enum_value_ordering(self, file);
430        }
431        // xtask:new-rule:rule-call
432
433        // locate any ignores in the file
434        find_ignores(self, &file.syntax_node());
435
436        self.errors(text)
437    }
438
439    fn errors(&mut self, text: &str) -> Vec<Violation> {
440        let ignore_index = IgnoreIndex::new(text, &self.ignores);
441        let mut errors: Vec<Violation> = self
442            .errors
443            .iter()
444            // TODO: we should have errors for when there was an ignore but that
445            // ignore didn't actually ignore anything
446            .filter(|err| !ignore_index.contains(err.text_range, err.code))
447            .cloned()
448            .collect::<Vec<_>>();
449        // ensure we order them by where they appear in the file
450        errors.sort_by_key(|x| x.text_range.start());
451        errors
452    }
453
454    pub fn with_all_rules() -> Self {
455        let rules = all::<Rule>().collect::<FxHashSet<_>>();
456        Linter::from(rules)
457    }
458
459    pub fn without_rules(exclude: &[Rule]) -> Self {
460        let all_rules = all::<Rule>().collect::<FxHashSet<_>>();
461        let mut exclude_set = FxHashSet::with_capacity_and_hasher(exclude.len(), FxBuildHasher);
462        for e in exclude {
463            exclude_set.insert(e);
464        }
465
466        let rules = all_rules
467            .into_iter()
468            .filter(|x| !exclude_set.contains(x))
469            .collect::<FxHashSet<_>>();
470
471        Linter::from(rules)
472    }
473
474    pub fn from(rules: impl IntoIterator<Item = Rule>) -> Self {
475        Self {
476            errors: vec![],
477            ignores: vec![],
478            rules: rules.into_iter().collect(),
479            settings: Default::default(),
480        }
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use insta::assert_debug_snapshot;
487
488    use super::*;
489
490    #[test]
491    fn prefer_timestamp_aliases() {
492        let rule1: Rule = "prefer-timestamp-tz".parse().unwrap();
493        let rule2: Rule = "prefer-timestamptz".parse().unwrap();
494        assert_eq!(rule1, rule2);
495        assert_debug_snapshot!(rule1, @"PreferTimestampTz");
496    }
497
498    #[test]
499    fn invalid_rule_name() {
500        let result: Result<Rule, _> = "invalid-rule-name".parse();
501        assert!(result.is_err());
502    }
503}