rigsql_rules/capitalisation/
cp01.rs1use rigsql_core::{Segment, SegmentType, TokenKind};
2use rigsql_lexer::is_keyword;
3
4use super::CapitalisationPolicy;
5use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
6use crate::utils::{check_capitalisation, collect_matching_tokens, determine_majority_case};
7use crate::violation::LintViolation;
8
9#[derive(Debug)]
13pub struct RuleCP01 {
14 pub policy: CapitalisationPolicy,
15}
16
17impl Default for RuleCP01 {
18 fn default() -> Self {
19 Self {
20 policy: CapitalisationPolicy::Upper,
21 }
22 }
23}
24
25impl Rule for RuleCP01 {
26 fn code(&self) -> &'static str {
27 "CP01"
28 }
29 fn name(&self) -> &'static str {
30 "capitalisation.keywords"
31 }
32 fn description(&self) -> &'static str {
33 "Keywords must be consistently capitalised."
34 }
35 fn explanation(&self) -> &'static str {
36 "SQL keywords like SELECT, FROM, WHERE should use consistent capitalisation. \
37 Mixed case reduces readability. Most style guides recommend UPPER case keywords \
38 to distinguish them from identifiers."
39 }
40 fn groups(&self) -> &[RuleGroup] {
41 &[RuleGroup::Capitalisation]
42 }
43 fn is_fixable(&self) -> bool {
44 true
45 }
46
47 fn crawl_type(&self) -> CrawlType {
48 if self.policy == CapitalisationPolicy::Consistent {
49 CrawlType::RootOnly
50 } else {
51 CrawlType::Segment(vec![SegmentType::Keyword, SegmentType::Unparsable])
52 }
53 }
54
55 fn configure(&mut self, settings: &std::collections::HashMap<String, String>) {
56 if let Some(policy) = settings.get("capitalisation_policy") {
57 self.policy = CapitalisationPolicy::from_config(policy);
58 }
59 }
60
61 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
62 if self.policy == CapitalisationPolicy::Consistent {
63 return self.eval_consistent(ctx);
64 }
65
66 let Segment::Token(t) = ctx.segment else {
67 return vec![];
68 };
69 if t.token.kind != TokenKind::Word || !is_keyword(&t.token.text) {
70 return vec![];
71 }
72
73 let text = t.token.text.as_str();
74 let (expected, policy_name) = match self.policy {
75 CapitalisationPolicy::Upper => (text.to_ascii_uppercase(), "upper"),
76 CapitalisationPolicy::Lower => (text.to_ascii_lowercase(), "lower"),
77 CapitalisationPolicy::Capitalise => (crate::utils::capitalise(text), "capitalised"),
78 CapitalisationPolicy::Consistent => unreachable!(),
79 };
80
81 check_capitalisation(
82 self.code(),
83 "Keywords",
84 text,
85 &expected,
86 policy_name,
87 t.token.span,
88 )
89 .into_iter()
90 .collect()
91 }
92}
93
94impl RuleCP01 {
95 fn eval_consistent(&self, ctx: &RuleContext) -> Vec<LintViolation> {
96 let mut tokens = Vec::new();
97 collect_matching_tokens(
98 ctx.root,
99 &|seg| {
100 if let Segment::Token(t) = seg {
101 if t.segment_type == SegmentType::Keyword
102 && t.token.kind == TokenKind::Word
103 && is_keyword(&t.token.text)
104 {
105 return Some((t.token.text.to_string(), t.token.span));
106 }
107 }
108 None
109 },
110 &mut tokens,
111 );
112
113 if tokens.is_empty() {
114 return vec![];
115 }
116
117 let majority = determine_majority_case(&tokens);
118 let mut violations = Vec::new();
119 for (text, span) in &tokens {
120 let expected = match majority {
121 "upper" => text.to_ascii_uppercase(),
122 _ => text.to_ascii_lowercase(),
123 };
124 if let Some(v) =
125 check_capitalisation(self.code(), "Keywords", text, &expected, majority, *span)
126 {
127 violations.push(v);
128 }
129 }
130 violations
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use crate::test_utils::lint_sql;
138
139 #[test]
140 fn test_cp01_flags_lowercase_keyword() {
141 let violations = lint_sql("select 1", RuleCP01::default());
142 assert_eq!(violations.len(), 1);
143 }
144
145 #[test]
146 fn test_cp01_accepts_uppercase_keyword() {
147 let violations = lint_sql("SELECT 1", RuleCP01::default());
148 assert_eq!(violations.len(), 0);
149 }
150
151 #[test]
152 fn test_cp01_fix_replaces_to_upper() {
153 let violations = lint_sql("select 1", RuleCP01::default());
154 assert_eq!(violations.len(), 1);
155 assert_eq!(violations[0].fixes.len(), 1);
156 assert_eq!(violations[0].fixes[0].new_text, "SELECT");
157 }
158
159 #[test]
160 fn test_cp01_lower_policy() {
161 let rule = RuleCP01 {
162 policy: CapitalisationPolicy::Lower,
163 };
164 let violations = lint_sql("SELECT 1", rule);
165 assert_eq!(violations.len(), 1);
166 }
167
168 #[test]
169 fn test_cp01_consistent_flags_minority() {
170 let rule = RuleCP01 {
172 policy: CapitalisationPolicy::Consistent,
173 };
174 let violations = lint_sql("SELECT id FROM users where id = 1 AND name = 'a'", rule);
175 assert_eq!(violations.len(), 1);
176 assert_eq!(violations[0].fixes[0].new_text, "WHERE");
177 }
178
179 #[test]
180 fn test_cp01_consistent_all_same_no_violation() {
181 let rule = RuleCP01 {
182 policy: CapitalisationPolicy::Consistent,
183 };
184 let violations = lint_sql("SELECT id FROM users WHERE id = 1", rule);
185 assert_eq!(violations.len(), 0);
186 }
187
188 #[test]
189 fn test_cp01_consistent_majority_lower() {
190 let rule = RuleCP01 {
192 policy: CapitalisationPolicy::Consistent,
193 };
194 let violations = lint_sql("select id from users where id = 1 AND name = 'a'", rule);
195 assert_eq!(violations.len(), 1);
196 assert_eq!(violations[0].fixes[0].new_text, "and");
197 }
198
199 #[test]
200 fn test_cp01_multiple_keywords() {
201 let violations = lint_sql("select * from users where id = 1", RuleCP01::default());
202 let codes: Vec<&str> = violations.iter().map(|v| v.rule_code).collect();
203 assert!(codes.iter().all(|&c| c == "CP01"));
204 assert!(violations.len() >= 3);
205 let fix_texts: Vec<&str> = violations
206 .iter()
207 .map(|v| v.fixes[0].new_text.as_str())
208 .collect();
209 assert!(fix_texts.contains(&"SELECT"));
210 assert!(fix_texts.contains(&"FROM"));
211 assert!(fix_texts.contains(&"WHERE"));
212 }
213}