rigsql_rules/convention/
cv01.rs1use rigsql_core::{Segment, SegmentType, TokenKind};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::{LintViolation, SourceEdit};
5
6#[derive(Debug)]
12pub struct RuleCV01 {
13 pub preferred: NotEqualStyle,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum NotEqualStyle {
18 Consistent,
20 CStyle,
22 AnsiStyle,
24}
25
26impl NotEqualStyle {
27 fn as_str(self) -> Option<&'static str> {
28 match self {
29 NotEqualStyle::CStyle => Some("!="),
30 NotEqualStyle::AnsiStyle => Some("<>"),
31 NotEqualStyle::Consistent => None,
32 }
33 }
34}
35
36impl Default for RuleCV01 {
37 fn default() -> Self {
38 Self {
39 preferred: NotEqualStyle::Consistent,
40 }
41 }
42}
43
44impl Rule for RuleCV01 {
45 fn code(&self) -> &'static str {
46 "CV01"
47 }
48 fn name(&self) -> &'static str {
49 "convention.not_equal"
50 }
51 fn description(&self) -> &'static str {
52 "Consistent not-equal operator."
53 }
54 fn explanation(&self) -> &'static str {
55 "SQL has two not-equal operators: '!=' and '<>'. Using one consistently \
56 improves readability. By default, the first style encountered in a file \
57 is preferred; set `preferred_not_equal` to `c_style` or `ansi` to \
58 enforce a specific style."
59 }
60 fn groups(&self) -> &[RuleGroup] {
61 &[RuleGroup::Convention]
62 }
63 fn is_fixable(&self) -> bool {
64 true
65 }
66
67 fn configure(&mut self, settings: &std::collections::HashMap<String, String>) {
68 if let Some(val) = settings.get("preferred_not_equal") {
69 self.preferred = match val.as_str() {
70 "ansi" | "<>" => NotEqualStyle::AnsiStyle,
71 "c_style" | "cstyle" | "!=" => NotEqualStyle::CStyle,
72 _ => NotEqualStyle::Consistent,
73 };
74 }
75 }
76
77 fn crawl_type(&self) -> CrawlType {
78 if self.preferred == NotEqualStyle::Consistent {
79 CrawlType::RootOnly
80 } else {
81 CrawlType::Segment(vec![SegmentType::ComparisonOperator])
82 }
83 }
84
85 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
86 let target = match self.preferred.as_str() {
87 Some(pinned) => pinned,
88 None => return self.eval_consistent(ctx),
89 };
90
91 let Segment::Token(t) = ctx.segment else {
92 return vec![];
93 };
94 if t.token.kind != TokenKind::Neq {
95 return vec![];
96 }
97
98 if t.token.text.as_str() == target {
99 return vec![];
100 }
101
102 vec![violation_for(self.code(), &t.token, target)]
103 }
104}
105
106impl RuleCV01 {
107 fn eval_consistent(&self, ctx: &RuleContext) -> Vec<LintViolation> {
108 let neq_tokens: Vec<_> = ctx
109 .root
110 .tokens()
111 .into_iter()
112 .filter(|t| t.kind == TokenKind::Neq)
113 .collect();
114
115 let target = match neq_tokens.first() {
116 Some(first) if first.text.as_str() == "<>" => "<>",
117 Some(_) => "!=",
118 None => return vec![],
119 };
120
121 neq_tokens
122 .into_iter()
123 .filter(|t| t.text.as_str() != target)
124 .map(|t| violation_for(self.code(), t, target))
125 .collect()
126 }
127}
128
129fn violation_for(code: &'static str, token: &rigsql_core::Token, target: &str) -> LintViolation {
130 let (msg, key) = if target == "!=" {
131 ("Use '!=' instead of '<>'.", "rules.CV01.msg.use_ne")
132 } else {
133 ("Use '<>' instead of '!='.", "rules.CV01.msg.use_ltgt")
134 };
135 LintViolation::with_fix_and_msg_key(
136 code,
137 msg,
138 token.span,
139 vec![SourceEdit::replace(token.span, target)],
140 key,
141 vec![],
142 )
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use crate::test_utils::lint_sql;
149
150 #[test]
151 fn test_cv01_consistent_accepts_ansi_only() {
152 let violations = lint_sql("SELECT * FROM t WHERE a <> b", RuleCV01::default());
153 assert_eq!(violations.len(), 0);
154 }
155
156 #[test]
157 fn test_cv01_consistent_accepts_cstyle_only() {
158 let violations = lint_sql("SELECT * FROM t WHERE a != b", RuleCV01::default());
159 assert_eq!(violations.len(), 0);
160 }
161
162 #[test]
163 fn test_cv01_consistent_flags_mixed_first_ansi_wins() {
164 let violations = lint_sql(
165 "SELECT * FROM t WHERE a <> b AND c != d",
166 RuleCV01::default(),
167 );
168 assert_eq!(violations.len(), 1);
169 assert_eq!(violations[0].fixes[0].new_text, "<>");
170 }
171
172 #[test]
173 fn test_cv01_consistent_flags_mixed_first_cstyle_wins() {
174 let violations = lint_sql(
175 "SELECT * FROM t WHERE a != b AND c <> d",
176 RuleCV01::default(),
177 );
178 assert_eq!(violations.len(), 1);
179 assert_eq!(violations[0].fixes[0].new_text, "!=");
180 }
181
182 #[test]
183 fn test_cv01_cstyle_policy_flags_ansi() {
184 let rule = RuleCV01 {
185 preferred: NotEqualStyle::CStyle,
186 };
187 let violations = lint_sql("SELECT * FROM t WHERE a <> b", rule);
188 assert_eq!(violations.len(), 1);
189 }
190
191 #[test]
192 fn test_cv01_ansi_policy_flags_cstyle() {
193 let rule = RuleCV01 {
194 preferred: NotEqualStyle::AnsiStyle,
195 };
196 let violations = lint_sql("SELECT * FROM t WHERE a != b", rule);
197 assert_eq!(violations.len(), 1);
198 }
199
200 #[test]
201 fn test_cv01_consistent_flags_multiple_mismatches() {
202 let violations = lint_sql(
203 "SELECT * FROM t WHERE a <> b AND c != d AND e != f",
204 RuleCV01::default(),
205 );
206 assert_eq!(violations.len(), 2);
207 assert!(violations.iter().all(|v| v.fixes[0].new_text == "<>"));
208 }
209}