Skip to main content

rigsql_rules/capitalisation/
cp05.rs

1use rigsql_core::{Segment, SegmentType};
2
3use super::CapitalisationPolicy;
4use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
5use crate::utils::{check_capitalisation, determine_majority_case};
6use crate::violation::LintViolation;
7
8/// CP05: Data type names must be consistently capitalised.
9///
10/// By default expects upper case (INT, VARCHAR, etc.).
11#[derive(Debug)]
12pub struct RuleCP05 {
13    pub policy: CapitalisationPolicy,
14}
15
16impl Default for RuleCP05 {
17    fn default() -> Self {
18        Self {
19            policy: CapitalisationPolicy::Upper,
20        }
21    }
22}
23
24/// Check if a token text is purely numeric/punctuation (e.g. "255", "(", ")").
25fn is_numeric_or_paren(text: &str) -> bool {
26    text.chars()
27        .all(|c| c.is_ascii_digit() || c == '(' || c == ')' || c == ',')
28}
29
30impl RuleCP05 {
31    fn eval_consistent(&self, ctx: &RuleContext) -> Vec<LintViolation> {
32        let mut tokens = Vec::new();
33        Self::collect_datatype_tokens(ctx.root, &mut tokens);
34
35        if tokens.is_empty() {
36            return vec![];
37        }
38
39        let majority = determine_majority_case(&tokens);
40        let mut violations = Vec::new();
41        for (text, span) in &tokens {
42            let expected = match majority {
43                "upper" => text.to_ascii_uppercase(),
44                _ => text.to_ascii_lowercase(),
45            };
46            if let Some(v) =
47                check_capitalisation(self.code(), "Data type", text, &expected, majority, *span)
48            {
49                violations.push(v);
50            }
51        }
52        violations
53    }
54
55    fn collect_datatype_tokens(segment: &Segment, out: &mut Vec<(String, rigsql_core::Span)>) {
56        if segment.segment_type() == SegmentType::DataType {
57            for token in segment.tokens() {
58                let text = token.text.as_str();
59                if !is_numeric_or_paren(text) {
60                    out.push((text.to_string(), token.span));
61                }
62            }
63        }
64        for child in segment.children() {
65            Self::collect_datatype_tokens(child, out);
66        }
67    }
68}
69
70impl Rule for RuleCP05 {
71    fn code(&self) -> &'static str {
72        "CP05"
73    }
74    fn name(&self) -> &'static str {
75        "capitalisation.types"
76    }
77    fn description(&self) -> &'static str {
78        "Data type names must be consistently capitalised."
79    }
80    fn explanation(&self) -> &'static str {
81        "Data type names (INT, VARCHAR, TEXT, etc.) should use consistent capitalisation. \
82         Most style guides recommend upper case for data types to distinguish them from column names."
83    }
84    fn groups(&self) -> &[RuleGroup] {
85        &[RuleGroup::Capitalisation]
86    }
87    fn is_fixable(&self) -> bool {
88        true
89    }
90
91    fn configure(&mut self, settings: &std::collections::HashMap<String, String>) {
92        if let Some(policy) = settings.get("capitalisation_policy") {
93            self.policy = CapitalisationPolicy::from_config(policy);
94        }
95    }
96
97    fn crawl_type(&self) -> CrawlType {
98        if self.policy == CapitalisationPolicy::Consistent {
99            CrawlType::RootOnly
100        } else {
101            CrawlType::Segment(vec![SegmentType::DataType])
102        }
103    }
104
105    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
106        if self.policy == CapitalisationPolicy::Consistent {
107            return self.eval_consistent(ctx);
108        }
109
110        // DataType node may contain keyword tokens (INT, VARCHAR, etc.)
111        let tokens = ctx.segment.tokens();
112        let mut violations = Vec::new();
113
114        for token in tokens {
115            let text = token.text.as_str();
116            if is_numeric_or_paren(text) {
117                continue;
118            }
119
120            let (expected, policy_name) = match self.policy {
121                CapitalisationPolicy::Upper => (text.to_ascii_uppercase(), "upper"),
122                CapitalisationPolicy::Lower => (text.to_ascii_lowercase(), "lower"),
123                CapitalisationPolicy::Capitalise => (crate::utils::capitalise(text), "capitalised"),
124                CapitalisationPolicy::Consistent => unreachable!(),
125            };
126
127            if let Some(v) = check_capitalisation(
128                self.code(),
129                "Data type",
130                text,
131                &expected,
132                policy_name,
133                token.span,
134            ) {
135                violations.push(v);
136            }
137        }
138
139        violations
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use crate::test_utils::lint_sql;
147
148    #[test]
149    fn test_cp05_flags_lowercase_type() {
150        let violations = lint_sql("SELECT CAST(1 AS int)", RuleCP05::default());
151        assert_eq!(violations.len(), 1);
152    }
153
154    #[test]
155    fn test_cp05_accepts_uppercase_type() {
156        let violations = lint_sql("SELECT CAST(1 AS INT)", RuleCP05::default());
157        assert_eq!(violations.len(), 0);
158    }
159
160    #[test]
161    fn test_cp05_lower_policy() {
162        let rule = RuleCP05 {
163            policy: CapitalisationPolicy::Lower,
164        };
165        let violations = lint_sql("SELECT CAST(1 AS INT)", rule);
166        assert_eq!(violations.len(), 1);
167    }
168
169    #[test]
170    fn test_cp05_consistent_all_same_no_violation() {
171        let rule = RuleCP05 {
172            policy: CapitalisationPolicy::Consistent,
173        };
174        let violations = lint_sql("SELECT CAST(1 AS INT), CAST(2 AS VARCHAR)", rule);
175        assert_eq!(violations.len(), 0);
176    }
177
178    #[test]
179    fn test_cp05_consistent_flags_minority() {
180        // 2 upper (INT, VARCHAR) vs 1 lower (text) → majority upper, flag "text"
181        let rule = RuleCP05 {
182            policy: CapitalisationPolicy::Consistent,
183        };
184        let violations = lint_sql(
185            "SELECT CAST(1 AS INT), CAST(2 AS VARCHAR), CAST(3 AS text)",
186            rule,
187        );
188        assert_eq!(violations.len(), 1);
189        assert_eq!(violations[0].fixes[0].new_text, "TEXT");
190    }
191
192    #[test]
193    fn test_cp05_consistent_majority_lower() {
194        // 2 lower (int, varchar) vs 1 upper (TEXT) → majority lower, flag "TEXT"
195        let rule = RuleCP05 {
196            policy: CapitalisationPolicy::Consistent,
197        };
198        let violations = lint_sql(
199            "SELECT CAST(1 AS int), CAST(2 AS varchar), CAST(3 AS TEXT)",
200            rule,
201        );
202        assert_eq!(violations.len(), 1);
203        assert_eq!(violations[0].fixes[0].new_text, "text");
204    }
205}