Skip to main content

rigsql_rules/capitalisation/
cp02.rs

1use rigsql_core::{Segment, SegmentType, TokenKind};
2use rigsql_lexer::is_keyword;
3
4use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
5use crate::violation::{LintViolation, SourceEdit};
6
7/// CP02: Identifiers (non-keywords) must be consistently capitalised.
8///
9/// By default, expects lower case identifiers.
10#[derive(Debug)]
11pub struct RuleCP02 {
12    pub policy: IdentifierPolicy,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum IdentifierPolicy {
17    Lower,
18    Upper,
19    Consistent,
20}
21
22impl Default for RuleCP02 {
23    fn default() -> Self {
24        Self {
25            policy: IdentifierPolicy::Consistent,
26        }
27    }
28}
29
30impl Rule for RuleCP02 {
31    fn code(&self) -> &'static str {
32        "CP02"
33    }
34    fn name(&self) -> &'static str {
35        "capitalisation.identifiers"
36    }
37    fn description(&self) -> &'static str {
38        "Unquoted identifiers must be consistently capitalised."
39    }
40    fn explanation(&self) -> &'static str {
41        "Unquoted identifiers (table names, column names) should use consistent capitalisation. \
42         Most SQL style guides recommend lower_snake_case for identifiers."
43    }
44    fn groups(&self) -> &[RuleGroup] {
45        &[RuleGroup::Capitalisation]
46    }
47    fn is_fixable(&self) -> bool {
48        true
49    }
50
51    fn crawl_type(&self) -> CrawlType {
52        CrawlType::Segment(vec![SegmentType::Identifier])
53    }
54
55    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
56        let Segment::Token(t) = ctx.segment else {
57            return vec![];
58        };
59        if t.token.kind != TokenKind::Word {
60            return vec![];
61        }
62        // Skip if it's actually a keyword
63        if is_keyword(&t.token.text) {
64            return vec![];
65        }
66        // Skip if parent is a FunctionCall (function names are handled by CP03)
67        if let Some(parent) = ctx.parent {
68            if parent.segment_type() == rigsql_core::SegmentType::FunctionCall {
69                return vec![];
70            }
71        }
72
73        let text = t.token.text.as_str();
74
75        // Skip identifiers containing non-ASCII characters (e.g. Japanese column names)
76        // — ascii case conversion would produce broken results
77        if !text.is_ascii() {
78            return vec![];
79        }
80
81        let expected = match self.policy {
82            IdentifierPolicy::Lower => text.to_ascii_lowercase(),
83            IdentifierPolicy::Upper => text.to_ascii_uppercase(),
84            IdentifierPolicy::Consistent => return vec![], // TODO: track first-seen case
85        };
86
87        if text != expected {
88            vec![LintViolation::with_fix(
89                self.code(),
90                format!(
91                    "Unquoted identifiers must be {} case. Found '{}'.",
92                    match self.policy {
93                        IdentifierPolicy::Lower => "lower",
94                        IdentifierPolicy::Upper => "upper",
95                        IdentifierPolicy::Consistent => "consistent",
96                    },
97                    text
98                ),
99                t.token.span,
100                vec![SourceEdit::replace(t.token.span, expected.clone())],
101            )]
102        } else {
103            vec![]
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::test_utils::lint_sql;
112
113    #[test]
114    fn test_cp02_consistent_default_no_violation() {
115        let violations = lint_sql("SELECT Users FROM t", RuleCP02::default());
116        assert_eq!(violations.len(), 0);
117    }
118
119    #[test]
120    fn test_cp02_lower_policy_flags_upper() {
121        let rule = RuleCP02 {
122            policy: IdentifierPolicy::Lower,
123        };
124        let violations = lint_sql("SELECT Users FROM t", rule);
125        assert!(!violations.is_empty());
126    }
127
128    #[test]
129    fn test_cp02_skips_keywords() {
130        let rule = RuleCP02 {
131            policy: IdentifierPolicy::Lower,
132        };
133        let violations = lint_sql("SELECT id FROM users", rule);
134        assert_eq!(violations.len(), 0);
135    }
136
137    #[test]
138    fn test_cp02_skips_function_parent() {
139        let rule = RuleCP02 {
140            policy: IdentifierPolicy::Lower,
141        };
142        let violations = lint_sql("SELECT COUNT(id) FROM users", rule);
143        assert_eq!(violations.len(), 0);
144    }
145}