rigsql_rules/capitalisation/
cp03.rs1use rigsql_core::{Segment, SegmentType, TokenKind};
2
3use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
4use crate::violation::{LintViolation, SourceEdit};
5
6const BUILTIN_FUNCTIONS: &[&str] = &[
8 "ABS",
9 "ACOS",
10 "APP_NAME",
11 "ASCII",
12 "ASIN",
13 "ATAN",
14 "ATAN2",
15 "AVG",
16 "CAST",
17 "CEILING",
18 "CHAR",
19 "CHARINDEX",
20 "CHOOSE",
21 "COALESCE",
22 "CONCAT",
23 "CONCAT_WS",
24 "CONVERT",
25 "COS",
26 "COT",
27 "COUNT",
28 "COUNT_BIG",
29 "CUME_DIST",
30 "CURRENT_TIMESTAMP",
31 "CURRENT_USER",
32 "CURSOR_STATUS",
33 "DATALENGTH",
34 "DATEADD",
35 "DATEDIFF",
36 "DATEDIFF_BIG",
37 "DATEFROMPARTS",
38 "DATENAME",
39 "DATEPART",
40 "DATETIME2FROMPARTS",
41 "DATETIMEFROMPARTS",
42 "DAY",
43 "DB_ID",
44 "DB_NAME",
45 "DENSE_RANK",
46 "DIFFERENCE",
47 "EOMONTH",
48 "ERROR_LINE",
49 "ERROR_MESSAGE",
50 "ERROR_NUMBER",
51 "ERROR_PROCEDURE",
52 "ERROR_SEVERITY",
53 "ERROR_STATE",
54 "EXP",
55 "FIRST_VALUE",
56 "FLOOR",
57 "FORMAT",
58 "GETDATE",
59 "GETUTCDATE",
60 "GREATEST",
61 "GROUPING",
62 "GROUPING_ID",
63 "HAS_PERMS_BY_NAME",
64 "HOST_NAME",
65 "IDENTITY",
66 "IDENT_CURRENT",
67 "IFNULL",
68 "IIF",
69 "ISJSON",
70 "ISNULL",
71 "ISNUMERIC",
72 "JSON_ARRAY",
73 "JSON_MODIFY",
74 "JSON_OBJECT",
75 "JSON_QUERY",
76 "JSON_VALUE",
77 "LAG",
78 "LAST_VALUE",
79 "LEAD",
80 "LEAST",
81 "LEFT",
82 "LEN",
83 "LENGTH",
84 "LOG",
85 "LOG10",
86 "LOWER",
87 "LTRIM",
88 "MAX",
89 "MIN",
90 "MONTH",
91 "NCHAR",
92 "NEWID",
93 "NTILE",
94 "NULLIF",
95 "NVL",
96 "NVL2",
97 "OBJECT_ID",
98 "OBJECT_NAME",
99 "PARSENAME",
100 "PATINDEX",
101 "PERCENT_RANK",
102 "PI",
103 "POWER",
104 "QUOTENAME",
105 "RAND",
106 "RANK",
107 "REPLACE",
108 "REPLICATE",
109 "REVERSE",
110 "RIGHT",
111 "ROUND",
112 "ROW_NUMBER",
113 "RTRIM",
114 "SCHEMA_NAME",
115 "SCOPE_IDENTITY",
116 "SIGN",
117 "SIN",
118 "SOUNDEX",
119 "SPACE",
120 "SQRT",
121 "SQUARE",
122 "STR",
123 "STRING_AGG",
124 "STRING_SPLIT",
125 "STUFF",
126 "SUBSTRING",
127 "SUM",
128 "SUSER_SNAME",
129 "SWITCHOFFSET",
130 "SYSDATETIME",
131 "SYSUTCDATETIME",
132 "TAN",
133 "TODATETIMEOFFSET",
134 "TRANSLATE",
135 "TRIM",
136 "TRY_CAST",
137 "TRY_CONVERT",
138 "TRY_PARSE",
139 "TYPE_NAME",
140 "UNICODE",
141 "UPPER",
142 "USER_NAME",
143 "YEAR",
144];
145
146#[derive(Debug, Default)]
150pub struct RuleCP03;
151
152impl Rule for RuleCP03 {
153 fn code(&self) -> &'static str {
154 "CP03"
155 }
156 fn name(&self) -> &'static str {
157 "capitalisation.functions"
158 }
159 fn description(&self) -> &'static str {
160 "Function names must be consistently capitalised."
161 }
162 fn explanation(&self) -> &'static str {
163 "Function names like COUNT, SUM, COALESCE should be consistently capitalised. \
164 Whether upper or lower depends on your team's convention."
165 }
166 fn groups(&self) -> &[RuleGroup] {
167 &[RuleGroup::Capitalisation]
168 }
169 fn is_fixable(&self) -> bool {
170 true
171 }
172
173 fn crawl_type(&self) -> CrawlType {
174 CrawlType::Segment(vec![SegmentType::FunctionCall])
175 }
176
177 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
178 let children = ctx.segment.children();
180 if children.is_empty() {
181 return vec![];
182 }
183
184 let name_seg = Self::find_function_name(children);
186 let Some(Segment::Token(t)) = name_seg else {
187 return vec![];
188 };
189 if t.token.kind != TokenKind::Word {
190 return vec![];
191 }
192
193 let text = t.token.text.as_str();
195 let upper = text.to_ascii_uppercase();
196
197 if BUILTIN_FUNCTIONS.binary_search(&upper.as_str()).is_err() {
199 return vec![];
200 }
201
202 let is_all_upper = text
206 .chars()
207 .all(|c| !c.is_ascii_alphabetic() || c.is_ascii_uppercase());
208 let is_all_lower = text
209 .chars()
210 .all(|c| !c.is_ascii_alphabetic() || c.is_ascii_lowercase());
211 if is_all_upper || is_all_lower {
212 return vec![];
213 }
214
215 vec![LintViolation::with_fix_and_msg_key(
216 self.code(),
217 format!(
218 "Function name '{}' has inconsistent capitalisation. Use all upper or all lower case.",
219 text
220 ),
221 t.token.span,
222 vec![SourceEdit::replace(t.token.span, upper)],
223 "rules.CP03.msg",
224 vec![("name".to_string(), text.to_string())],
225 )]
226 }
227}
228
229impl RuleCP03 {
230 fn find_function_name(children: &[Segment]) -> Option<&Segment> {
231 for child in children {
232 match child.segment_type() {
233 SegmentType::Identifier => return Some(child),
234 SegmentType::ColumnRef => {
235 let inner = child.children();
237 return inner
238 .iter()
239 .rev()
240 .find(|s| s.segment_type() == SegmentType::Identifier);
241 }
242 _ if child.segment_type().is_trivia() => continue,
243 _ => break,
244 }
245 }
246 None
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use crate::test_utils::lint_sql;
254
255 #[test]
256 fn test_cp03_flags_mixed_case() {
257 let violations = lint_sql("SELECT Count(*) FROM t", RuleCP03);
258 assert_eq!(violations.len(), 1);
259 }
260
261 #[test]
262 fn test_cp03_accepts_all_upper() {
263 let violations = lint_sql("SELECT COUNT(*) FROM t", RuleCP03);
264 assert_eq!(violations.len(), 0);
265 }
266
267 #[test]
268 fn test_cp03_accepts_all_lower() {
269 let violations = lint_sql("SELECT count(*) FROM t", RuleCP03);
270 assert_eq!(violations.len(), 0);
271 }
272
273 #[test]
274 fn test_cp03_skips_user_defined_function() {
275 let violations = lint_sql("SELECT GetDropdownOptions('a', 'b') FROM t", RuleCP03);
276 assert_eq!(violations.len(), 0);
277 }
278}