Skip to main content

rigsql_rules/capitalisation/
cp02.rs

1use rigsql_core::{Segment, SegmentType, TokenKind};
2use rigsql_lexer::is_keyword;
3
4use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
5use crate::violation::{LintViolation, SourceEdit};
6
7/// CP02: Identifiers (non-keywords) must be consistently capitalised.
8///
9/// By default, expects lower case identifiers.
10#[derive(Debug)]
11pub struct RuleCP02 {
12    pub policy: IdentifierPolicy,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum IdentifierPolicy {
17    Lower,
18    Upper,
19    Consistent,
20}
21
22impl Default for RuleCP02 {
23    fn default() -> Self {
24        Self {
25            policy: IdentifierPolicy::Consistent,
26        }
27    }
28}
29
30impl Rule for RuleCP02 {
31    fn code(&self) -> &'static str {
32        "CP02"
33    }
34    fn name(&self) -> &'static str {
35        "capitalisation.identifiers"
36    }
37    fn description(&self) -> &'static str {
38        "Unquoted identifiers must be consistently capitalised."
39    }
40    fn explanation(&self) -> &'static str {
41        "Unquoted identifiers (table names, column names) should use consistent capitalisation. \
42         Most SQL style guides recommend lower_snake_case for identifiers."
43    }
44    fn groups(&self) -> &[RuleGroup] {
45        &[RuleGroup::Capitalisation]
46    }
47    fn is_fixable(&self) -> bool {
48        true
49    }
50
51    fn crawl_type(&self) -> CrawlType {
52        CrawlType::Segment(vec![SegmentType::Identifier])
53    }
54
55    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
56        let Segment::Token(t) = ctx.segment else {
57            return vec![];
58        };
59        if t.token.kind != TokenKind::Word {
60            return vec![];
61        }
62        // Skip if it's actually a keyword
63        if is_keyword(&t.token.text) {
64            return vec![];
65        }
66        // Skip if parent is a FunctionCall (function names are handled by CP03)
67        if let Some(parent) = ctx.parent {
68            if parent.segment_type() == rigsql_core::SegmentType::FunctionCall {
69                return vec![];
70            }
71        }
72
73        let text = t.token.text.as_str();
74
75        // Skip identifiers containing non-ASCII characters (e.g. Japanese column names)
76        // — ascii case conversion would produce broken results
77        if !text.is_ascii() {
78            return vec![];
79        }
80
81        let expected = match self.policy {
82            IdentifierPolicy::Lower => text.to_ascii_lowercase(),
83            IdentifierPolicy::Upper => text.to_ascii_uppercase(),
84            IdentifierPolicy::Consistent => return vec![], // TODO: track first-seen case
85        };
86
87        if text != expected {
88            let policy_name = match self.policy {
89                IdentifierPolicy::Lower => "lower",
90                IdentifierPolicy::Upper => "upper",
91                IdentifierPolicy::Consistent => "consistent",
92            };
93            vec![LintViolation::with_fix_and_msg_key(
94                self.code(),
95                format!(
96                    "Unquoted identifiers must be {} case. Found '{}'.",
97                    policy_name, text
98                ),
99                t.token.span,
100                vec![SourceEdit::replace(t.token.span, expected.clone())],
101                "rules.CP02.msg",
102                vec![
103                    ("policy".to_string(), policy_name.to_string()),
104                    ("found".to_string(), text.to_string()),
105                ],
106            )]
107        } else {
108            vec![]
109        }
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use crate::test_utils::lint_sql;
117
118    #[test]
119    fn test_cp02_consistent_default_no_violation() {
120        let violations = lint_sql("SELECT Users FROM t", RuleCP02::default());
121        assert_eq!(violations.len(), 0);
122    }
123
124    #[test]
125    fn test_cp02_lower_policy_flags_upper() {
126        let rule = RuleCP02 {
127            policy: IdentifierPolicy::Lower,
128        };
129        let violations = lint_sql("SELECT Users FROM t", rule);
130        assert!(!violations.is_empty());
131    }
132
133    #[test]
134    fn test_cp02_skips_keywords() {
135        let rule = RuleCP02 {
136            policy: IdentifierPolicy::Lower,
137        };
138        let violations = lint_sql("SELECT id FROM users", rule);
139        assert_eq!(violations.len(), 0);
140    }
141
142    #[test]
143    fn test_cp02_skips_function_parent() {
144        let rule = RuleCP02 {
145            policy: IdentifierPolicy::Lower,
146        };
147        let violations = lint_sql("SELECT COUNT(id) FROM users", rule);
148        assert_eq!(violations.len(), 0);
149    }
150}