rigsql_rules/capitalisation/
cp01.rs1use rigsql_core::{Segment, SegmentType, TokenKind};
2use rigsql_lexer::is_keyword;
3
4use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
5use crate::utils::check_capitalisation;
6use crate::violation::LintViolation;
7
8#[derive(Debug)]
12pub struct RuleCP01 {
13 pub policy: CapitalisationPolicy,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum CapitalisationPolicy {
18 Upper,
19 Lower,
20 Capitalise,
21}
22
23impl Default for RuleCP01 {
24 fn default() -> Self {
25 Self {
26 policy: CapitalisationPolicy::Upper,
27 }
28 }
29}
30
31impl Rule for RuleCP01 {
32 fn code(&self) -> &'static str {
33 "CP01"
34 }
35 fn name(&self) -> &'static str {
36 "capitalisation.keywords"
37 }
38 fn description(&self) -> &'static str {
39 "Keywords must be consistently capitalised."
40 }
41 fn explanation(&self) -> &'static str {
42 "SQL keywords like SELECT, FROM, WHERE should use consistent capitalisation. \
43 Mixed case reduces readability. Most style guides recommend UPPER case keywords \
44 to distinguish them from identifiers."
45 }
46 fn groups(&self) -> &[RuleGroup] {
47 &[RuleGroup::Capitalisation]
48 }
49 fn is_fixable(&self) -> bool {
50 true
51 }
52
53 fn crawl_type(&self) -> CrawlType {
54 CrawlType::Segment(vec![SegmentType::Keyword, SegmentType::Unparsable])
55 }
56
57 fn configure(&mut self, settings: &std::collections::HashMap<String, String>) {
58 if let Some(policy) = settings.get("capitalisation_policy") {
59 self.policy = match policy.as_str() {
60 "lower" => CapitalisationPolicy::Lower,
61 "capitalise" | "capitalize" => CapitalisationPolicy::Capitalise,
62 _ => CapitalisationPolicy::Upper,
63 };
64 }
65 }
66
67 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
68 let Segment::Token(t) = ctx.segment else {
69 return vec![];
70 };
71 if t.token.kind != TokenKind::Word || !is_keyword(&t.token.text) {
72 return vec![];
73 }
74
75 let text = t.token.text.as_str();
76 let (expected, policy_name) = match self.policy {
77 CapitalisationPolicy::Upper => (text.to_ascii_uppercase(), "upper"),
78 CapitalisationPolicy::Lower => (text.to_ascii_lowercase(), "lower"),
79 CapitalisationPolicy::Capitalise => (capitalise(text), "capitalised"),
80 };
81
82 check_capitalisation(
83 self.code(),
84 "Keywords",
85 text,
86 &expected,
87 policy_name,
88 t.token.span,
89 )
90 .into_iter()
91 .collect()
92 }
93}
94
95fn capitalise(s: &str) -> String {
96 let mut chars = s.chars();
97 match chars.next() {
98 Some(c) => c.to_uppercase().to_string() + &chars.as_str().to_lowercase(),
99 None => String::new(),
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::test_utils::lint_sql;
107
108 #[test]
109 fn test_cp01_flags_lowercase_keyword() {
110 let violations = lint_sql("select 1", RuleCP01::default());
111 assert_eq!(violations.len(), 1);
112 }
113
114 #[test]
115 fn test_cp01_accepts_uppercase_keyword() {
116 let violations = lint_sql("SELECT 1", RuleCP01::default());
117 assert_eq!(violations.len(), 0);
118 }
119
120 #[test]
121 fn test_cp01_fix_replaces_to_upper() {
122 let violations = lint_sql("select 1", RuleCP01::default());
123 assert_eq!(violations.len(), 1);
124 assert_eq!(violations[0].fixes.len(), 1);
125 assert_eq!(violations[0].fixes[0].new_text, "SELECT");
126 }
127
128 #[test]
129 fn test_cp01_lower_policy() {
130 let rule = RuleCP01 {
131 policy: CapitalisationPolicy::Lower,
132 };
133 let violations = lint_sql("SELECT 1", rule);
134 assert_eq!(violations.len(), 1);
135 }
136
137 #[test]
138 fn test_cp01_multiple_keywords() {
139 let violations = lint_sql("select * from users where id = 1", RuleCP01::default());
140 let codes: Vec<&str> = violations.iter().map(|v| v.rule_code).collect();
141 assert!(codes.iter().all(|&c| c == "CP01"));
142 assert!(violations.len() >= 3);
143 let fix_texts: Vec<&str> = violations
144 .iter()
145 .map(|v| v.fixes[0].new_text.as_str())
146 .collect();
147 assert!(fix_texts.contains(&"SELECT"));
148 assert!(fix_texts.contains(&"FROM"));
149 assert!(fix_texts.contains(&"WHERE"));
150 }
151}