Skip to main content

squawk_linter/
lib.rs

1use rustc_hash::FxHashSet;
2use std::fmt;
3
4use enum_iterator::Sequence;
5use enum_iterator::all;
6pub use ignore::Ignore;
7use ignore::find_ignores;
8use ignore_index::IgnoreIndex;
9use rowan::TextRange;
10use rowan::TextSize;
11use serde::Deserialize;
12
13use squawk_syntax::SyntaxNode;
14use squawk_syntax::{Parse, SourceFile};
15
16pub use version::Version;
17
18pub mod analyze;
19pub mod ignore;
20mod ignore_index;
21mod version;
22mod visitors;
23
24mod rules;
25
26#[cfg(test)]
27mod test_utils;
28use rules::adding_field_with_default;
29use rules::adding_foreign_key_constraint;
30use rules::adding_not_null_field;
31use rules::adding_primary_key_constraint;
32use rules::adding_required_field;
33use rules::ban_alter_domain_with_add_constraint;
34use rules::ban_char_field;
35use rules::ban_concurrent_index_creation_in_transaction;
36use rules::ban_create_domain_with_constraint;
37use rules::ban_drop_column;
38use rules::ban_drop_database;
39use rules::ban_drop_not_null;
40use rules::ban_drop_table;
41use rules::ban_truncate_cascade;
42use rules::ban_uncommitted_transaction;
43use rules::changing_column_type;
44use rules::constraint_missing_not_valid;
45use rules::disallow_unique_constraint;
46use rules::prefer_bigint_over_int;
47use rules::prefer_bigint_over_smallint;
48use rules::prefer_identity;
49use rules::prefer_robust_stmts;
50use rules::prefer_text_field;
51use rules::prefer_timestamptz;
52use rules::renaming_column;
53use rules::renaming_table;
54use rules::require_concurrent_index_creation;
55use rules::require_concurrent_index_deletion;
56use rules::require_enum_value_ordering;
57use rules::require_table_schema;
58use rules::require_timeout_settings;
59use rules::transaction_nesting;
60// xtask:new-rule:rule-import
61
62#[derive(Debug, PartialEq, Clone, Copy, Hash, Eq, Sequence)]
63pub enum Rule {
64    RequireConcurrentIndexCreation,
65    RequireConcurrentIndexDeletion,
66    ConstraintMissingNotValid,
67    AddingFieldWithDefault,
68    AddingForeignKeyConstraint,
69    ChangingColumnType,
70    AddingNotNullableField,
71    AddingSerialPrimaryKeyField,
72    RenamingColumn,
73    RenamingTable,
74    DisallowedUniqueConstraint,
75    BanDropDatabase,
76    PreferBigintOverInt,
77    PreferBigintOverSmallint,
78    PreferIdentity,
79    PreferRobustStmts,
80    PreferTextField,
81    PreferTimestampTz,
82    BanCharField,
83    BanDropColumn,
84    BanDropTable,
85    BanDropNotNull,
86    TransactionNesting,
87    AddingRequiredField,
88    BanConcurrentIndexCreationInTransaction,
89    UnusedIgnore,
90    BanCreateDomainWithConstraint,
91    BanAlterDomainWithAddConstraint,
92    BanTruncateCascade,
93    RequireTimeoutSettings,
94    BanUncommittedTransaction,
95    RequireEnumValueOrdering,
96    RequireTableSchema,
97    // xtask:new-rule:error-name
98}
99
100impl Rule {
101    /// Rules that are opt-in are not enabled by default.
102    /// They must be explicitly included via configuration.
103    pub fn is_opt_in(&self) -> bool {
104        matches!(self, Rule::RequireTableSchema)
105    }
106}
107
108impl TryFrom<&str> for Rule {
109    type Error = String;
110
111    fn try_from(s: &str) -> Result<Self, Self::Error> {
112        match s {
113            "require-concurrent-index-creation" => Ok(Rule::RequireConcurrentIndexCreation),
114            "require-concurrent-index-deletion" => Ok(Rule::RequireConcurrentIndexDeletion),
115            "constraint-missing-not-valid" => Ok(Rule::ConstraintMissingNotValid),
116            "adding-field-with-default" => Ok(Rule::AddingFieldWithDefault),
117            "adding-foreign-key-constraint" => Ok(Rule::AddingForeignKeyConstraint),
118            "changing-column-type" => Ok(Rule::ChangingColumnType),
119            "adding-not-nullable-field" => Ok(Rule::AddingNotNullableField),
120            "adding-serial-primary-key-field" => Ok(Rule::AddingSerialPrimaryKeyField),
121            "renaming-column" => Ok(Rule::RenamingColumn),
122            "renaming-table" => Ok(Rule::RenamingTable),
123            "disallowed-unique-constraint" => Ok(Rule::DisallowedUniqueConstraint),
124            "ban-drop-database" => Ok(Rule::BanDropDatabase),
125            "prefer-bigint-over-int" => Ok(Rule::PreferBigintOverInt),
126            "prefer-bigint-over-smallint" => Ok(Rule::PreferBigintOverSmallint),
127            "prefer-identity" => Ok(Rule::PreferIdentity),
128            "prefer-robust-stmts" => Ok(Rule::PreferRobustStmts),
129            "prefer-text-field" => Ok(Rule::PreferTextField),
130            // this is typo'd so we just support both
131            "prefer-timestamptz" => Ok(Rule::PreferTimestampTz),
132            "prefer-timestamp-tz" => Ok(Rule::PreferTimestampTz),
133            "ban-char-field" => Ok(Rule::BanCharField),
134            "ban-drop-column" => Ok(Rule::BanDropColumn),
135            "ban-drop-table" => Ok(Rule::BanDropTable),
136            "ban-drop-not-null" => Ok(Rule::BanDropNotNull),
137            "transaction-nesting" => Ok(Rule::TransactionNesting),
138            "adding-required-field" => Ok(Rule::AddingRequiredField),
139            "ban-concurrent-index-creation-in-transaction" => {
140                Ok(Rule::BanConcurrentIndexCreationInTransaction)
141            }
142            "ban-create-domain-with-constraint" => Ok(Rule::BanCreateDomainWithConstraint),
143            "ban-alter-domain-with-add-constraint" => Ok(Rule::BanAlterDomainWithAddConstraint),
144            "ban-truncate-cascade" => Ok(Rule::BanTruncateCascade),
145            "require-timeout-settings" => Ok(Rule::RequireTimeoutSettings),
146            "ban-uncommitted-transaction" => Ok(Rule::BanUncommittedTransaction),
147            "require-enum-value-ordering" => Ok(Rule::RequireEnumValueOrdering),
148            "require-table-schema" => Ok(Rule::RequireTableSchema),
149            // xtask:new-rule:str-name
150            _ => Err(format!("Unknown violation name: {s}")),
151        }
152    }
153}
154
155#[derive(Debug, Clone, PartialEq, Eq)]
156pub struct UnknownRuleName {
157    val: String,
158}
159
160impl std::fmt::Display for UnknownRuleName {
161    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
162        write!(f, "invalid rule name {}", self.val)
163    }
164}
165
166impl std::error::Error for UnknownRuleName {}
167
168impl std::str::FromStr for Rule {
169    type Err = UnknownRuleName;
170    fn from_str(s: &str) -> Result<Self, Self::Err> {
171        Rule::try_from(s).map_err(|_| UnknownRuleName { val: s.to_string() })
172    }
173}
174
175impl fmt::Display for Rule {
176    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177        let val = match &self {
178            Rule::RequireConcurrentIndexCreation => "require-concurrent-index-creation",
179            Rule::RequireConcurrentIndexDeletion => "require-concurrent-index-deletion",
180            Rule::ConstraintMissingNotValid => "constraint-missing-not-valid",
181            Rule::AddingFieldWithDefault => "adding-field-with-default",
182            Rule::AddingForeignKeyConstraint => "adding-foreign-key-constraint",
183            Rule::ChangingColumnType => "changing-column-type",
184            Rule::AddingNotNullableField => "adding-not-nullable-field",
185            Rule::AddingSerialPrimaryKeyField => "adding-serial-primary-key-field",
186            Rule::RenamingColumn => "renaming-column",
187            Rule::RenamingTable => "renaming-table",
188            Rule::DisallowedUniqueConstraint => "disallowed-unique-constraint",
189            Rule::BanDropDatabase => "ban-drop-database",
190            Rule::PreferBigintOverInt => "prefer-bigint-over-int",
191            Rule::PreferBigintOverSmallint => "prefer-bigint-over-smallint",
192            Rule::PreferIdentity => "prefer-identity",
193            Rule::PreferRobustStmts => "prefer-robust-stmts",
194            Rule::PreferTextField => "prefer-text-field",
195            Rule::PreferTimestampTz => "prefer-timestamp-tz",
196            Rule::BanCharField => "ban-char-field",
197            Rule::BanDropColumn => "ban-drop-column",
198            Rule::BanDropTable => "ban-drop-table",
199            Rule::BanDropNotNull => "ban-drop-not-null",
200            Rule::TransactionNesting => "transaction-nesting",
201            Rule::AddingRequiredField => "adding-required-field",
202            Rule::BanConcurrentIndexCreationInTransaction => {
203                "ban-concurrent-index-creation-in-transaction"
204            }
205            Rule::BanCreateDomainWithConstraint => "ban-create-domain-with-constraint",
206            Rule::UnusedIgnore => "unused-ignore",
207            Rule::BanAlterDomainWithAddConstraint => "ban-alter-domain-with-add-constraint",
208            Rule::BanTruncateCascade => "ban-truncate-cascade",
209            Rule::RequireTimeoutSettings => "require-timeout-settings",
210            Rule::BanUncommittedTransaction => "ban-uncommitted-transaction",
211            Rule::RequireEnumValueOrdering => "require-enum-value-ordering",
212            Rule::RequireTableSchema => "require-table-schema",
213            // xtask:new-rule:variant-to-name
214        };
215        write!(f, "{val}")
216    }
217}
218
219impl<'de> Deserialize<'de> for Rule {
220    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
221    where
222        D: serde::Deserializer<'de>,
223    {
224        let s = String::deserialize(deserializer)?;
225        s.parse().map_err(serde::de::Error::custom)
226    }
227}
228
229#[derive(Debug, Clone, PartialEq, Eq)]
230pub struct Fix {
231    pub title: String,
232    pub edits: Vec<Edit>,
233}
234
235impl Fix {
236    fn new<T: Into<String>>(title: T, edits: Vec<Edit>) -> Fix {
237        Fix {
238            title: title.into(),
239            edits,
240        }
241    }
242}
243
244#[derive(Debug, Clone, PartialEq, Eq)]
245pub struct Edit {
246    pub text_range: TextRange,
247    // TODO: does this need to be an Option?
248    pub text: Option<String>,
249}
250impl Edit {
251    pub fn insert<T: Into<String>>(text: T, at: TextSize) -> Self {
252        Self {
253            text_range: TextRange::new(at, at),
254            text: Some(text.into()),
255        }
256    }
257    pub fn replace<T: Into<String>>(text_range: TextRange, text: T) -> Self {
258        Self {
259            text_range,
260            text: Some(text.into()),
261        }
262    }
263    pub fn delete(text_range: TextRange) -> Self {
264        Self {
265            text_range,
266            text: None,
267        }
268    }
269}
270
271#[derive(Debug, Clone, PartialEq, Eq)]
272pub struct Violation {
273    // TODO: should this be String instead?
274    pub code: Rule,
275    pub message: String,
276    pub text_range: TextRange,
277    pub help: Option<String>,
278    pub fix: Option<Fix>,
279}
280
281impl Violation {
282    #[must_use]
283    pub fn for_node(code: Rule, message: String, node: &SyntaxNode) -> Self {
284        let range = node.text_range();
285
286        let start = node
287            .children_with_tokens()
288            .find(|x| !x.kind().is_trivia())
289            .map(|x| x.text_range().start())
290            // Not sure we actually hit this, but just being safe
291            .unwrap_or_else(|| range.start());
292
293        Self {
294            code,
295            text_range: TextRange::new(start, range.end()),
296            message,
297            help: None,
298            fix: None,
299        }
300    }
301
302    #[must_use]
303    pub fn for_range(code: Rule, message: String, text_range: TextRange) -> Self {
304        Self {
305            code,
306            text_range,
307            message,
308            help: None,
309            fix: None,
310        }
311    }
312
313    fn fix(mut self, fix: Option<Fix>) -> Violation {
314        self.fix = fix;
315        self
316    }
317    fn help(mut self, help: impl Into<String>) -> Violation {
318        self.help = Some(help.into());
319        self
320    }
321}
322
323#[derive(Default)]
324pub struct LinterSettings {
325    pub pg_version: Version,
326    pub assume_in_transaction: bool,
327}
328
329pub struct Linter {
330    errors: Vec<Violation>,
331    ignores: Vec<Ignore>,
332    pub rules: FxHashSet<Rule>,
333    pub settings: LinterSettings,
334}
335
336impl Linter {
337    fn report(&mut self, error: Violation) {
338        self.errors.push(error);
339    }
340
341    fn ignore(&mut self, ignore: Ignore) {
342        self.ignores.push(ignore);
343    }
344
345    #[must_use]
346    pub fn lint(&mut self, file: &Parse<SourceFile>, text: &str) -> Vec<Violation> {
347        if self.rules.contains(&Rule::AddingFieldWithDefault) {
348            adding_field_with_default(self, file);
349        }
350        if self.rules.contains(&Rule::AddingForeignKeyConstraint) {
351            adding_foreign_key_constraint(self, file);
352        }
353        if self.rules.contains(&Rule::AddingNotNullableField) {
354            adding_not_null_field(self, file);
355        }
356        if self.rules.contains(&Rule::AddingSerialPrimaryKeyField) {
357            adding_primary_key_constraint(self, file);
358        }
359        if self.rules.contains(&Rule::AddingRequiredField) {
360            adding_required_field(self, file);
361        }
362        if self.rules.contains(&Rule::BanDropDatabase) {
363            ban_drop_database(self, file);
364        }
365        if self.rules.contains(&Rule::BanCharField) {
366            ban_char_field(self, file);
367        }
368        if self
369            .rules
370            .contains(&Rule::BanConcurrentIndexCreationInTransaction)
371        {
372            ban_concurrent_index_creation_in_transaction(self, file);
373        }
374        if self.rules.contains(&Rule::BanDropColumn) {
375            ban_drop_column(self, file);
376        }
377        if self.rules.contains(&Rule::BanDropNotNull) {
378            ban_drop_not_null(self, file);
379        }
380        if self.rules.contains(&Rule::BanDropTable) {
381            ban_drop_table(self, file);
382        }
383        if self.rules.contains(&Rule::ChangingColumnType) {
384            changing_column_type(self, file);
385        }
386        if self.rules.contains(&Rule::ConstraintMissingNotValid) {
387            constraint_missing_not_valid(self, file);
388        }
389        if self.rules.contains(&Rule::DisallowedUniqueConstraint) {
390            disallow_unique_constraint(self, file);
391        }
392        if self.rules.contains(&Rule::PreferBigintOverInt) {
393            prefer_bigint_over_int(self, file);
394        }
395        if self.rules.contains(&Rule::PreferBigintOverSmallint) {
396            prefer_bigint_over_smallint(self, file);
397        }
398        if self.rules.contains(&Rule::PreferIdentity) {
399            prefer_identity(self, file);
400        }
401        if self.rules.contains(&Rule::PreferRobustStmts) {
402            prefer_robust_stmts(self, file);
403        }
404        if self.rules.contains(&Rule::PreferTextField) {
405            prefer_text_field(self, file);
406        }
407        if self.rules.contains(&Rule::PreferTimestampTz) {
408            prefer_timestamptz(self, file);
409        }
410        if self.rules.contains(&Rule::RenamingColumn) {
411            renaming_column(self, file);
412        }
413        if self.rules.contains(&Rule::RenamingTable) {
414            renaming_table(self, file);
415        }
416        if self.rules.contains(&Rule::RequireConcurrentIndexCreation) {
417            require_concurrent_index_creation(self, file);
418        }
419        if self.rules.contains(&Rule::RequireConcurrentIndexDeletion) {
420            require_concurrent_index_deletion(self, file);
421        }
422        if self.rules.contains(&Rule::BanCreateDomainWithConstraint) {
423            ban_create_domain_with_constraint(self, file);
424        }
425        if self.rules.contains(&Rule::BanAlterDomainWithAddConstraint) {
426            ban_alter_domain_with_add_constraint(self, file);
427        }
428        if self.rules.contains(&Rule::TransactionNesting) {
429            transaction_nesting(self, file);
430        }
431        if self.rules.contains(&Rule::BanTruncateCascade) {
432            ban_truncate_cascade(self, file);
433        }
434        if self.rules.contains(&Rule::RequireTimeoutSettings) {
435            require_timeout_settings(self, file);
436        }
437        if self.rules.contains(&Rule::BanUncommittedTransaction) {
438            ban_uncommitted_transaction(self, file);
439        }
440        if self.rules.contains(&Rule::RequireEnumValueOrdering) {
441            require_enum_value_ordering(self, file);
442        }
443        if self.rules.contains(&Rule::RequireTableSchema) {
444            require_table_schema(self, file);
445        }
446        // xtask:new-rule:rule-call
447
448        // locate any ignores in the file
449        find_ignores(self, &file.syntax_node());
450
451        self.errors(text)
452    }
453
454    fn errors(&mut self, text: &str) -> Vec<Violation> {
455        let ignore_index = IgnoreIndex::new(text, &self.ignores);
456        let mut errors: Vec<Violation> = self
457            .errors
458            .iter()
459            // TODO: we should have errors for when there was an ignore but that
460            // ignore didn't actually ignore anything
461            .filter(|err| !ignore_index.contains(err.text_range, err.code))
462            .cloned()
463            .collect::<Vec<_>>();
464        // ensure we order them by where they appear in the file
465        errors.sort_by_key(|x| x.text_range.start());
466        errors
467    }
468
469    fn default_rules() -> FxHashSet<Rule> {
470        all::<Rule>()
471            .filter(|r| !r.is_opt_in())
472            .collect::<FxHashSet<_>>()
473    }
474
475    pub fn with_default_rules() -> Self {
476        let rules = Linter::default_rules();
477        Linter::from(rules)
478    }
479
480    pub fn with_rules(include: &[Rule], exclude: &[Rule]) -> Self {
481        let mut default_rules = Linter::default_rules();
482
483        for rule in include {
484            default_rules.insert(*rule);
485        }
486
487        for rule in exclude {
488            default_rules.remove(rule);
489        }
490
491        Linter::from(default_rules)
492    }
493
494    pub fn from(rules: impl IntoIterator<Item = Rule>) -> Self {
495        Self {
496            errors: vec![],
497            ignores: vec![],
498            rules: rules.into_iter().collect(),
499            settings: Default::default(),
500        }
501    }
502}
503
504#[cfg(test)]
505mod tests {
506    use insta::assert_debug_snapshot;
507
508    use super::*;
509
510    #[test]
511    fn prefer_timestamp_aliases() {
512        let rule1: Rule = "prefer-timestamp-tz".parse().unwrap();
513        let rule2: Rule = "prefer-timestamptz".parse().unwrap();
514        assert_eq!(rule1, rule2);
515        assert_debug_snapshot!(rule1, @"PreferTimestampTz");
516    }
517
518    #[test]
519    fn invalid_rule_name() {
520        let result: Result<Rule, _> = "invalid-rule-name".parse();
521        assert!(result.is_err());
522    }
523
524    #[test]
525    fn with_rules_opt_in_disabled_by_default() {
526        let linter = Linter::with_rules(&[], &[]);
527        assert!(!linter.rules.contains(&Rule::RequireTableSchema));
528    }
529
530    #[test]
531    fn with_rules_opt_in_enabled_via_include() {
532        let linter = Linter::with_rules(&[Rule::RequireTableSchema], &[]);
533        assert!(linter.rules.contains(&Rule::RequireTableSchema));
534    }
535
536    #[test]
537    fn with_rules_exclude_takes_precedence_over_include() {
538        let linter = Linter::with_rules(&[Rule::RequireTableSchema], &[Rule::RequireTableSchema]);
539        assert!(!linter.rules.contains(&Rule::RequireTableSchema));
540    }
541
542    #[test]
543    fn with_rules_exclude_removes_default_rule() {
544        let linter = Linter::with_rules(&[], &[Rule::BanDropTable]);
545        assert!(!linter.rules.contains(&Rule::BanDropTable));
546    }
547}