Skip to main content

rigsql_rules/capitalisation/
cp01.rs

1use rigsql_core::{Segment, SegmentType, TokenKind};
2use rigsql_lexer::is_keyword;
3
4use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
5use crate::utils::check_capitalisation;
6use crate::violation::LintViolation;
7
8/// CP01: Keywords must be consistently capitalised.
9///
10/// By default, expects UPPER case keywords.
11#[derive(Debug)]
12pub struct RuleCP01 {
13    pub policy: CapitalisationPolicy,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum CapitalisationPolicy {
18    Upper,
19    Lower,
20    Capitalise,
21}
22
23impl Default for RuleCP01 {
24    fn default() -> Self {
25        Self {
26            policy: CapitalisationPolicy::Upper,
27        }
28    }
29}
30
31impl Rule for RuleCP01 {
32    fn code(&self) -> &'static str {
33        "CP01"
34    }
35    fn name(&self) -> &'static str {
36        "capitalisation.keywords"
37    }
38    fn description(&self) -> &'static str {
39        "Keywords must be consistently capitalised."
40    }
41    fn explanation(&self) -> &'static str {
42        "SQL keywords like SELECT, FROM, WHERE should use consistent capitalisation. \
43         Mixed case reduces readability. Most style guides recommend UPPER case keywords \
44         to distinguish them from identifiers."
45    }
46    fn groups(&self) -> &[RuleGroup] {
47        &[RuleGroup::Capitalisation]
48    }
49    fn is_fixable(&self) -> bool {
50        true
51    }
52
53    fn crawl_type(&self) -> CrawlType {
54        CrawlType::Segment(vec![SegmentType::Keyword, SegmentType::Unparsable])
55    }
56
57    fn configure(&mut self, settings: &std::collections::HashMap<String, String>) {
58        if let Some(policy) = settings.get("capitalisation_policy") {
59            self.policy = match policy.as_str() {
60                "lower" => CapitalisationPolicy::Lower,
61                "capitalise" | "capitalize" => CapitalisationPolicy::Capitalise,
62                _ => CapitalisationPolicy::Upper,
63            };
64        }
65    }
66
67    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
68        let Segment::Token(t) = ctx.segment else {
69            return vec![];
70        };
71        if t.token.kind != TokenKind::Word || !is_keyword(&t.token.text) {
72            return vec![];
73        }
74
75        let text = t.token.text.as_str();
76        let (expected, policy_name) = match self.policy {
77            CapitalisationPolicy::Upper => (text.to_ascii_uppercase(), "upper"),
78            CapitalisationPolicy::Lower => (text.to_ascii_lowercase(), "lower"),
79            CapitalisationPolicy::Capitalise => (capitalise(text), "capitalised"),
80        };
81
82        check_capitalisation(
83            self.code(),
84            "Keywords",
85            text,
86            &expected,
87            policy_name,
88            t.token.span,
89        )
90        .into_iter()
91        .collect()
92    }
93}
94
95fn capitalise(s: &str) -> String {
96    let mut chars = s.chars();
97    match chars.next() {
98        Some(c) => c.to_uppercase().to_string() + &chars.as_str().to_lowercase(),
99        None => String::new(),
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::test_utils::lint_sql;
107
108    #[test]
109    fn test_cp01_flags_lowercase_keyword() {
110        let violations = lint_sql("select 1", RuleCP01::default());
111        assert_eq!(violations.len(), 1);
112    }
113
114    #[test]
115    fn test_cp01_accepts_uppercase_keyword() {
116        let violations = lint_sql("SELECT 1", RuleCP01::default());
117        assert_eq!(violations.len(), 0);
118    }
119
120    #[test]
121    fn test_cp01_fix_replaces_to_upper() {
122        let violations = lint_sql("select 1", RuleCP01::default());
123        assert_eq!(violations.len(), 1);
124        assert_eq!(violations[0].fixes.len(), 1);
125        assert_eq!(violations[0].fixes[0].new_text, "SELECT");
126    }
127
128    #[test]
129    fn test_cp01_lower_policy() {
130        let rule = RuleCP01 {
131            policy: CapitalisationPolicy::Lower,
132        };
133        let violations = lint_sql("SELECT 1", rule);
134        assert_eq!(violations.len(), 1);
135    }
136
137    #[test]
138    fn test_cp01_multiple_keywords() {
139        let violations = lint_sql("select * from users where id = 1", RuleCP01::default());
140        let codes: Vec<&str> = violations.iter().map(|v| v.rule_code).collect();
141        assert!(codes.iter().all(|&c| c == "CP01"));
142        assert!(violations.len() >= 3);
143        let fix_texts: Vec<&str> = violations
144            .iter()
145            .map(|v| v.fixes[0].new_text.as_str())
146            .collect();
147        assert!(fix_texts.contains(&"SELECT"));
148        assert!(fix_texts.contains(&"FROM"));
149        assert!(fix_texts.contains(&"WHERE"));
150    }
151}