rigsql_rules/capitalisation/
cp02.rs1use rigsql_core::{Segment, SegmentType, TokenKind};
2use rigsql_lexer::is_keyword;
3
4use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
5use crate::violation::{LintViolation, SourceEdit};
6
7#[derive(Debug)]
11pub struct RuleCP02 {
12 pub policy: IdentifierPolicy,
13}
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum IdentifierPolicy {
17 Lower,
18 Upper,
19 Consistent,
20}
21
22impl Default for RuleCP02 {
23 fn default() -> Self {
24 Self {
25 policy: IdentifierPolicy::Consistent,
26 }
27 }
28}
29
30impl Rule for RuleCP02 {
31 fn code(&self) -> &'static str {
32 "CP02"
33 }
34 fn name(&self) -> &'static str {
35 "capitalisation.identifiers"
36 }
37 fn description(&self) -> &'static str {
38 "Unquoted identifiers must be consistently capitalised."
39 }
40 fn explanation(&self) -> &'static str {
41 "Unquoted identifiers (table names, column names) should use consistent capitalisation. \
42 Most SQL style guides recommend lower_snake_case for identifiers."
43 }
44 fn groups(&self) -> &[RuleGroup] {
45 &[RuleGroup::Capitalisation]
46 }
47 fn is_fixable(&self) -> bool {
48 true
49 }
50
51 fn crawl_type(&self) -> CrawlType {
52 CrawlType::Segment(vec![SegmentType::Identifier])
53 }
54
55 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
56 let Segment::Token(t) = ctx.segment else {
57 return vec![];
58 };
59 if t.token.kind != TokenKind::Word {
60 return vec![];
61 }
62 if is_keyword(&t.token.text) {
64 return vec![];
65 }
66 if let Some(parent) = ctx.parent {
68 if parent.segment_type() == rigsql_core::SegmentType::FunctionCall {
69 return vec![];
70 }
71 }
72
73 let text = t.token.text.as_str();
74
75 if !text.is_ascii() {
78 return vec![];
79 }
80
81 let expected = match self.policy {
82 IdentifierPolicy::Lower => text.to_ascii_lowercase(),
83 IdentifierPolicy::Upper => text.to_ascii_uppercase(),
84 IdentifierPolicy::Consistent => return vec![], };
86
87 if text != expected {
88 let policy_name = match self.policy {
89 IdentifierPolicy::Lower => "lower",
90 IdentifierPolicy::Upper => "upper",
91 IdentifierPolicy::Consistent => "consistent",
92 };
93 vec![LintViolation::with_fix_and_msg_key(
94 self.code(),
95 format!(
96 "Unquoted identifiers must be {} case. Found '{}'.",
97 policy_name, text
98 ),
99 t.token.span,
100 vec![SourceEdit::replace(t.token.span, expected.clone())],
101 "rules.CP02.msg",
102 vec![
103 ("policy".to_string(), policy_name.to_string()),
104 ("found".to_string(), text.to_string()),
105 ],
106 )]
107 } else {
108 vec![]
109 }
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116 use crate::test_utils::lint_sql;
117
118 #[test]
119 fn test_cp02_consistent_default_no_violation() {
120 let violations = lint_sql("SELECT Users FROM t", RuleCP02::default());
121 assert_eq!(violations.len(), 0);
122 }
123
124 #[test]
125 fn test_cp02_lower_policy_flags_upper() {
126 let rule = RuleCP02 {
127 policy: IdentifierPolicy::Lower,
128 };
129 let violations = lint_sql("SELECT Users FROM t", rule);
130 assert!(!violations.is_empty());
131 }
132
133 #[test]
134 fn test_cp02_skips_keywords() {
135 let rule = RuleCP02 {
136 policy: IdentifierPolicy::Lower,
137 };
138 let violations = lint_sql("SELECT id FROM users", rule);
139 assert_eq!(violations.len(), 0);
140 }
141
142 #[test]
143 fn test_cp02_skips_function_parent() {
144 let rule = RuleCP02 {
145 policy: IdentifierPolicy::Lower,
146 };
147 let violations = lint_sql("SELECT COUNT(id) FROM users", rule);
148 assert_eq!(violations.len(), 0);
149 }
150}