1use std::collections::HashSet;
6
7use crate::linter::config::LintConfig;
8use crate::linter::rule::{LintContext, LintRule};
9use crate::types::{issue_codes, Dialect, Issue, IssueAutofixApplicability, IssuePatchEdit};
10use regex::Regex;
11use sqlparser::ast::Statement;
12use sqlparser::tokenizer::{Token, TokenWithSpan, Tokenizer};
13
14use super::capitalisation_policy_helpers::{
15 ignored_words_from_config, ignored_words_regex_from_config, token_is_ignored,
16 tokens_violate_policy, CapitalisationPolicy,
17};
18
19pub struct CapitalisationKeywords {
20 policy: CapitalisationPolicy,
21 ignore_words: HashSet<String>,
22 ignore_words_regex: Option<Regex>,
23}
24
25impl CapitalisationKeywords {
26 pub fn from_config(config: &LintConfig) -> Self {
27 Self {
28 policy: CapitalisationPolicy::from_rule_config(
29 config,
30 issue_codes::LINT_CP_001,
31 "capitalisation_policy",
32 ),
33 ignore_words: ignored_words_from_config(config, issue_codes::LINT_CP_001),
34 ignore_words_regex: ignored_words_regex_from_config(config, issue_codes::LINT_CP_001),
35 }
36 }
37}
38
39impl Default for CapitalisationKeywords {
40 fn default() -> Self {
41 Self {
42 policy: CapitalisationPolicy::Consistent,
43 ignore_words: HashSet::new(),
44 ignore_words_regex: None,
45 }
46 }
47}
48
49impl LintRule for CapitalisationKeywords {
50 fn code(&self) -> &'static str {
51 issue_codes::LINT_CP_001
52 }
53
54 fn name(&self) -> &'static str {
55 "Keyword capitalisation"
56 }
57
58 fn description(&self) -> &'static str {
59 "Inconsistent capitalisation of keywords."
60 }
61
62 fn check(&self, _statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
63 let keywords =
64 keyword_tokens_for_context(ctx, &self.ignore_words, self.ignore_words_regex.as_ref());
65 let keyword_values = keywords
66 .iter()
67 .map(|candidate| candidate.value.clone())
68 .collect::<Vec<_>>();
69 if !tokens_violate_policy(&keyword_values, self.policy) {
70 Vec::new()
71 } else {
72 let mut issue = Issue::info(
73 issue_codes::LINT_CP_001,
74 "SQL keywords use inconsistent capitalisation.",
75 )
76 .with_statement(ctx.statement_index);
77
78 let autofix_edits = keyword_autofix_edits(ctx, &keywords, self.policy);
79 if !autofix_edits.is_empty() {
80 issue = issue.with_autofix_edits(IssueAutofixApplicability::Safe, autofix_edits);
81 }
82
83 vec![issue]
84 }
85 }
86}
87
88#[derive(Clone)]
89struct KeywordCandidate {
90 value: String,
91 start: usize,
92 end: usize,
93}
94
95fn keyword_tokens_for_context(
96 ctx: &LintContext,
97 ignore_words: &HashSet<String>,
98 ignore_words_regex: Option<&Regex>,
99) -> Vec<KeywordCandidate> {
100 let from_document_tokens = ctx.with_document_tokens(|tokens| {
101 if tokens.is_empty() {
102 return None;
103 }
104
105 let mut out = Vec::new();
106 let mut prev_is_period = false;
107 for token in tokens {
108 let Some((start, end)) = token_with_span_offsets(ctx.sql, token) else {
109 continue;
110 };
111 if start < ctx.statement_range.start || end > ctx.statement_range.end {
112 continue;
113 }
114
115 match &token.token {
116 Token::Period => {
117 prev_is_period = true;
118 continue;
119 }
120 Token::Whitespace(_) => continue,
121 _ => {}
122 }
123
124 if let Token::Word(word) = &token.token {
125 let after_period = prev_is_period;
126 prev_is_period = false;
127 if after_period {
128 continue;
129 }
130 if !source_word_matches(ctx.sql, start, end, word.value.as_str()) {
134 return None;
135 }
136 if is_tracked_keyword(word.value.as_str())
137 && !is_excluded_keyword(word.value.as_str())
138 && !token_is_ignored(word.value.as_str(), ignore_words, ignore_words_regex)
139 {
140 let Some(local_start) = start.checked_sub(ctx.statement_range.start) else {
141 continue;
142 };
143 let Some(local_end) = end.checked_sub(ctx.statement_range.start) else {
144 continue;
145 };
146 out.push(KeywordCandidate {
147 value: word.value.clone(),
148 start: local_start,
149 end: local_end,
150 });
151 }
152 } else {
153 prev_is_period = false;
154 }
155 }
156 Some(out)
157 });
158
159 if let Some(tokens) = from_document_tokens {
160 return tokens;
161 }
162
163 keyword_tokens(
164 ctx.statement_sql(),
165 ignore_words,
166 ignore_words_regex,
167 ctx.dialect(),
168 )
169}
170
171fn keyword_tokens(
172 sql: &str,
173 ignore_words: &HashSet<String>,
174 ignore_words_regex: Option<&Regex>,
175 dialect: Dialect,
176) -> Vec<KeywordCandidate> {
177 let dialect = dialect.to_sqlparser_dialect();
178 let mut tokenizer = Tokenizer::new(dialect.as_ref(), sql);
179 let Ok(tokens) = tokenizer.tokenize_with_location() else {
180 return Vec::new();
181 };
182
183 let mut prev_is_period = false;
186 let mut out = Vec::new();
187 for token in &tokens {
188 match &token.token {
189 Token::Period => {
190 prev_is_period = true;
191 continue;
192 }
193 Token::Whitespace(_) => continue,
194 Token::Word(word) => {
195 let after_period = prev_is_period;
196 prev_is_period = false;
197 if after_period {
198 continue;
199 }
200 if is_tracked_keyword(word.value.as_str())
201 && !is_excluded_keyword(word.value.as_str())
202 && !token_is_ignored(word.value.as_str(), ignore_words, ignore_words_regex)
203 {
204 if let Some((start, end)) = token_with_span_offsets(sql, token) {
205 out.push(KeywordCandidate {
206 value: word.value.clone(),
207 start,
208 end,
209 });
210 }
211 }
212 }
213 _ => {
214 prev_is_period = false;
215 }
216 }
217 }
218 out
219}
220
221fn keyword_autofix_edits(
222 ctx: &LintContext,
223 keywords: &[KeywordCandidate],
224 policy: CapitalisationPolicy,
225) -> Vec<IssuePatchEdit> {
226 let resolved_policy = if policy == CapitalisationPolicy::Consistent {
228 resolve_consistent_policy(keywords)
229 } else {
230 policy
231 };
232
233 let mut edits = Vec::new();
234
235 for candidate in keywords {
236 let Some(replacement) = keyword_case_replacement(candidate.value.as_str(), resolved_policy)
237 else {
238 continue;
239 };
240 if replacement == candidate.value {
241 continue;
242 }
243
244 edits.push(IssuePatchEdit::new(
245 ctx.span_from_statement_offset(candidate.start, candidate.end),
246 replacement,
247 ));
248 }
249
250 edits.sort_by_key(|edit| (edit.span.start, edit.span.end));
251 edits.dedup_by(|left, right| {
252 left.span.start == right.span.start
253 && left.span.end == right.span.end
254 && left.replacement == right.replacement
255 });
256 edits
257}
258
259fn resolve_consistent_policy(keywords: &[KeywordCandidate]) -> CapitalisationPolicy {
266 const UPPER: u8 = 0b001;
267 const LOWER: u8 = 0b010;
268 const CAPITALISE: u8 = 0b100;
269
270 let mut refuted: u8 = 0;
271 let mut latest_possible = CapitalisationPolicy::Upper; for kw in keywords {
274 let v = kw.value.as_str();
275
276 let first_is_lower = v
278 .chars()
279 .find(|c| c.is_ascii_alphabetic())
280 .is_some_and(|c| c.is_ascii_lowercase());
281
282 if first_is_lower {
283 refuted |= UPPER | CAPITALISE;
284 if v != v.to_ascii_lowercase() {
285 refuted |= LOWER;
286 }
287 } else {
288 refuted |= LOWER;
289 if v != v.to_ascii_uppercase() {
290 refuted |= UPPER;
291 }
292 if v != capitalise_ascii_token(v) {
293 refuted |= CAPITALISE;
294 }
295 }
296
297 let possible = (UPPER | LOWER | CAPITALISE) & !refuted;
298 if possible == 0 {
299 return latest_possible;
300 }
301
302 if possible & UPPER != 0 {
304 latest_possible = CapitalisationPolicy::Upper;
305 } else if possible & LOWER != 0 {
306 latest_possible = CapitalisationPolicy::Lower;
307 } else {
308 latest_possible = CapitalisationPolicy::Capitalise;
309 }
310 }
311
312 latest_possible
313}
314
315fn keyword_case_replacement(value: &str, policy: CapitalisationPolicy) -> Option<String> {
316 match policy {
317 CapitalisationPolicy::Consistent => {
318 Some(value.to_ascii_lowercase())
321 }
322 CapitalisationPolicy::Lower => Some(value.to_ascii_lowercase()),
323 CapitalisationPolicy::Upper => Some(value.to_ascii_uppercase()),
324 CapitalisationPolicy::Capitalise => Some(capitalise_ascii_token(value)),
325 CapitalisationPolicy::Pascal
327 | CapitalisationPolicy::Camel
328 | CapitalisationPolicy::Snake => None,
329 }
330}
331
332fn capitalise_ascii_token(value: &str) -> String {
333 let mut out = String::with_capacity(value.len());
334 let mut seen_alpha = false;
335
336 for ch in value.chars() {
337 if !ch.is_ascii_alphabetic() {
338 out.push(ch);
339 continue;
340 }
341
342 if !seen_alpha {
343 out.push(ch.to_ascii_uppercase());
344 seen_alpha = true;
345 } else {
346 out.push(ch.to_ascii_lowercase());
347 }
348 }
349
350 out
351}
352
353fn token_with_span_offsets(sql: &str, token: &TokenWithSpan) -> Option<(usize, usize)> {
354 let start = line_col_to_offset(
355 sql,
356 token.span.start.line as usize,
357 token.span.start.column as usize,
358 )?;
359 let end = line_col_to_offset(
360 sql,
361 token.span.end.line as usize,
362 token.span.end.column as usize,
363 )?;
364 Some((start, end))
365}
366
367fn line_col_to_offset(sql: &str, line: usize, column: usize) -> Option<usize> {
368 if line == 0 || column == 0 {
369 return None;
370 }
371
372 let mut current_line = 1usize;
373 let mut current_col = 1usize;
374
375 for (offset, ch) in sql.char_indices() {
376 if current_line == line && current_col == column {
377 return Some(offset);
378 }
379
380 if ch == '\n' {
381 current_line += 1;
382 current_col = 1;
383 } else {
384 current_col += 1;
385 }
386 }
387
388 if current_line == line && current_col == column {
389 return Some(sql.len());
390 }
391
392 None
393}
394
395fn source_word_matches(sql: &str, start: usize, end: usize, value: &str) -> bool {
396 let Some(raw) = sql.get(start..end) else {
397 return false;
398 };
399 let normalized = raw.trim_matches(|ch| matches!(ch, '"' | '`' | '[' | ']'));
400 normalized.eq_ignore_ascii_case(value)
401}
402
403fn is_tracked_keyword(value: &str) -> bool {
404 matches!(
405 value.to_ascii_uppercase().as_str(),
406 "SELECT"
407 | "FROM"
408 | "WHERE"
409 | "JOIN"
410 | "LEFT"
411 | "RIGHT"
412 | "FULL"
413 | "INNER"
414 | "OUTER"
415 | "ON"
416 | "GROUP"
417 | "BY"
418 | "ORDER"
419 | "HAVING"
420 | "UNION"
421 | "INSERT"
422 | "INTO"
423 | "UPDATE"
424 | "DELETE"
425 | "CREATE"
426 | "ALTER"
427 | "TABLE"
428 | "TYPE"
429 | "WITH"
430 | "AS"
431 | "CASE"
432 | "WHEN"
433 | "THEN"
434 | "ELSE"
435 | "END"
436 | "AND"
437 | "OR"
438 | "NOT"
439 | "IS"
440 | "IN"
441 | "EXISTS"
442 | "DISTINCT"
443 | "LIMIT"
444 | "OFFSET"
445 | "INTERVAL"
446 | "YEAR"
447 | "MONTH"
448 | "DAY"
449 | "HOUR"
450 | "MINUTE"
451 | "SECOND"
452 | "WEEK"
453 | "MONDAY"
454 | "TUESDAY"
455 | "WEDNESDAY"
456 | "THURSDAY"
457 | "FRIDAY"
458 | "SATURDAY"
459 | "SUNDAY"
460 | "CUBE"
461 | "CAST"
462 | "COALESCE"
463 | "SAFE_CAST"
464 | "TRY_CAST"
465 | "ASC"
466 | "DESC"
467 | "CROSS"
468 | "NATURAL"
469 | "OVER"
470 | "PARTITION"
471 | "BETWEEN"
472 | "LIKE"
473 | "SET"
474 | "QUALIFY"
475 | "LATERAL"
476 | "ROLLUP"
477 | "GROUPING"
478 | "SETS"
479 | "ALL"
480 | "ANY"
481 | "SOME"
482 | "EXCEPT"
483 | "INTERSECT"
484 | "VALUES"
485 | "DROP"
486 | "IF"
487 | "VIEW"
488 | "USING"
489 | "FETCH"
490 | "NEXT"
491 | "ROWS"
492 | "ONLY"
493 | "FIRST"
494 | "LAST"
495 | "RECURSIVE"
496 | "WINDOW"
497 | "RANGE"
498 | "UNBOUNDED"
499 | "PRECEDING"
500 | "FOLLOWING"
501 | "CURRENT"
502 | "ROW"
503 | "NULLS"
504 | "TOP"
505 | "PERCENT"
506 | "REPLACE"
507 | "GRANT"
508 | "REVOKE"
509 )
510}
511
512fn is_excluded_keyword(value: &str) -> bool {
513 matches!(
514 value.to_ascii_uppercase().as_str(),
515 "NULL"
516 | "TRUE"
517 | "FALSE"
518 | "INT"
519 | "INTEGER"
520 | "BIGINT"
521 | "SMALLINT"
522 | "TINYINT"
523 | "VARCHAR"
524 | "CHAR"
525 | "TEXT"
526 | "BOOLEAN"
527 | "BOOL"
528 | "STRING"
529 | "INT64"
530 | "FLOAT64"
531 | "BYTES"
532 | "NUMERIC"
533 | "DECIMAL"
534 | "FLOAT"
535 | "DOUBLE"
536 | "DATE"
537 | "TIME"
538 | "TIMESTAMP"
539 | "INTERVAL"
540 | "STRUCT"
541 | "ARRAY"
542 | "MAP"
543 | "ENUM"
544 | "COALESCE"
546 | "CAST"
547 | "SAFE_CAST"
548 | "TRY_CAST"
549 | "ANY"
550 | "SOME"
551 | "REPLACE"
552 | "TYPE"
555 )
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561 use crate::linter::config::LintConfig;
562 use crate::parser::parse_sql;
563 use crate::types::IssueAutofixApplicability;
564
565 fn run(sql: &str) -> Vec<Issue> {
566 let statements = parse_sql(sql).expect("parse");
567 let rule = CapitalisationKeywords::default();
568 statements
569 .iter()
570 .enumerate()
571 .flat_map(|(index, statement)| {
572 rule.check(
573 statement,
574 &LintContext {
575 sql,
576 statement_range: 0..sql.len(),
577 statement_index: index,
578 },
579 )
580 })
581 .collect()
582 }
583
584 fn apply_issue_autofix(sql: &str, issue: &Issue) -> Option<String> {
585 let autofix = issue.autofix.as_ref()?;
586 let mut out = sql.to_string();
587 let mut edits = autofix.edits.clone();
588 edits.sort_by_key(|edit| (edit.span.start, edit.span.end));
589 for edit in edits.into_iter().rev() {
590 out.replace_range(edit.span.start..edit.span.end, &edit.replacement);
591 }
592 Some(out)
593 }
594
595 #[test]
596 fn flags_mixed_keyword_case() {
597 let issues = run("SELECT a from t");
598 assert_eq!(issues.len(), 1);
599 assert_eq!(issues[0].code, issue_codes::LINT_CP_001);
600 }
601
602 #[test]
603 fn emits_safe_autofix_for_mixed_keyword_case() {
604 let sql = "SELECT a from t";
605 let issues = run(sql);
606 assert_eq!(issues.len(), 1);
607 let autofix = issues[0].autofix.as_ref().expect("autofix metadata");
608 assert_eq!(autofix.applicability, IssueAutofixApplicability::Safe);
609 let fixed = apply_issue_autofix(sql, &issues[0]).expect("apply autofix");
610 assert_eq!(fixed, "SELECT a FROM t");
611 }
612
613 #[test]
614 fn does_not_flag_consistent_keyword_case() {
615 assert!(run("SELECT a FROM t").is_empty());
616 }
617
618 #[test]
619 fn does_not_flag_keyword_words_in_strings_or_comments() {
620 let sql = "SELECT 'select from where' AS txt -- select from where\nFROM t";
621 assert!(run(sql).is_empty());
622 }
623
624 #[test]
625 fn upper_policy_flags_lowercase_keywords() {
626 let config = LintConfig {
627 enabled: true,
628 disabled_rules: vec![],
629 rule_configs: std::collections::BTreeMap::from([(
630 "capitalisation.keywords".to_string(),
631 serde_json::json!({"capitalisation_policy": "upper"}),
632 )]),
633 };
634 let rule = CapitalisationKeywords::from_config(&config);
635 let sql = "select a from t";
636 let statements = parse_sql(sql).expect("parse");
637 let issues = rule.check(
638 &statements[0],
639 &LintContext {
640 sql,
641 statement_range: 0..sql.len(),
642 statement_index: 0,
643 },
644 );
645 assert_eq!(issues.len(), 1);
646 }
647
648 #[test]
649 fn upper_policy_emits_uppercase_autofix() {
650 let config = LintConfig {
651 enabled: true,
652 disabled_rules: vec![],
653 rule_configs: std::collections::BTreeMap::from([(
654 "capitalisation.keywords".to_string(),
655 serde_json::json!({"capitalisation_policy": "upper"}),
656 )]),
657 };
658 let rule = CapitalisationKeywords::from_config(&config);
659 let sql = "select a from t";
660 let statements = parse_sql(sql).expect("parse");
661 let issues = rule.check(
662 &statements[0],
663 &LintContext {
664 sql,
665 statement_range: 0..sql.len(),
666 statement_index: 0,
667 },
668 );
669 assert_eq!(issues.len(), 1);
670 let fixed = apply_issue_autofix(sql, &issues[0]).expect("apply autofix");
671 assert_eq!(fixed, "SELECT a FROM t");
672 }
673
674 #[test]
675 fn camel_policy_violation_remains_report_only() {
676 let config = LintConfig {
677 enabled: true,
678 disabled_rules: vec![],
679 rule_configs: std::collections::BTreeMap::from([(
680 "capitalisation.keywords".to_string(),
681 serde_json::json!({"capitalisation_policy": "camel"}),
682 )]),
683 };
684 let rule = CapitalisationKeywords::from_config(&config);
685 let sql = "SELECT a FROM t";
686 let statements = parse_sql(sql).expect("parse");
687 let issues = rule.check(
688 &statements[0],
689 &LintContext {
690 sql,
691 statement_range: 0..sql.len(),
692 statement_index: 0,
693 },
694 );
695 assert_eq!(issues.len(), 1);
696 assert!(
697 issues[0].autofix.is_none(),
698 "camel/pascal/snake are report-only in current CP001 autofix scope"
699 );
700 }
701
702 #[test]
703 fn ignore_words_excludes_keywords_from_check() {
704 let config = LintConfig {
705 enabled: true,
706 disabled_rules: vec![],
707 rule_configs: std::collections::BTreeMap::from([(
708 "LINT_CP_001".to_string(),
709 serde_json::json!({"ignore_words": ["FROM"]}),
710 )]),
711 };
712 let rule = CapitalisationKeywords::from_config(&config);
713 let sql = "SELECT a from t";
714 let statements = parse_sql(sql).expect("parse");
715 let issues = rule.check(
716 &statements[0],
717 &LintContext {
718 sql,
719 statement_range: 0..sql.len(),
720 statement_index: 0,
721 },
722 );
723 assert!(issues.is_empty());
724 }
725
726 #[test]
727 fn ignore_words_regex_excludes_keywords_from_check() {
728 let config = LintConfig {
729 enabled: true,
730 disabled_rules: vec![],
731 rule_configs: std::collections::BTreeMap::from([(
732 "capitalisation.keywords".to_string(),
733 serde_json::json!({"ignore_words_regex": "^from$"}),
734 )]),
735 };
736 let rule = CapitalisationKeywords::from_config(&config);
737 let sql = "SELECT a from t";
738 let statements = parse_sql(sql).expect("parse");
739 let issues = rule.check(
740 &statements[0],
741 &LintContext {
742 sql,
743 statement_range: 0..sql.len(),
744 statement_index: 0,
745 },
746 );
747 assert!(issues.is_empty());
748 }
749}