rigsql_rules/capitalisation/
cp02.rs1use rigsql_core::{Segment, SegmentType, TokenKind};
2use rigsql_lexer::is_keyword;
3
4use super::CapitalisationPolicy;
5use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
6use crate::utils::{check_capitalisation, collect_matching_tokens, determine_majority_case};
7use crate::violation::{LintViolation, SourceEdit};
8
9#[derive(Debug)]
13pub struct RuleCP02 {
14 pub policy: CapitalisationPolicy,
15}
16
17impl Default for RuleCP02 {
18 fn default() -> Self {
19 Self {
20 policy: CapitalisationPolicy::Consistent,
21 }
22 }
23}
24
25impl RuleCP02 {
26 fn should_skip(seg: &Segment, parent: Option<&Segment>) -> bool {
28 let Segment::Token(t) = seg else {
29 return true;
30 };
31 if t.token.kind != TokenKind::Word {
32 return true;
33 }
34 if is_keyword(&t.token.text) {
35 return true;
36 }
37 if let Some(p) = parent {
38 if p.segment_type() == SegmentType::FunctionCall {
39 return true;
40 }
41 }
42 if !t.token.text.is_ascii() {
43 return true;
44 }
45 false
46 }
47
48 fn eval_consistent(&self, ctx: &RuleContext) -> Vec<LintViolation> {
49 let mut tokens = Vec::new();
50 collect_matching_tokens(
51 ctx.root,
52 &|seg| {
53 if let Segment::Token(t) = seg {
54 if t.segment_type == SegmentType::Identifier
55 && t.token.kind == TokenKind::Word
56 && !is_keyword(&t.token.text)
57 && t.token.text.is_ascii()
58 {
59 return Some((t.token.text.to_string(), t.token.span));
60 }
61 }
62 None
63 },
64 &mut tokens,
65 );
66
67 if tokens.is_empty() {
68 return vec![];
69 }
70
71 let majority = determine_majority_case(&tokens);
72 let mut violations = Vec::new();
73 for (text, span) in &tokens {
74 let expected = match majority {
75 "upper" => text.to_ascii_uppercase(),
76 _ => text.to_ascii_lowercase(),
77 };
78 if let Some(v) = check_capitalisation(
79 self.code(),
80 "Unquoted identifiers",
81 text,
82 &expected,
83 majority,
84 *span,
85 ) {
86 violations.push(v);
87 }
88 }
89 violations
90 }
91}
92
93impl Rule for RuleCP02 {
94 fn code(&self) -> &'static str {
95 "CP02"
96 }
97 fn name(&self) -> &'static str {
98 "capitalisation.identifiers"
99 }
100 fn description(&self) -> &'static str {
101 "Unquoted identifiers must be consistently capitalised."
102 }
103 fn explanation(&self) -> &'static str {
104 "Unquoted identifiers (table names, column names) should use consistent capitalisation. \
105 Most SQL style guides recommend lower_snake_case for identifiers."
106 }
107 fn groups(&self) -> &[RuleGroup] {
108 &[RuleGroup::Capitalisation]
109 }
110 fn is_fixable(&self) -> bool {
111 true
112 }
113
114 fn crawl_type(&self) -> CrawlType {
115 if self.policy == CapitalisationPolicy::Consistent {
116 CrawlType::RootOnly
117 } else {
118 CrawlType::Segment(vec![SegmentType::Identifier])
119 }
120 }
121
122 fn configure(&mut self, settings: &std::collections::HashMap<String, String>) {
123 if let Some(policy) = settings.get("capitalisation_policy") {
124 self.policy = CapitalisationPolicy::from_config(policy);
125 }
126 }
127
128 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
129 if self.policy == CapitalisationPolicy::Consistent {
130 return self.eval_consistent(ctx);
131 }
132
133 if Self::should_skip(ctx.segment, ctx.parent) {
134 return vec![];
135 }
136
137 let Segment::Token(t) = ctx.segment else {
138 return vec![];
139 };
140 let text = t.token.text.as_str();
141
142 let (expected, policy_name) = match self.policy {
143 CapitalisationPolicy::Upper => (text.to_ascii_uppercase(), "upper"),
144 CapitalisationPolicy::Lower => (text.to_ascii_lowercase(), "lower"),
145 CapitalisationPolicy::Capitalise => (crate::utils::capitalise(text), "capitalised"),
146 CapitalisationPolicy::Consistent => unreachable!(),
147 };
148
149 if text != expected {
150 vec![LintViolation::with_fix_and_msg_key(
151 self.code(),
152 format!(
153 "Unquoted identifiers must be {} case. Found '{}'.",
154 policy_name, text
155 ),
156 t.token.span,
157 vec![SourceEdit::replace(t.token.span, expected.clone())],
158 "rules.CP02.msg",
159 vec![
160 ("policy".to_string(), policy_name.to_string()),
161 ("found".to_string(), text.to_string()),
162 ],
163 )]
164 } else {
165 vec![]
166 }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::test_utils::lint_sql;
174
175 #[test]
176 fn test_cp02_lower_policy_flags_upper() {
177 let rule = RuleCP02 {
178 policy: CapitalisationPolicy::Lower,
179 };
180 let violations = lint_sql("SELECT Users FROM t", rule);
181 assert!(!violations.is_empty());
182 }
183
184 #[test]
185 fn test_cp02_skips_keywords() {
186 let rule = RuleCP02 {
187 policy: CapitalisationPolicy::Lower,
188 };
189 let violations = lint_sql("SELECT id FROM users", rule);
190 assert_eq!(violations.len(), 0);
191 }
192
193 #[test]
194 fn test_cp02_skips_function_parent() {
195 let rule = RuleCP02 {
196 policy: CapitalisationPolicy::Lower,
197 };
198 let violations = lint_sql("SELECT COUNT(id) FROM users", rule);
199 assert_eq!(violations.len(), 0);
200 }
201
202 #[test]
203 fn test_cp02_consistent_all_lower_no_violation() {
204 let rule = RuleCP02 {
205 policy: CapitalisationPolicy::Consistent,
206 };
207 let violations = lint_sql("SELECT id, name FROM users", rule);
208 assert_eq!(violations.len(), 0);
209 }
210
211 #[test]
212 fn test_cp02_consistent_flags_minority() {
213 let rule = RuleCP02 {
215 policy: CapitalisationPolicy::Consistent,
216 };
217 let violations = lint_sql("SELECT id, name, AGE FROM users", rule);
218 assert_eq!(violations.len(), 1);
219 assert_eq!(violations[0].fixes[0].new_text, "age");
220 }
221
222 #[test]
223 fn test_cp02_consistent_majority_upper() {
224 let rule = RuleCP02 {
226 policy: CapitalisationPolicy::Consistent,
227 };
228 let violations = lint_sql("SELECT ID, NAME, age FROM USERS", rule);
229 assert_eq!(violations.len(), 1);
230 assert_eq!(violations[0].fixes[0].new_text, "AGE");
231 }
232}