Skip to main content

rigsql_rules/capitalisation/
cp03.rs

1use rigsql_core::{Segment, SegmentType, TokenKind};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::{LintViolation, SourceEdit};
5
6/// Built-in SQL function names (sorted alphabetically for binary_search).
7const BUILTIN_FUNCTIONS: &[&str] = &[
8    "ABS",
9    "ACOS",
10    "APP_NAME",
11    "ASCII",
12    "ASIN",
13    "ATAN",
14    "ATAN2",
15    "AVG",
16    "CAST",
17    "CEILING",
18    "CHAR",
19    "CHARINDEX",
20    "CHOOSE",
21    "COALESCE",
22    "CONCAT",
23    "CONCAT_WS",
24    "CONVERT",
25    "COS",
26    "COT",
27    "COUNT",
28    "COUNT_BIG",
29    "CUME_DIST",
30    "CURRENT_TIMESTAMP",
31    "CURRENT_USER",
32    "CURSOR_STATUS",
33    "DATALENGTH",
34    "DATEADD",
35    "DATEDIFF",
36    "DATEDIFF_BIG",
37    "DATEFROMPARTS",
38    "DATENAME",
39    "DATEPART",
40    "DATETIME2FROMPARTS",
41    "DATETIMEFROMPARTS",
42    "DAY",
43    "DB_ID",
44    "DB_NAME",
45    "DENSE_RANK",
46    "DIFFERENCE",
47    "EOMONTH",
48    "ERROR_LINE",
49    "ERROR_MESSAGE",
50    "ERROR_NUMBER",
51    "ERROR_PROCEDURE",
52    "ERROR_SEVERITY",
53    "ERROR_STATE",
54    "EXP",
55    "FIRST_VALUE",
56    "FLOOR",
57    "FORMAT",
58    "GETDATE",
59    "GETUTCDATE",
60    "GREATEST",
61    "GROUPING",
62    "GROUPING_ID",
63    "HAS_PERMS_BY_NAME",
64    "HOST_NAME",
65    "IDENTITY",
66    "IDENT_CURRENT",
67    "IFNULL",
68    "IIF",
69    "ISJSON",
70    "ISNULL",
71    "ISNUMERIC",
72    "JSON_ARRAY",
73    "JSON_MODIFY",
74    "JSON_OBJECT",
75    "JSON_QUERY",
76    "JSON_VALUE",
77    "LAG",
78    "LAST_VALUE",
79    "LEAD",
80    "LEAST",
81    "LEFT",
82    "LEN",
83    "LENGTH",
84    "LOG",
85    "LOG10",
86    "LOWER",
87    "LTRIM",
88    "MAX",
89    "MIN",
90    "MONTH",
91    "NCHAR",
92    "NEWID",
93    "NTILE",
94    "NULLIF",
95    "NVL",
96    "NVL2",
97    "OBJECT_ID",
98    "OBJECT_NAME",
99    "PARSENAME",
100    "PATINDEX",
101    "PERCENT_RANK",
102    "PI",
103    "POWER",
104    "QUOTENAME",
105    "RAND",
106    "RANK",
107    "REPLACE",
108    "REPLICATE",
109    "REVERSE",
110    "RIGHT",
111    "ROUND",
112    "ROW_NUMBER",
113    "RTRIM",
114    "SCHEMA_NAME",
115    "SCOPE_IDENTITY",
116    "SIGN",
117    "SIN",
118    "SOUNDEX",
119    "SPACE",
120    "SQRT",
121    "SQUARE",
122    "STR",
123    "STRING_AGG",
124    "STRING_SPLIT",
125    "STUFF",
126    "SUBSTRING",
127    "SUM",
128    "SUSER_SNAME",
129    "SWITCHOFFSET",
130    "SYSDATETIME",
131    "SYSUTCDATETIME",
132    "TAN",
133    "TODATETIMEOFFSET",
134    "TRANSLATE",
135    "TRIM",
136    "TRY_CAST",
137    "TRY_CONVERT",
138    "TRY_PARSE",
139    "TYPE_NAME",
140    "UNICODE",
141    "UPPER",
142    "USER_NAME",
143    "YEAR",
144];
145
146/// CP03: Function names must be consistently capitalised.
147///
148/// By default, expects lower case function names.
149#[derive(Debug, Default)]
150pub struct RuleCP03;
151
152impl Rule for RuleCP03 {
153    fn code(&self) -> &'static str {
154        "CP03"
155    }
156    fn name(&self) -> &'static str {
157        "capitalisation.functions"
158    }
159    fn description(&self) -> &'static str {
160        "Function names must be consistently capitalised."
161    }
162    fn explanation(&self) -> &'static str {
163        "Function names like COUNT, SUM, COALESCE should be consistently capitalised. \
164         Whether upper or lower depends on your team's convention."
165    }
166    fn groups(&self) -> &[RuleGroup] {
167        &[RuleGroup::Capitalisation]
168    }
169    fn is_fixable(&self) -> bool {
170        true
171    }
172
173    fn crawl_type(&self) -> CrawlType {
174        CrawlType::Segment(vec![SegmentType::FunctionCall])
175    }
176
177    fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
178        // FunctionCall's first child should be the function name (Identifier)
179        let children = ctx.segment.children();
180        if children.is_empty() {
181            return vec![];
182        }
183
184        // Walk to find the function name token
185        let name_seg = Self::find_function_name(children);
186        let Some(Segment::Token(t)) = name_seg else {
187            return vec![];
188        };
189        if t.token.kind != TokenKind::Word {
190            return vec![];
191        }
192
193        // Check: function names should be consistent (default: lower)
194        let text = t.token.text.as_str();
195        let upper = text.to_ascii_uppercase();
196
197        // Only check built-in SQL functions; skip user-defined functions
198        if BUILTIN_FUNCTIONS.binary_search(&upper.as_str()).is_err() {
199            return vec![];
200        }
201
202        // Skip if it's all upper or all lower (both are acceptable in many configs)
203        // Default: we don't enforce function name case (many projects use either)
204        // Only flag mixed case
205        let is_all_upper = text
206            .chars()
207            .all(|c| !c.is_ascii_alphabetic() || c.is_ascii_uppercase());
208        let is_all_lower = text
209            .chars()
210            .all(|c| !c.is_ascii_alphabetic() || c.is_ascii_lowercase());
211        if is_all_upper || is_all_lower {
212            return vec![];
213        }
214
215        vec![LintViolation::with_fix_and_msg_key(
216            self.code(),
217            format!(
218                "Function name '{}' has inconsistent capitalisation. Use all upper or all lower case.",
219                text
220            ),
221            t.token.span,
222            vec![SourceEdit::replace(t.token.span, upper)],
223            "rules.CP03.msg",
224            vec![("name".to_string(), text.to_string())],
225        )]
226    }
227}
228
229impl RuleCP03 {
230    fn find_function_name(children: &[Segment]) -> Option<&Segment> {
231        for child in children {
232            match child.segment_type() {
233                SegmentType::Identifier => return Some(child),
234                SegmentType::ColumnRef => {
235                    // qualified function: schema.func — get last identifier
236                    let inner = child.children();
237                    return inner
238                        .iter()
239                        .rev()
240                        .find(|s| s.segment_type() == SegmentType::Identifier);
241                }
242                _ if child.segment_type().is_trivia() => continue,
243                _ => break,
244            }
245        }
246        None
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use crate::test_utils::lint_sql;
254
255    #[test]
256    fn test_cp03_flags_mixed_case() {
257        let violations = lint_sql("SELECT Count(*) FROM t", RuleCP03);
258        assert_eq!(violations.len(), 1);
259    }
260
261    #[test]
262    fn test_cp03_accepts_all_upper() {
263        let violations = lint_sql("SELECT COUNT(*) FROM t", RuleCP03);
264        assert_eq!(violations.len(), 0);
265    }
266
267    #[test]
268    fn test_cp03_accepts_all_lower() {
269        let violations = lint_sql("SELECT count(*) FROM t", RuleCP03);
270        assert_eq!(violations.len(), 0);
271    }
272
273    #[test]
274    fn test_cp03_skips_user_defined_function() {
275        let violations = lint_sql("SELECT GetDropdownOptions('a', 'b') FROM t", RuleCP03);
276        assert_eq!(violations.len(), 0);
277    }
278}