Skip to main content

rigsql_rules/capitalisation/
cp03.rs

1use rigsql_core::{Segment, SegmentType, TokenKind};
2
3use super::CapitalisationPolicy;
4use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
5use crate::utils::{check_capitalisation, determine_majority_case};
6use crate::violation::LintViolation;
7
8/// Built-in SQL function names (sorted alphabetically for binary_search).
9const BUILTIN_FUNCTIONS: &[&str] = &[
10    "ABS",
11    "ACOS",
12    "APP_NAME",
13    "ASCII",
14    "ASIN",
15    "ATAN",
16    "ATAN2",
17    "AVG",
18    "CAST",
19    "CEILING",
20    "CHAR",
21    "CHARINDEX",
22    "CHOOSE",
23    "COALESCE",
24    "CONCAT",
25    "CONCAT_WS",
26    "CONVERT",
27    "COS",
28    "COT",
29    "COUNT",
30    "COUNT_BIG",
31    "CUME_DIST",
32    "CURRENT_TIMESTAMP",
33    "CURRENT_USER",
34    "CURSOR_STATUS",
35    "DATALENGTH",
36    "DATEADD",
37    "DATEDIFF",
38    "DATEDIFF_BIG",
39    "DATEFROMPARTS",
40    "DATENAME",
41    "DATEPART",
42    "DATETIME2FROMPARTS",
43    "DATETIMEFROMPARTS",
44    "DAY",
45    "DB_ID",
46    "DB_NAME",
47    "DENSE_RANK",
48    "DIFFERENCE",
49    "EOMONTH",
50    "ERROR_LINE",
51    "ERROR_MESSAGE",
52    "ERROR_NUMBER",
53    "ERROR_PROCEDURE",
54    "ERROR_SEVERITY",
55    "ERROR_STATE",
56    "EXP",
57    "FIRST_VALUE",
58    "FLOOR",
59    "FORMAT",
60    "GETDATE",
61    "GETUTCDATE",
62    "GREATEST",
63    "GROUPING",
64    "GROUPING_ID",
65    "HAS_PERMS_BY_NAME",
66    "HOST_NAME",
67    "IDENTITY",
68    "IDENT_CURRENT",
69    "IFNULL",
70    "IIF",
71    "ISJSON",
72    "ISNULL",
73    "ISNUMERIC",
74    "JSON_ARRAY",
75    "JSON_MODIFY",
76    "JSON_OBJECT",
77    "JSON_QUERY",
78    "JSON_VALUE",
79    "LAG",
80    "LAST_VALUE",
81    "LEAD",
82    "LEAST",
83    "LEFT",
84    "LEN",
85    "LENGTH",
86    "LOG",
87    "LOG10",
88    "LOWER",
89    "LTRIM",
90    "MAX",
91    "MIN",
92    "MONTH",
93    "NCHAR",
94    "NEWID",
95    "NTILE",
96    "NULLIF",
97    "NVL",
98    "NVL2",
99    "OBJECT_ID",
100    "OBJECT_NAME",
101    "PARSENAME",
102    "PATINDEX",
103    "PERCENT_RANK",
104    "PI",
105    "POWER",
106    "QUOTENAME",
107    "RAND",
108    "RANK",
109    "REPLACE",
110    "REPLICATE",
111    "REVERSE",
112    "RIGHT",
113    "ROUND",
114    "ROW_NUMBER",
115    "RTRIM",
116    "SCHEMA_NAME",
117    "SCOPE_IDENTITY",
118    "SIGN",
119    "SIN",
120    "SOUNDEX",
121    "SPACE",
122    "SQRT",
123    "SQUARE",
124    "STR",
125    "STRING_AGG",
126    "STRING_SPLIT",
127    "STUFF",
128    "SUBSTRING",
129    "SUM",
130    "SUSER_SNAME",
131    "SWITCHOFFSET",
132    "SYSDATETIME",
133    "SYSUTCDATETIME",
134    "TAN",
135    "TODATETIMEOFFSET",
136    "TRANSLATE",
137    "TRIM",
138    "TRY_CAST",
139    "TRY_CONVERT",
140    "TRY_PARSE",
141    "TYPE_NAME",
142    "UNICODE",
143    "UPPER",
144    "USER_NAME",
145    "YEAR",
146];
147
148/// CP03: Function names must be consistently capitalised.
149///
150/// By default, expects UPPER case function names (sqlfluff-compatible).
151#[derive(Debug)]
152pub struct RuleCP03 {
153    pub policy: CapitalisationPolicy,
154}
155
156impl Default for RuleCP03 {
157    fn default() -> Self {
158        Self {
159            policy: CapitalisationPolicy::Upper,
160        }
161    }
162}
163
164impl Rule for RuleCP03 {
165    fn code(&self) -> &'static str {
166        "CP03"
167    }
168    fn name(&self) -> &'static str {
169        "capitalisation.functions"
170    }
171    fn description(&self) -> &'static str {
172        "Function names must be consistently capitalised."
173    }
174    fn explanation(&self) -> &'static str {
175        "Function names like COUNT, SUM, COALESCE should be consistently capitalised. \
176         Whether upper or lower depends on your team's convention."
177    }
178    fn groups(&self) -> &[RuleGroup] {
179        &[RuleGroup::Capitalisation]
180    }
181    fn is_fixable(&self) -> bool {
182        true
183    }
184
185    fn crawl_type(&self) -> CrawlType {
186        if self.policy == CapitalisationPolicy::Consistent {
187            CrawlType::RootOnly
188        } else {
189            CrawlType::Segment(vec![SegmentType::FunctionCall])
190        }
191    }
192
193    fn configure(&mut self, settings: &std::collections::HashMap<String, String>) {
194        if let Some(policy) = settings.get("capitalisation_policy") {
195            self.policy = CapitalisationPolicy::from_config(policy);
196        }
197    }
198
199    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
200        if self.policy == CapitalisationPolicy::Consistent {
201            return self.eval_consistent(ctx);
202        }
203
204        // FunctionCall's first child should be the function name (Identifier)
205        let children = ctx.segment.children();
206        if children.is_empty() {
207            return vec![];
208        }
209
210        // Walk to find the function name token
211        let name_seg = Self::find_function_name(children);
212        let Some(Segment::Token(t)) = name_seg else {
213            return vec![];
214        };
215        if t.token.kind != TokenKind::Word {
216            return vec![];
217        }
218
219        let text = t.token.text.as_str();
220        let upper = text.to_ascii_uppercase();
221
222        // Only check built-in SQL functions; skip user-defined functions
223        if BUILTIN_FUNCTIONS.binary_search(&upper.as_str()).is_err() {
224            return vec![];
225        }
226
227        let (expected, policy_name) = match self.policy {
228            CapitalisationPolicy::Upper => (upper, "upper"),
229            CapitalisationPolicy::Lower => (text.to_ascii_lowercase(), "lower"),
230            CapitalisationPolicy::Capitalise => (crate::utils::capitalise(text), "capitalised"),
231            CapitalisationPolicy::Consistent => unreachable!(),
232        };
233
234        check_capitalisation(
235            self.code(),
236            "Function names",
237            text,
238            &expected,
239            policy_name,
240            t.token.span,
241        )
242        .into_iter()
243        .collect()
244    }
245}
246
247impl RuleCP03 {
248    fn eval_consistent(&self, ctx: &RuleContext) -> Vec<LintViolation> {
249        let mut tokens = Vec::new();
250        Self::collect_builtin_function_names(ctx.root, &mut tokens);
251
252        if tokens.is_empty() {
253            return vec![];
254        }
255
256        let majority = determine_majority_case(&tokens);
257        let mut violations = Vec::new();
258        for (text, span) in &tokens {
259            let expected = match majority {
260                "upper" => text.to_ascii_uppercase(),
261                _ => text.to_ascii_lowercase(),
262            };
263            if let Some(v) = check_capitalisation(
264                self.code(),
265                "Function names",
266                text,
267                &expected,
268                majority,
269                *span,
270            ) {
271                violations.push(v);
272            }
273        }
274        violations
275    }
276
277    /// Recursively collect built-in function name tokens from the CST.
278    fn collect_builtin_function_names(
279        segment: &Segment,
280        out: &mut Vec<(String, rigsql_core::Span)>,
281    ) {
282        if segment.segment_type() == SegmentType::FunctionCall {
283            if let Some(Segment::Token(t)) = Self::find_function_name(segment.children()) {
284                if t.token.kind == TokenKind::Word {
285                    let upper = t.token.text.to_ascii_uppercase();
286                    if BUILTIN_FUNCTIONS.binary_search(&upper.as_str()).is_ok() {
287                        out.push((t.token.text.to_string(), t.token.span));
288                    }
289                }
290            }
291        }
292        for child in segment.children() {
293            Self::collect_builtin_function_names(child, out);
294        }
295    }
296
297    fn find_function_name(children: &[Segment]) -> Option<&Segment> {
298        for child in children {
299            match child.segment_type() {
300                SegmentType::Identifier => return Some(child),
301                SegmentType::ColumnRef => {
302                    // qualified function: schema.func — get last identifier
303                    let inner = child.children();
304                    return inner
305                        .iter()
306                        .rev()
307                        .find(|s| s.segment_type() == SegmentType::Identifier);
308                }
309                _ if child.segment_type().is_trivia() => continue,
310                _ => break,
311            }
312        }
313        None
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use crate::test_utils::lint_sql;
321
322    #[test]
323    fn test_cp03_flags_lowercase_function() {
324        // Default policy is upper, so lowercase should be flagged
325        let violations = lint_sql("SELECT count(*) FROM t", RuleCP03::default());
326        assert_eq!(violations.len(), 1);
327        assert_eq!(violations[0].fixes[0].new_text, "COUNT");
328    }
329
330    #[test]
331    fn test_cp03_flags_mixed_case() {
332        let violations = lint_sql("SELECT Count(*) FROM t", RuleCP03::default());
333        assert_eq!(violations.len(), 1);
334        assert_eq!(violations[0].fixes[0].new_text, "COUNT");
335    }
336
337    #[test]
338    fn test_cp03_accepts_all_upper() {
339        let violations = lint_sql("SELECT COUNT(*) FROM t", RuleCP03::default());
340        assert_eq!(violations.len(), 0);
341    }
342
343    #[test]
344    fn test_cp03_lower_policy_flags_upper() {
345        let rule = RuleCP03 {
346            policy: CapitalisationPolicy::Lower,
347        };
348        let violations = lint_sql("SELECT COUNT(*) FROM t", rule);
349        assert_eq!(violations.len(), 1);
350        assert_eq!(violations[0].fixes[0].new_text, "count");
351    }
352
353    #[test]
354    fn test_cp03_lower_policy_accepts_lower() {
355        let rule = RuleCP03 {
356            policy: CapitalisationPolicy::Lower,
357        };
358        let violations = lint_sql("SELECT count(*) FROM t", rule);
359        assert_eq!(violations.len(), 0);
360    }
361
362    #[test]
363    fn test_cp03_capitalise_policy() {
364        let rule = RuleCP03 {
365            policy: CapitalisationPolicy::Capitalise,
366        };
367        let violations = lint_sql("SELECT count(*) FROM t", rule);
368        assert_eq!(violations.len(), 1);
369        assert_eq!(violations[0].fixes[0].new_text, "Count");
370    }
371
372    #[test]
373    fn test_cp03_skips_user_defined_function() {
374        let violations = lint_sql(
375            "SELECT GetDropdownOptions('a', 'b') FROM t",
376            RuleCP03::default(),
377        );
378        assert_eq!(violations.len(), 0);
379    }
380
381    #[test]
382    fn test_cp03_consistent_flags_minority() {
383        // 2 upper (COUNT, SUM) vs 1 lower (avg) → majority upper, flag "avg"
384        let rule = RuleCP03 {
385            policy: CapitalisationPolicy::Consistent,
386        };
387        let violations = lint_sql("SELECT COUNT(*), SUM(x), avg(y) FROM t", rule);
388        assert_eq!(violations.len(), 1);
389        assert_eq!(violations[0].fixes[0].new_text, "AVG");
390    }
391
392    #[test]
393    fn test_cp03_consistent_all_same_no_violation() {
394        let rule = RuleCP03 {
395            policy: CapitalisationPolicy::Consistent,
396        };
397        let violations = lint_sql("SELECT COUNT(*), SUM(x) FROM t", rule);
398        assert_eq!(violations.len(), 0);
399    }
400
401    #[test]
402    fn test_cp03_consistent_majority_lower() {
403        // 2 lower (count, sum) vs 1 upper (AVG) → majority lower, flag "AVG"
404        let rule = RuleCP03 {
405            policy: CapitalisationPolicy::Consistent,
406        };
407        let violations = lint_sql("SELECT count(*), sum(x), AVG(y) FROM t", rule);
408        assert_eq!(violations.len(), 1);
409        assert_eq!(violations[0].fixes[0].new_text, "avg");
410    }
411
412    #[test]
413    fn test_cp03_flags_replace_function() {
414        // The issue from #32: replace should be flagged and fixed to REPLACE
415        let violations = lint_sql("SELECT replace(col, 'a', 'b') FROM t", RuleCP03::default());
416        assert_eq!(violations.len(), 1);
417        assert_eq!(violations[0].fixes[0].new_text, "REPLACE");
418    }
419}