Skip to main content

rigsql_rules/capitalisation/
cp01.rs

1use rigsql_core::{Segment, SegmentType, TokenKind};
2use rigsql_lexer::is_keyword;
3
4use super::CapitalisationPolicy;
5use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
6use crate::utils::{check_capitalisation, collect_matching_tokens, determine_majority_case};
7use crate::violation::LintViolation;
8
9/// CP01: Keywords must be consistently capitalised.
10///
11/// By default, expects UPPER case keywords.
12#[derive(Debug)]
13pub struct RuleCP01 {
14    pub policy: CapitalisationPolicy,
15}
16
17impl Default for RuleCP01 {
18    fn default() -> Self {
19        Self {
20            policy: CapitalisationPolicy::Upper,
21        }
22    }
23}
24
25impl Rule for RuleCP01 {
26    fn code(&self) -> &'static str {
27        "CP01"
28    }
29    fn name(&self) -> &'static str {
30        "capitalisation.keywords"
31    }
32    fn description(&self) -> &'static str {
33        "Keywords must be consistently capitalised."
34    }
35    fn explanation(&self) -> &'static str {
36        "SQL keywords like SELECT, FROM, WHERE should use consistent capitalisation. \
37         Mixed case reduces readability. Most style guides recommend UPPER case keywords \
38         to distinguish them from identifiers."
39    }
40    fn groups(&self) -> &[RuleGroup] {
41        &[RuleGroup::Capitalisation]
42    }
43    fn is_fixable(&self) -> bool {
44        true
45    }
46
47    fn crawl_type(&self) -> CrawlType {
48        if self.policy == CapitalisationPolicy::Consistent {
49            CrawlType::RootOnly
50        } else {
51            CrawlType::Segment(vec![SegmentType::Keyword, SegmentType::Unparsable])
52        }
53    }
54
55    fn configure(&mut self, settings: &std::collections::HashMap<String, String>) {
56        if let Some(policy) = settings.get("capitalisation_policy") {
57            self.policy = CapitalisationPolicy::from_config(policy);
58        }
59    }
60
61    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
62        if self.policy == CapitalisationPolicy::Consistent {
63            return self.eval_consistent(ctx);
64        }
65
66        let Segment::Token(t) = ctx.segment else {
67            return vec![];
68        };
69        if t.token.kind != TokenKind::Word || !is_keyword(&t.token.text) {
70            return vec![];
71        }
72
73        let text = t.token.text.as_str();
74        let (expected, policy_name) = match self.policy {
75            CapitalisationPolicy::Upper => (text.to_ascii_uppercase(), "upper"),
76            CapitalisationPolicy::Lower => (text.to_ascii_lowercase(), "lower"),
77            CapitalisationPolicy::Capitalise => (crate::utils::capitalise(text), "capitalised"),
78            CapitalisationPolicy::Consistent => unreachable!(),
79        };
80
81        check_capitalisation(
82            self.code(),
83            "Keywords",
84            text,
85            &expected,
86            policy_name,
87            t.token.span,
88        )
89        .into_iter()
90        .collect()
91    }
92}
93
94impl RuleCP01 {
95    fn eval_consistent(&self, ctx: &RuleContext) -> Vec<LintViolation> {
96        let mut tokens = Vec::new();
97        collect_matching_tokens(
98            ctx.root,
99            &|seg| {
100                if let Segment::Token(t) = seg {
101                    if t.segment_type == SegmentType::Keyword
102                        && t.token.kind == TokenKind::Word
103                        && is_keyword(&t.token.text)
104                    {
105                        return Some((t.token.text.to_string(), t.token.span));
106                    }
107                }
108                None
109            },
110            &mut tokens,
111        );
112
113        if tokens.is_empty() {
114            return vec![];
115        }
116
117        let majority = determine_majority_case(&tokens);
118        let mut violations = Vec::new();
119        for (text, span) in &tokens {
120            let expected = match majority {
121                "upper" => text.to_ascii_uppercase(),
122                _ => text.to_ascii_lowercase(),
123            };
124            if let Some(v) =
125                check_capitalisation(self.code(), "Keywords", text, &expected, majority, *span)
126            {
127                violations.push(v);
128            }
129        }
130        violations
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use crate::test_utils::lint_sql;
138
139    #[test]
140    fn test_cp01_flags_lowercase_keyword() {
141        let violations = lint_sql("select 1", RuleCP01::default());
142        assert_eq!(violations.len(), 1);
143    }
144
145    #[test]
146    fn test_cp01_accepts_uppercase_keyword() {
147        let violations = lint_sql("SELECT 1", RuleCP01::default());
148        assert_eq!(violations.len(), 0);
149    }
150
151    #[test]
152    fn test_cp01_fix_replaces_to_upper() {
153        let violations = lint_sql("select 1", RuleCP01::default());
154        assert_eq!(violations.len(), 1);
155        assert_eq!(violations[0].fixes.len(), 1);
156        assert_eq!(violations[0].fixes[0].new_text, "SELECT");
157    }
158
159    #[test]
160    fn test_cp01_lower_policy() {
161        let rule = RuleCP01 {
162            policy: CapitalisationPolicy::Lower,
163        };
164        let violations = lint_sql("SELECT 1", rule);
165        assert_eq!(violations.len(), 1);
166    }
167
168    #[test]
169    fn test_cp01_consistent_flags_minority() {
170        // 3 upper (SELECT, FROM, WHERE) vs 1 lower (and) → majority upper, flag "and"
171        let rule = RuleCP01 {
172            policy: CapitalisationPolicy::Consistent,
173        };
174        let violations = lint_sql("SELECT id FROM users where id = 1 AND name = 'a'", rule);
175        assert_eq!(violations.len(), 1);
176        assert_eq!(violations[0].fixes[0].new_text, "WHERE");
177    }
178
179    #[test]
180    fn test_cp01_consistent_all_same_no_violation() {
181        let rule = RuleCP01 {
182            policy: CapitalisationPolicy::Consistent,
183        };
184        let violations = lint_sql("SELECT id FROM users WHERE id = 1", rule);
185        assert_eq!(violations.len(), 0);
186    }
187
188    #[test]
189    fn test_cp01_consistent_majority_lower() {
190        // 3 lower (select, from, where) vs 1 upper (AND) → majority lower, flag "AND"
191        let rule = RuleCP01 {
192            policy: CapitalisationPolicy::Consistent,
193        };
194        let violations = lint_sql("select id from users where id = 1 AND name = 'a'", rule);
195        assert_eq!(violations.len(), 1);
196        assert_eq!(violations[0].fixes[0].new_text, "and");
197    }
198
199    #[test]
200    fn test_cp01_multiple_keywords() {
201        let violations = lint_sql("select * from users where id = 1", RuleCP01::default());
202        let codes: Vec<&str> = violations.iter().map(|v| v.rule_code).collect();
203        assert!(codes.iter().all(|&c| c == "CP01"));
204        assert!(violations.len() >= 3);
205        let fix_texts: Vec<&str> = violations
206            .iter()
207            .map(|v| v.fixes[0].new_text.as_str())
208            .collect();
209        assert!(fix_texts.contains(&"SELECT"));
210        assert!(fix_texts.contains(&"FROM"));
211        assert!(fix_texts.contains(&"WHERE"));
212    }
213}