Skip to main content

rigsql_rules/convention/
cv01.rs

1use rigsql_core::{Segment, SegmentType, TokenKind};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::{LintViolation, SourceEdit};
5
6/// CV01: Use consistent not-equal operator.
7///
8/// By default, flag inconsistent use within a file. When mixed styles are
9/// present, the first occurrence wins. Users can pin a specific style via
10/// the `preferred_not_equal` setting.
11#[derive(Debug)]
12pub struct RuleCV01 {
13    pub preferred: NotEqualStyle,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum NotEqualStyle {
18    /// Match whichever style appears first in the file.
19    Consistent,
20    /// Prefer `!=`
21    CStyle,
22    /// Prefer `<>`
23    AnsiStyle,
24}
25
26impl NotEqualStyle {
27    fn as_str(self) -> Option<&'static str> {
28        match self {
29            NotEqualStyle::CStyle => Some("!="),
30            NotEqualStyle::AnsiStyle => Some("<>"),
31            NotEqualStyle::Consistent => None,
32        }
33    }
34}
35
36impl Default for RuleCV01 {
37    fn default() -> Self {
38        Self {
39            preferred: NotEqualStyle::Consistent,
40        }
41    }
42}
43
44impl Rule for RuleCV01 {
45    fn code(&self) -> &'static str {
46        "CV01"
47    }
48    fn name(&self) -> &'static str {
49        "convention.not_equal"
50    }
51    fn description(&self) -> &'static str {
52        "Consistent not-equal operator."
53    }
54    fn explanation(&self) -> &'static str {
55        "SQL has two not-equal operators: '!=' and '<>'. Using one consistently \
56         improves readability. By default, the first style encountered in a file \
57         is preferred; set `preferred_not_equal` to `c_style` or `ansi` to \
58         enforce a specific style."
59    }
60    fn groups(&self) -> &[RuleGroup] {
61        &[RuleGroup::Convention]
62    }
63    fn is_fixable(&self) -> bool {
64        true
65    }
66
67    fn configure(&mut self, settings: &std::collections::HashMap<String, String>) {
68        if let Some(val) = settings.get("preferred_not_equal") {
69            self.preferred = match val.as_str() {
70                "ansi" | "<>" => NotEqualStyle::AnsiStyle,
71                "c_style" | "cstyle" | "!=" => NotEqualStyle::CStyle,
72                _ => NotEqualStyle::Consistent,
73            };
74        }
75    }
76
77    fn crawl_type(&self) -> CrawlType {
78        if self.preferred == NotEqualStyle::Consistent {
79            CrawlType::RootOnly
80        } else {
81            CrawlType::Segment(vec![SegmentType::ComparisonOperator])
82        }
83    }
84
85    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
86        let target = match self.preferred.as_str() {
87            Some(pinned) => pinned,
88            None => return self.eval_consistent(ctx),
89        };
90
91        let Segment::Token(t) = ctx.segment else {
92            return vec![];
93        };
94        if t.token.kind != TokenKind::Neq {
95            return vec![];
96        }
97
98        if t.token.text.as_str() == target {
99            return vec![];
100        }
101
102        vec![violation_for(self.code(), &t.token, target)]
103    }
104}
105
106impl RuleCV01 {
107    fn eval_consistent(&self, ctx: &RuleContext) -> Vec<LintViolation> {
108        let neq_tokens: Vec<_> = ctx
109            .root
110            .tokens()
111            .into_iter()
112            .filter(|t| t.kind == TokenKind::Neq)
113            .collect();
114
115        let target = match neq_tokens.first() {
116            Some(first) if first.text.as_str() == "<>" => "<>",
117            Some(_) => "!=",
118            None => return vec![],
119        };
120
121        neq_tokens
122            .into_iter()
123            .filter(|t| t.text.as_str() != target)
124            .map(|t| violation_for(self.code(), t, target))
125            .collect()
126    }
127}
128
129fn violation_for(code: &'static str, token: &rigsql_core::Token, target: &str) -> LintViolation {
130    let (msg, key) = if target == "!=" {
131        ("Use '!=' instead of '<>'.", "rules.CV01.msg.use_ne")
132    } else {
133        ("Use '<>' instead of '!='.", "rules.CV01.msg.use_ltgt")
134    };
135    LintViolation::with_fix_and_msg_key(
136        code,
137        msg,
138        token.span,
139        vec![SourceEdit::replace(token.span, target)],
140        key,
141        vec![],
142    )
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use crate::test_utils::lint_sql;
149
150    #[test]
151    fn test_cv01_consistent_accepts_ansi_only() {
152        let violations = lint_sql("SELECT * FROM t WHERE a <> b", RuleCV01::default());
153        assert_eq!(violations.len(), 0);
154    }
155
156    #[test]
157    fn test_cv01_consistent_accepts_cstyle_only() {
158        let violations = lint_sql("SELECT * FROM t WHERE a != b", RuleCV01::default());
159        assert_eq!(violations.len(), 0);
160    }
161
162    #[test]
163    fn test_cv01_consistent_flags_mixed_first_ansi_wins() {
164        let violations = lint_sql(
165            "SELECT * FROM t WHERE a <> b AND c != d",
166            RuleCV01::default(),
167        );
168        assert_eq!(violations.len(), 1);
169        assert_eq!(violations[0].fixes[0].new_text, "<>");
170    }
171
172    #[test]
173    fn test_cv01_consistent_flags_mixed_first_cstyle_wins() {
174        let violations = lint_sql(
175            "SELECT * FROM t WHERE a != b AND c <> d",
176            RuleCV01::default(),
177        );
178        assert_eq!(violations.len(), 1);
179        assert_eq!(violations[0].fixes[0].new_text, "!=");
180    }
181
182    #[test]
183    fn test_cv01_cstyle_policy_flags_ansi() {
184        let rule = RuleCV01 {
185            preferred: NotEqualStyle::CStyle,
186        };
187        let violations = lint_sql("SELECT * FROM t WHERE a <> b", rule);
188        assert_eq!(violations.len(), 1);
189    }
190
191    #[test]
192    fn test_cv01_ansi_policy_flags_cstyle() {
193        let rule = RuleCV01 {
194            preferred: NotEqualStyle::AnsiStyle,
195        };
196        let violations = lint_sql("SELECT * FROM t WHERE a != b", rule);
197        assert_eq!(violations.len(), 1);
198    }
199
200    #[test]
201    fn test_cv01_consistent_flags_multiple_mismatches() {
202        let violations = lint_sql(
203            "SELECT * FROM t WHERE a <> b AND c != d AND e != f",
204            RuleCV01::default(),
205        );
206        assert_eq!(violations.len(), 2);
207        assert!(violations.iter().all(|v| v.fixes[0].new_text == "<>"));
208    }
209}