Skip to main content

rigsql_rules/capitalisation/
cp02.rs

1use rigsql_core::{Segment, SegmentType, TokenKind};
2use rigsql_lexer::is_keyword;
3
4use super::CapitalisationPolicy;
5use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
6use crate::utils::{check_capitalisation, collect_matching_tokens, determine_majority_case};
7use crate::violation::{LintViolation, SourceEdit};
8
9/// CP02: Identifiers (non-keywords) must be consistently capitalised.
10///
11/// By default, expects consistent case identifiers.
12#[derive(Debug)]
13pub struct RuleCP02 {
14    pub policy: CapitalisationPolicy,
15}
16
17impl Default for RuleCP02 {
18    fn default() -> Self {
19        Self {
20            policy: CapitalisationPolicy::Consistent,
21        }
22    }
23}
24
25impl RuleCP02 {
26    /// Check if an identifier token should be skipped.
27    fn should_skip(seg: &Segment, parent: Option<&Segment>) -> bool {
28        let Segment::Token(t) = seg else {
29            return true;
30        };
31        if t.token.kind != TokenKind::Word {
32            return true;
33        }
34        if is_keyword(&t.token.text) {
35            return true;
36        }
37        if let Some(p) = parent {
38            if p.segment_type() == SegmentType::FunctionCall {
39                return true;
40            }
41        }
42        if !t.token.text.is_ascii() {
43            return true;
44        }
45        false
46    }
47
48    fn eval_consistent(&self, ctx: &RuleContext) -> Vec<LintViolation> {
49        let mut tokens = Vec::new();
50        collect_matching_tokens(
51            ctx.root,
52            &|seg| {
53                if let Segment::Token(t) = seg {
54                    if t.segment_type == SegmentType::Identifier
55                        && t.token.kind == TokenKind::Word
56                        && !is_keyword(&t.token.text)
57                        && t.token.text.is_ascii()
58                    {
59                        return Some((t.token.text.to_string(), t.token.span));
60                    }
61                }
62                None
63            },
64            &mut tokens,
65        );
66
67        if tokens.is_empty() {
68            return vec![];
69        }
70
71        let majority = determine_majority_case(&tokens);
72        let mut violations = Vec::new();
73        for (text, span) in &tokens {
74            let expected = match majority {
75                "upper" => text.to_ascii_uppercase(),
76                _ => text.to_ascii_lowercase(),
77            };
78            if let Some(v) = check_capitalisation(
79                self.code(),
80                "Unquoted identifiers",
81                text,
82                &expected,
83                majority,
84                *span,
85            ) {
86                violations.push(v);
87            }
88        }
89        violations
90    }
91}
92
93impl Rule for RuleCP02 {
94    fn code(&self) -> &'static str {
95        "CP02"
96    }
97    fn name(&self) -> &'static str {
98        "capitalisation.identifiers"
99    }
100    fn description(&self) -> &'static str {
101        "Unquoted identifiers must be consistently capitalised."
102    }
103    fn explanation(&self) -> &'static str {
104        "Unquoted identifiers (table names, column names) should use consistent capitalisation. \
105         Most SQL style guides recommend lower_snake_case for identifiers."
106    }
107    fn groups(&self) -> &[RuleGroup] {
108        &[RuleGroup::Capitalisation]
109    }
110    fn is_fixable(&self) -> bool {
111        true
112    }
113
114    fn crawl_type(&self) -> CrawlType {
115        if self.policy == CapitalisationPolicy::Consistent {
116            CrawlType::RootOnly
117        } else {
118            CrawlType::Segment(vec![SegmentType::Identifier])
119        }
120    }
121
122    fn configure(&mut self, settings: &std::collections::HashMap<String, String>) {
123        if let Some(policy) = settings.get("capitalisation_policy") {
124            self.policy = CapitalisationPolicy::from_config(policy);
125        }
126    }
127
128    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
129        if self.policy == CapitalisationPolicy::Consistent {
130            return self.eval_consistent(ctx);
131        }
132
133        if Self::should_skip(ctx.segment, ctx.parent) {
134            return vec![];
135        }
136
137        let Segment::Token(t) = ctx.segment else {
138            return vec![];
139        };
140        let text = t.token.text.as_str();
141
142        let (expected, policy_name) = match self.policy {
143            CapitalisationPolicy::Upper => (text.to_ascii_uppercase(), "upper"),
144            CapitalisationPolicy::Lower => (text.to_ascii_lowercase(), "lower"),
145            CapitalisationPolicy::Capitalise => (crate::utils::capitalise(text), "capitalised"),
146            CapitalisationPolicy::Consistent => unreachable!(),
147        };
148
149        if text != expected {
150            vec![LintViolation::with_fix_and_msg_key(
151                self.code(),
152                format!(
153                    "Unquoted identifiers must be {} case. Found '{}'.",
154                    policy_name, text
155                ),
156                t.token.span,
157                vec![SourceEdit::replace(t.token.span, expected.clone())],
158                "rules.CP02.msg",
159                vec![
160                    ("policy".to_string(), policy_name.to_string()),
161                    ("found".to_string(), text.to_string()),
162                ],
163            )]
164        } else {
165            vec![]
166        }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::test_utils::lint_sql;
174
175    #[test]
176    fn test_cp02_lower_policy_flags_upper() {
177        let rule = RuleCP02 {
178            policy: CapitalisationPolicy::Lower,
179        };
180        let violations = lint_sql("SELECT Users FROM t", rule);
181        assert!(!violations.is_empty());
182    }
183
184    #[test]
185    fn test_cp02_skips_keywords() {
186        let rule = RuleCP02 {
187            policy: CapitalisationPolicy::Lower,
188        };
189        let violations = lint_sql("SELECT id FROM users", rule);
190        assert_eq!(violations.len(), 0);
191    }
192
193    #[test]
194    fn test_cp02_skips_function_parent() {
195        let rule = RuleCP02 {
196            policy: CapitalisationPolicy::Lower,
197        };
198        let violations = lint_sql("SELECT COUNT(id) FROM users", rule);
199        assert_eq!(violations.len(), 0);
200    }
201
202    #[test]
203    fn test_cp02_consistent_all_lower_no_violation() {
204        let rule = RuleCP02 {
205            policy: CapitalisationPolicy::Consistent,
206        };
207        let violations = lint_sql("SELECT id, name FROM users", rule);
208        assert_eq!(violations.len(), 0);
209    }
210
211    #[test]
212    fn test_cp02_consistent_flags_minority() {
213        // 3 lower (id, name, users) vs 1 upper (AGE) → majority lower, flag "AGE"
214        let rule = RuleCP02 {
215            policy: CapitalisationPolicy::Consistent,
216        };
217        let violations = lint_sql("SELECT id, name, AGE FROM users", rule);
218        assert_eq!(violations.len(), 1);
219        assert_eq!(violations[0].fixes[0].new_text, "age");
220    }
221
222    #[test]
223    fn test_cp02_consistent_majority_upper() {
224        // 3 upper (ID, NAME, USERS) vs 1 lower (age) → majority upper, flag "age"
225        let rule = RuleCP02 {
226            policy: CapitalisationPolicy::Consistent,
227        };
228        let violations = lint_sql("SELECT ID, NAME, age FROM USERS", rule);
229        assert_eq!(violations.len(), 1);
230        assert_eq!(violations[0].fixes[0].new_text, "AGE");
231    }
232}