1use rustc_hash::FxHashSet;
2use std::fmt;
3
4use enum_iterator::Sequence;
5use enum_iterator::all;
6pub use ignore::Ignore;
7use ignore::find_ignores;
8use ignore_index::IgnoreIndex;
9use rowan::TextRange;
10use rowan::TextSize;
11use serde::Deserialize;
12
13use squawk_syntax::SyntaxNode;
14use squawk_syntax::{Parse, SourceFile};
15
16pub use version::Version;
17
18pub mod analyze;
19pub mod ignore;
20mod ignore_index;
21mod version;
22mod visitors;
23
24mod rules;
25
26#[cfg(test)]
27mod test_utils;
28use rules::adding_field_with_default;
29use rules::adding_foreign_key_constraint;
30use rules::adding_not_null_field;
31use rules::adding_primary_key_constraint;
32use rules::adding_required_field;
33use rules::ban_alter_domain_with_add_constraint;
34use rules::ban_char_field;
35use rules::ban_concurrent_index_creation_in_transaction;
36use rules::ban_create_domain_with_constraint;
37use rules::ban_drop_column;
38use rules::ban_drop_database;
39use rules::ban_drop_not_null;
40use rules::ban_drop_table;
41use rules::ban_truncate_cascade;
42use rules::ban_uncommitted_transaction;
43use rules::changing_column_type;
44use rules::constraint_missing_not_valid;
45use rules::disallow_unique_constraint;
46use rules::prefer_bigint_over_int;
47use rules::prefer_bigint_over_smallint;
48use rules::prefer_identity;
49use rules::prefer_robust_stmts;
50use rules::prefer_text_field;
51use rules::prefer_timestamptz;
52use rules::renaming_column;
53use rules::renaming_table;
54use rules::require_concurrent_index_creation;
55use rules::require_concurrent_index_deletion;
56use rules::require_enum_value_ordering;
57use rules::require_table_schema;
58use rules::require_timeout_settings;
59use rules::transaction_nesting;
60#[derive(Debug, PartialEq, Clone, Copy, Hash, Eq, Sequence)]
63pub enum Rule {
64 RequireConcurrentIndexCreation,
65 RequireConcurrentIndexDeletion,
66 ConstraintMissingNotValid,
67 AddingFieldWithDefault,
68 AddingForeignKeyConstraint,
69 ChangingColumnType,
70 AddingNotNullableField,
71 AddingSerialPrimaryKeyField,
72 RenamingColumn,
73 RenamingTable,
74 DisallowedUniqueConstraint,
75 BanDropDatabase,
76 PreferBigintOverInt,
77 PreferBigintOverSmallint,
78 PreferIdentity,
79 PreferRobustStmts,
80 PreferTextField,
81 PreferTimestampTz,
82 BanCharField,
83 BanDropColumn,
84 BanDropTable,
85 BanDropNotNull,
86 TransactionNesting,
87 AddingRequiredField,
88 BanConcurrentIndexCreationInTransaction,
89 UnusedIgnore,
90 BanCreateDomainWithConstraint,
91 BanAlterDomainWithAddConstraint,
92 BanTruncateCascade,
93 RequireTimeoutSettings,
94 BanUncommittedTransaction,
95 RequireEnumValueOrdering,
96 RequireTableSchema,
97 }
99
100impl Rule {
101 pub fn is_opt_in(&self) -> bool {
104 matches!(self, Rule::RequireTableSchema)
105 }
106}
107
108impl TryFrom<&str> for Rule {
109 type Error = String;
110
111 fn try_from(s: &str) -> Result<Self, Self::Error> {
112 match s {
113 "require-concurrent-index-creation" => Ok(Rule::RequireConcurrentIndexCreation),
114 "require-concurrent-index-deletion" => Ok(Rule::RequireConcurrentIndexDeletion),
115 "constraint-missing-not-valid" => Ok(Rule::ConstraintMissingNotValid),
116 "adding-field-with-default" => Ok(Rule::AddingFieldWithDefault),
117 "adding-foreign-key-constraint" => Ok(Rule::AddingForeignKeyConstraint),
118 "changing-column-type" => Ok(Rule::ChangingColumnType),
119 "adding-not-nullable-field" => Ok(Rule::AddingNotNullableField),
120 "adding-serial-primary-key-field" => Ok(Rule::AddingSerialPrimaryKeyField),
121 "renaming-column" => Ok(Rule::RenamingColumn),
122 "renaming-table" => Ok(Rule::RenamingTable),
123 "disallowed-unique-constraint" => Ok(Rule::DisallowedUniqueConstraint),
124 "ban-drop-database" => Ok(Rule::BanDropDatabase),
125 "prefer-bigint-over-int" => Ok(Rule::PreferBigintOverInt),
126 "prefer-bigint-over-smallint" => Ok(Rule::PreferBigintOverSmallint),
127 "prefer-identity" => Ok(Rule::PreferIdentity),
128 "prefer-robust-stmts" => Ok(Rule::PreferRobustStmts),
129 "prefer-text-field" => Ok(Rule::PreferTextField),
130 "prefer-timestamptz" => Ok(Rule::PreferTimestampTz),
132 "prefer-timestamp-tz" => Ok(Rule::PreferTimestampTz),
133 "ban-char-field" => Ok(Rule::BanCharField),
134 "ban-drop-column" => Ok(Rule::BanDropColumn),
135 "ban-drop-table" => Ok(Rule::BanDropTable),
136 "ban-drop-not-null" => Ok(Rule::BanDropNotNull),
137 "transaction-nesting" => Ok(Rule::TransactionNesting),
138 "adding-required-field" => Ok(Rule::AddingRequiredField),
139 "ban-concurrent-index-creation-in-transaction" => {
140 Ok(Rule::BanConcurrentIndexCreationInTransaction)
141 }
142 "ban-create-domain-with-constraint" => Ok(Rule::BanCreateDomainWithConstraint),
143 "ban-alter-domain-with-add-constraint" => Ok(Rule::BanAlterDomainWithAddConstraint),
144 "ban-truncate-cascade" => Ok(Rule::BanTruncateCascade),
145 "require-timeout-settings" => Ok(Rule::RequireTimeoutSettings),
146 "ban-uncommitted-transaction" => Ok(Rule::BanUncommittedTransaction),
147 "require-enum-value-ordering" => Ok(Rule::RequireEnumValueOrdering),
148 "require-table-schema" => Ok(Rule::RequireTableSchema),
149 _ => Err(format!("Unknown violation name: {s}")),
151 }
152 }
153}
154
155#[derive(Debug, Clone, PartialEq, Eq)]
156pub struct UnknownRuleName {
157 val: String,
158}
159
160impl std::fmt::Display for UnknownRuleName {
161 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
162 write!(f, "invalid rule name {}", self.val)
163 }
164}
165
166impl std::error::Error for UnknownRuleName {}
167
168impl std::str::FromStr for Rule {
169 type Err = UnknownRuleName;
170 fn from_str(s: &str) -> Result<Self, Self::Err> {
171 Rule::try_from(s).map_err(|_| UnknownRuleName { val: s.to_string() })
172 }
173}
174
175impl fmt::Display for Rule {
176 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177 let val = match &self {
178 Rule::RequireConcurrentIndexCreation => "require-concurrent-index-creation",
179 Rule::RequireConcurrentIndexDeletion => "require-concurrent-index-deletion",
180 Rule::ConstraintMissingNotValid => "constraint-missing-not-valid",
181 Rule::AddingFieldWithDefault => "adding-field-with-default",
182 Rule::AddingForeignKeyConstraint => "adding-foreign-key-constraint",
183 Rule::ChangingColumnType => "changing-column-type",
184 Rule::AddingNotNullableField => "adding-not-nullable-field",
185 Rule::AddingSerialPrimaryKeyField => "adding-serial-primary-key-field",
186 Rule::RenamingColumn => "renaming-column",
187 Rule::RenamingTable => "renaming-table",
188 Rule::DisallowedUniqueConstraint => "disallowed-unique-constraint",
189 Rule::BanDropDatabase => "ban-drop-database",
190 Rule::PreferBigintOverInt => "prefer-bigint-over-int",
191 Rule::PreferBigintOverSmallint => "prefer-bigint-over-smallint",
192 Rule::PreferIdentity => "prefer-identity",
193 Rule::PreferRobustStmts => "prefer-robust-stmts",
194 Rule::PreferTextField => "prefer-text-field",
195 Rule::PreferTimestampTz => "prefer-timestamp-tz",
196 Rule::BanCharField => "ban-char-field",
197 Rule::BanDropColumn => "ban-drop-column",
198 Rule::BanDropTable => "ban-drop-table",
199 Rule::BanDropNotNull => "ban-drop-not-null",
200 Rule::TransactionNesting => "transaction-nesting",
201 Rule::AddingRequiredField => "adding-required-field",
202 Rule::BanConcurrentIndexCreationInTransaction => {
203 "ban-concurrent-index-creation-in-transaction"
204 }
205 Rule::BanCreateDomainWithConstraint => "ban-create-domain-with-constraint",
206 Rule::UnusedIgnore => "unused-ignore",
207 Rule::BanAlterDomainWithAddConstraint => "ban-alter-domain-with-add-constraint",
208 Rule::BanTruncateCascade => "ban-truncate-cascade",
209 Rule::RequireTimeoutSettings => "require-timeout-settings",
210 Rule::BanUncommittedTransaction => "ban-uncommitted-transaction",
211 Rule::RequireEnumValueOrdering => "require-enum-value-ordering",
212 Rule::RequireTableSchema => "require-table-schema",
213 };
215 write!(f, "{val}")
216 }
217}
218
219impl<'de> Deserialize<'de> for Rule {
220 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
221 where
222 D: serde::Deserializer<'de>,
223 {
224 let s = String::deserialize(deserializer)?;
225 s.parse().map_err(serde::de::Error::custom)
226 }
227}
228
229#[derive(Debug, Clone, PartialEq, Eq)]
230pub struct Fix {
231 pub title: String,
232 pub edits: Vec<Edit>,
233}
234
235impl Fix {
236 fn new<T: Into<String>>(title: T, edits: Vec<Edit>) -> Fix {
237 Fix {
238 title: title.into(),
239 edits,
240 }
241 }
242}
243
244#[derive(Debug, Clone, PartialEq, Eq)]
245pub struct Edit {
246 pub text_range: TextRange,
247 pub text: Option<String>,
249}
250impl Edit {
251 pub fn insert<T: Into<String>>(text: T, at: TextSize) -> Self {
252 Self {
253 text_range: TextRange::new(at, at),
254 text: Some(text.into()),
255 }
256 }
257 pub fn replace<T: Into<String>>(text_range: TextRange, text: T) -> Self {
258 Self {
259 text_range,
260 text: Some(text.into()),
261 }
262 }
263 pub fn delete(text_range: TextRange) -> Self {
264 Self {
265 text_range,
266 text: None,
267 }
268 }
269}
270
271#[derive(Debug, Clone, PartialEq, Eq)]
272pub struct Violation {
273 pub code: Rule,
275 pub message: String,
276 pub text_range: TextRange,
277 pub help: Option<String>,
278 pub fix: Option<Fix>,
279}
280
281impl Violation {
282 #[must_use]
283 pub fn for_node(code: Rule, message: String, node: &SyntaxNode) -> Self {
284 let range = node.text_range();
285
286 let start = node
287 .children_with_tokens()
288 .find(|x| !x.kind().is_trivia())
289 .map(|x| x.text_range().start())
290 .unwrap_or_else(|| range.start());
292
293 Self {
294 code,
295 text_range: TextRange::new(start, range.end()),
296 message,
297 help: None,
298 fix: None,
299 }
300 }
301
302 #[must_use]
303 pub fn for_range(code: Rule, message: String, text_range: TextRange) -> Self {
304 Self {
305 code,
306 text_range,
307 message,
308 help: None,
309 fix: None,
310 }
311 }
312
313 fn fix(mut self, fix: Option<Fix>) -> Violation {
314 self.fix = fix;
315 self
316 }
317 fn help(mut self, help: impl Into<String>) -> Violation {
318 self.help = Some(help.into());
319 self
320 }
321}
322
323#[derive(Default)]
324pub struct LinterSettings {
325 pub pg_version: Version,
326 pub assume_in_transaction: bool,
327}
328
329pub struct Linter {
330 errors: Vec<Violation>,
331 ignores: Vec<Ignore>,
332 pub rules: FxHashSet<Rule>,
333 pub settings: LinterSettings,
334}
335
336impl Linter {
337 fn report(&mut self, error: Violation) {
338 self.errors.push(error);
339 }
340
341 fn ignore(&mut self, ignore: Ignore) {
342 self.ignores.push(ignore);
343 }
344
345 #[must_use]
346 pub fn lint(&mut self, file: &Parse<SourceFile>, text: &str) -> Vec<Violation> {
347 if self.rules.contains(&Rule::AddingFieldWithDefault) {
348 adding_field_with_default(self, file);
349 }
350 if self.rules.contains(&Rule::AddingForeignKeyConstraint) {
351 adding_foreign_key_constraint(self, file);
352 }
353 if self.rules.contains(&Rule::AddingNotNullableField) {
354 adding_not_null_field(self, file);
355 }
356 if self.rules.contains(&Rule::AddingSerialPrimaryKeyField) {
357 adding_primary_key_constraint(self, file);
358 }
359 if self.rules.contains(&Rule::AddingRequiredField) {
360 adding_required_field(self, file);
361 }
362 if self.rules.contains(&Rule::BanDropDatabase) {
363 ban_drop_database(self, file);
364 }
365 if self.rules.contains(&Rule::BanCharField) {
366 ban_char_field(self, file);
367 }
368 if self
369 .rules
370 .contains(&Rule::BanConcurrentIndexCreationInTransaction)
371 {
372 ban_concurrent_index_creation_in_transaction(self, file);
373 }
374 if self.rules.contains(&Rule::BanDropColumn) {
375 ban_drop_column(self, file);
376 }
377 if self.rules.contains(&Rule::BanDropNotNull) {
378 ban_drop_not_null(self, file);
379 }
380 if self.rules.contains(&Rule::BanDropTable) {
381 ban_drop_table(self, file);
382 }
383 if self.rules.contains(&Rule::ChangingColumnType) {
384 changing_column_type(self, file);
385 }
386 if self.rules.contains(&Rule::ConstraintMissingNotValid) {
387 constraint_missing_not_valid(self, file);
388 }
389 if self.rules.contains(&Rule::DisallowedUniqueConstraint) {
390 disallow_unique_constraint(self, file);
391 }
392 if self.rules.contains(&Rule::PreferBigintOverInt) {
393 prefer_bigint_over_int(self, file);
394 }
395 if self.rules.contains(&Rule::PreferBigintOverSmallint) {
396 prefer_bigint_over_smallint(self, file);
397 }
398 if self.rules.contains(&Rule::PreferIdentity) {
399 prefer_identity(self, file);
400 }
401 if self.rules.contains(&Rule::PreferRobustStmts) {
402 prefer_robust_stmts(self, file);
403 }
404 if self.rules.contains(&Rule::PreferTextField) {
405 prefer_text_field(self, file);
406 }
407 if self.rules.contains(&Rule::PreferTimestampTz) {
408 prefer_timestamptz(self, file);
409 }
410 if self.rules.contains(&Rule::RenamingColumn) {
411 renaming_column(self, file);
412 }
413 if self.rules.contains(&Rule::RenamingTable) {
414 renaming_table(self, file);
415 }
416 if self.rules.contains(&Rule::RequireConcurrentIndexCreation) {
417 require_concurrent_index_creation(self, file);
418 }
419 if self.rules.contains(&Rule::RequireConcurrentIndexDeletion) {
420 require_concurrent_index_deletion(self, file);
421 }
422 if self.rules.contains(&Rule::BanCreateDomainWithConstraint) {
423 ban_create_domain_with_constraint(self, file);
424 }
425 if self.rules.contains(&Rule::BanAlterDomainWithAddConstraint) {
426 ban_alter_domain_with_add_constraint(self, file);
427 }
428 if self.rules.contains(&Rule::TransactionNesting) {
429 transaction_nesting(self, file);
430 }
431 if self.rules.contains(&Rule::BanTruncateCascade) {
432 ban_truncate_cascade(self, file);
433 }
434 if self.rules.contains(&Rule::RequireTimeoutSettings) {
435 require_timeout_settings(self, file);
436 }
437 if self.rules.contains(&Rule::BanUncommittedTransaction) {
438 ban_uncommitted_transaction(self, file);
439 }
440 if self.rules.contains(&Rule::RequireEnumValueOrdering) {
441 require_enum_value_ordering(self, file);
442 }
443 if self.rules.contains(&Rule::RequireTableSchema) {
444 require_table_schema(self, file);
445 }
446 find_ignores(self, &file.syntax_node());
450
451 self.errors(text)
452 }
453
454 fn errors(&mut self, text: &str) -> Vec<Violation> {
455 let ignore_index = IgnoreIndex::new(text, &self.ignores);
456 let mut errors: Vec<Violation> = self
457 .errors
458 .iter()
459 .filter(|err| !ignore_index.contains(err.text_range, err.code))
462 .cloned()
463 .collect::<Vec<_>>();
464 errors.sort_by_key(|x| x.text_range.start());
466 errors
467 }
468
469 fn default_rules() -> FxHashSet<Rule> {
470 all::<Rule>()
471 .filter(|r| !r.is_opt_in())
472 .collect::<FxHashSet<_>>()
473 }
474
475 pub fn with_default_rules() -> Self {
476 let rules = Linter::default_rules();
477 Linter::from(rules)
478 }
479
480 pub fn with_rules(include: &[Rule], exclude: &[Rule]) -> Self {
481 let mut default_rules = Linter::default_rules();
482
483 for rule in include {
484 default_rules.insert(*rule);
485 }
486
487 for rule in exclude {
488 default_rules.remove(rule);
489 }
490
491 Linter::from(default_rules)
492 }
493
494 pub fn from(rules: impl IntoIterator<Item = Rule>) -> Self {
495 Self {
496 errors: vec![],
497 ignores: vec![],
498 rules: rules.into_iter().collect(),
499 settings: Default::default(),
500 }
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use insta::assert_debug_snapshot;
507
508 use super::*;
509
510 #[test]
511 fn prefer_timestamp_aliases() {
512 let rule1: Rule = "prefer-timestamp-tz".parse().unwrap();
513 let rule2: Rule = "prefer-timestamptz".parse().unwrap();
514 assert_eq!(rule1, rule2);
515 assert_debug_snapshot!(rule1, @"PreferTimestampTz");
516 }
517
518 #[test]
519 fn invalid_rule_name() {
520 let result: Result<Rule, _> = "invalid-rule-name".parse();
521 assert!(result.is_err());
522 }
523
524 #[test]
525 fn with_rules_opt_in_disabled_by_default() {
526 let linter = Linter::with_rules(&[], &[]);
527 assert!(!linter.rules.contains(&Rule::RequireTableSchema));
528 }
529
530 #[test]
531 fn with_rules_opt_in_enabled_via_include() {
532 let linter = Linter::with_rules(&[Rule::RequireTableSchema], &[]);
533 assert!(linter.rules.contains(&Rule::RequireTableSchema));
534 }
535
536 #[test]
537 fn with_rules_exclude_takes_precedence_over_include() {
538 let linter = Linter::with_rules(&[Rule::RequireTableSchema], &[Rule::RequireTableSchema]);
539 assert!(!linter.rules.contains(&Rule::RequireTableSchema));
540 }
541
542 #[test]
543 fn with_rules_exclude_removes_default_rule() {
544 let linter = Linter::with_rules(&[], &[Rule::BanDropTable]);
545 assert!(!linter.rules.contains(&Rule::BanDropTable));
546 }
547}