1use rigsql_core::{Segment, SegmentType, TokenKind};
2
3use super::CapitalisationPolicy;
4use crate::rule::{CrawlType, Rule, RuleContext, RuleGroup};
5use crate::utils::{check_capitalisation, determine_majority_case};
6use crate::violation::LintViolation;
7
8const BUILTIN_FUNCTIONS: &[&str] = &[
10 "ABS",
11 "ACOS",
12 "APP_NAME",
13 "ASCII",
14 "ASIN",
15 "ATAN",
16 "ATAN2",
17 "AVG",
18 "CAST",
19 "CEILING",
20 "CHAR",
21 "CHARINDEX",
22 "CHOOSE",
23 "COALESCE",
24 "CONCAT",
25 "CONCAT_WS",
26 "CONVERT",
27 "COS",
28 "COT",
29 "COUNT",
30 "COUNT_BIG",
31 "CUME_DIST",
32 "CURRENT_TIMESTAMP",
33 "CURRENT_USER",
34 "CURSOR_STATUS",
35 "DATALENGTH",
36 "DATEADD",
37 "DATEDIFF",
38 "DATEDIFF_BIG",
39 "DATEFROMPARTS",
40 "DATENAME",
41 "DATEPART",
42 "DATETIME2FROMPARTS",
43 "DATETIMEFROMPARTS",
44 "DAY",
45 "DB_ID",
46 "DB_NAME",
47 "DENSE_RANK",
48 "DIFFERENCE",
49 "EOMONTH",
50 "ERROR_LINE",
51 "ERROR_MESSAGE",
52 "ERROR_NUMBER",
53 "ERROR_PROCEDURE",
54 "ERROR_SEVERITY",
55 "ERROR_STATE",
56 "EXP",
57 "FIRST_VALUE",
58 "FLOOR",
59 "FORMAT",
60 "GETDATE",
61 "GETUTCDATE",
62 "GREATEST",
63 "GROUPING",
64 "GROUPING_ID",
65 "HAS_PERMS_BY_NAME",
66 "HOST_NAME",
67 "IDENTITY",
68 "IDENT_CURRENT",
69 "IFNULL",
70 "IIF",
71 "ISJSON",
72 "ISNULL",
73 "ISNUMERIC",
74 "JSON_ARRAY",
75 "JSON_MODIFY",
76 "JSON_OBJECT",
77 "JSON_QUERY",
78 "JSON_VALUE",
79 "LAG",
80 "LAST_VALUE",
81 "LEAD",
82 "LEAST",
83 "LEFT",
84 "LEN",
85 "LENGTH",
86 "LOG",
87 "LOG10",
88 "LOWER",
89 "LTRIM",
90 "MAX",
91 "MIN",
92 "MONTH",
93 "NCHAR",
94 "NEWID",
95 "NTILE",
96 "NULLIF",
97 "NVL",
98 "NVL2",
99 "OBJECT_ID",
100 "OBJECT_NAME",
101 "PARSENAME",
102 "PATINDEX",
103 "PERCENT_RANK",
104 "PI",
105 "POWER",
106 "QUOTENAME",
107 "RAND",
108 "RANK",
109 "REPLACE",
110 "REPLICATE",
111 "REVERSE",
112 "RIGHT",
113 "ROUND",
114 "ROW_NUMBER",
115 "RTRIM",
116 "SCHEMA_NAME",
117 "SCOPE_IDENTITY",
118 "SIGN",
119 "SIN",
120 "SOUNDEX",
121 "SPACE",
122 "SQRT",
123 "SQUARE",
124 "STR",
125 "STRING_AGG",
126 "STRING_SPLIT",
127 "STUFF",
128 "SUBSTRING",
129 "SUM",
130 "SUSER_SNAME",
131 "SWITCHOFFSET",
132 "SYSDATETIME",
133 "SYSUTCDATETIME",
134 "TAN",
135 "TODATETIMEOFFSET",
136 "TRANSLATE",
137 "TRIM",
138 "TRY_CAST",
139 "TRY_CONVERT",
140 "TRY_PARSE",
141 "TYPE_NAME",
142 "UNICODE",
143 "UPPER",
144 "USER_NAME",
145 "YEAR",
146];
147
148#[derive(Debug)]
152pub struct RuleCP03 {
153 pub policy: CapitalisationPolicy,
154}
155
156impl Default for RuleCP03 {
157 fn default() -> Self {
158 Self {
159 policy: CapitalisationPolicy::Upper,
160 }
161 }
162}
163
164impl Rule for RuleCP03 {
165 fn code(&self) -> &'static str {
166 "CP03"
167 }
168 fn name(&self) -> &'static str {
169 "capitalisation.functions"
170 }
171 fn description(&self) -> &'static str {
172 "Function names must be consistently capitalised."
173 }
174 fn explanation(&self) -> &'static str {
175 "Function names like COUNT, SUM, COALESCE should be consistently capitalised. \
176 Whether upper or lower depends on your team's convention."
177 }
178 fn groups(&self) -> &[RuleGroup] {
179 &[RuleGroup::Capitalisation]
180 }
181 fn is_fixable(&self) -> bool {
182 true
183 }
184
185 fn crawl_type(&self) -> CrawlType {
186 if self.policy == CapitalisationPolicy::Consistent {
187 CrawlType::RootOnly
188 } else {
189 CrawlType::Segment(vec![SegmentType::FunctionCall])
190 }
191 }
192
193 fn configure(&mut self, settings: &std::collections::HashMap<String, String>) {
194 if let Some(policy) = settings.get("capitalisation_policy") {
195 self.policy = CapitalisationPolicy::from_config(policy);
196 }
197 }
198
199 fn eval(&self, ctx: &RuleContext) -> Vec<LintViolation> {
200 if self.policy == CapitalisationPolicy::Consistent {
201 return self.eval_consistent(ctx);
202 }
203
204 let children = ctx.segment.children();
206 if children.is_empty() {
207 return vec![];
208 }
209
210 let name_seg = Self::find_function_name(children);
212 let Some(Segment::Token(t)) = name_seg else {
213 return vec![];
214 };
215 if t.token.kind != TokenKind::Word {
216 return vec![];
217 }
218
219 let text = t.token.text.as_str();
220 let upper = text.to_ascii_uppercase();
221
222 if BUILTIN_FUNCTIONS.binary_search(&upper.as_str()).is_err() {
224 return vec![];
225 }
226
227 let (expected, policy_name) = match self.policy {
228 CapitalisationPolicy::Upper => (upper, "upper"),
229 CapitalisationPolicy::Lower => (text.to_ascii_lowercase(), "lower"),
230 CapitalisationPolicy::Capitalise => (crate::utils::capitalise(text), "capitalised"),
231 CapitalisationPolicy::Consistent => unreachable!(),
232 };
233
234 check_capitalisation(
235 self.code(),
236 "Function names",
237 text,
238 &expected,
239 policy_name,
240 t.token.span,
241 )
242 .into_iter()
243 .collect()
244 }
245}
246
247impl RuleCP03 {
248 fn eval_consistent(&self, ctx: &RuleContext) -> Vec<LintViolation> {
249 let mut tokens = Vec::new();
250 Self::collect_builtin_function_names(ctx.root, &mut tokens);
251
252 if tokens.is_empty() {
253 return vec![];
254 }
255
256 let majority = determine_majority_case(&tokens);
257 let mut violations = Vec::new();
258 for (text, span) in &tokens {
259 let expected = match majority {
260 "upper" => text.to_ascii_uppercase(),
261 _ => text.to_ascii_lowercase(),
262 };
263 if let Some(v) = check_capitalisation(
264 self.code(),
265 "Function names",
266 text,
267 &expected,
268 majority,
269 *span,
270 ) {
271 violations.push(v);
272 }
273 }
274 violations
275 }
276
277 fn collect_builtin_function_names(
279 segment: &Segment,
280 out: &mut Vec<(String, rigsql_core::Span)>,
281 ) {
282 if segment.segment_type() == SegmentType::FunctionCall {
283 if let Some(Segment::Token(t)) = Self::find_function_name(segment.children()) {
284 if t.token.kind == TokenKind::Word {
285 let upper = t.token.text.to_ascii_uppercase();
286 if BUILTIN_FUNCTIONS.binary_search(&upper.as_str()).is_ok() {
287 out.push((t.token.text.to_string(), t.token.span));
288 }
289 }
290 }
291 }
292 for child in segment.children() {
293 Self::collect_builtin_function_names(child, out);
294 }
295 }
296
297 fn find_function_name(children: &[Segment]) -> Option<&Segment> {
298 for child in children {
299 match child.segment_type() {
300 SegmentType::Identifier => return Some(child),
301 SegmentType::ColumnRef => {
302 let inner = child.children();
304 return inner
305 .iter()
306 .rev()
307 .find(|s| s.segment_type() == SegmentType::Identifier);
308 }
309 _ if child.segment_type().is_trivia() => continue,
310 _ => break,
311 }
312 }
313 None
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use crate::test_utils::lint_sql;
321
322 #[test]
323 fn test_cp03_flags_lowercase_function() {
324 let violations = lint_sql("SELECT count(*) FROM t", RuleCP03::default());
326 assert_eq!(violations.len(), 1);
327 assert_eq!(violations[0].fixes[0].new_text, "COUNT");
328 }
329
330 #[test]
331 fn test_cp03_flags_mixed_case() {
332 let violations = lint_sql("SELECT Count(*) FROM t", RuleCP03::default());
333 assert_eq!(violations.len(), 1);
334 assert_eq!(violations[0].fixes[0].new_text, "COUNT");
335 }
336
337 #[test]
338 fn test_cp03_accepts_all_upper() {
339 let violations = lint_sql("SELECT COUNT(*) FROM t", RuleCP03::default());
340 assert_eq!(violations.len(), 0);
341 }
342
343 #[test]
344 fn test_cp03_lower_policy_flags_upper() {
345 let rule = RuleCP03 {
346 policy: CapitalisationPolicy::Lower,
347 };
348 let violations = lint_sql("SELECT COUNT(*) FROM t", rule);
349 assert_eq!(violations.len(), 1);
350 assert_eq!(violations[0].fixes[0].new_text, "count");
351 }
352
353 #[test]
354 fn test_cp03_lower_policy_accepts_lower() {
355 let rule = RuleCP03 {
356 policy: CapitalisationPolicy::Lower,
357 };
358 let violations = lint_sql("SELECT count(*) FROM t", rule);
359 assert_eq!(violations.len(), 0);
360 }
361
362 #[test]
363 fn test_cp03_capitalise_policy() {
364 let rule = RuleCP03 {
365 policy: CapitalisationPolicy::Capitalise,
366 };
367 let violations = lint_sql("SELECT count(*) FROM t", rule);
368 assert_eq!(violations.len(), 1);
369 assert_eq!(violations[0].fixes[0].new_text, "Count");
370 }
371
372 #[test]
373 fn test_cp03_skips_user_defined_function() {
374 let violations = lint_sql(
375 "SELECT GetDropdownOptions('a', 'b') FROM t",
376 RuleCP03::default(),
377 );
378 assert_eq!(violations.len(), 0);
379 }
380
381 #[test]
382 fn test_cp03_consistent_flags_minority() {
383 let rule = RuleCP03 {
385 policy: CapitalisationPolicy::Consistent,
386 };
387 let violations = lint_sql("SELECT COUNT(*), SUM(x), avg(y) FROM t", rule);
388 assert_eq!(violations.len(), 1);
389 assert_eq!(violations[0].fixes[0].new_text, "AVG");
390 }
391
392 #[test]
393 fn test_cp03_consistent_all_same_no_violation() {
394 let rule = RuleCP03 {
395 policy: CapitalisationPolicy::Consistent,
396 };
397 let violations = lint_sql("SELECT COUNT(*), SUM(x) FROM t", rule);
398 assert_eq!(violations.len(), 0);
399 }
400
401 #[test]
402 fn test_cp03_consistent_majority_lower() {
403 let rule = RuleCP03 {
405 policy: CapitalisationPolicy::Consistent,
406 };
407 let violations = lint_sql("SELECT count(*), sum(x), AVG(y) FROM t", rule);
408 assert_eq!(violations.len(), 1);
409 assert_eq!(violations[0].fixes[0].new_text, "avg");
410 }
411
412 #[test]
413 fn test_cp03_flags_replace_function() {
414 let violations = lint_sql("SELECT replace(col, 'a', 'b') FROM t", RuleCP03::default());
416 assert_eq!(violations.len(), 1);
417 assert_eq!(violations[0].fixes[0].new_text, "REPLACE");
418 }
419}