1use std::collections::HashSet;
6
7use crate::linter::config::LintConfig;
8use crate::linter::rule::{LintContext, LintRule};
9use crate::types::{issue_codes, Dialect, Issue, IssueAutofixApplicability, IssuePatchEdit, Span};
10use regex::Regex;
11use sqlparser::ast::Statement;
12use sqlparser::tokenizer::{Token, TokenWithSpan, Tokenizer};
13
14use super::capitalisation_policy_helpers::{
15 ignored_words_from_config, ignored_words_regex_from_config, token_is_ignored,
16 tokens_violate_policy, CapitalisationPolicy,
17};
18
19pub struct CapitalisationTypes {
20 policy: CapitalisationPolicy,
21 ignore_words: HashSet<String>,
22 ignore_words_regex: Option<Regex>,
23}
24
25impl CapitalisationTypes {
26 pub fn from_config(config: &LintConfig) -> Self {
27 Self {
28 policy: CapitalisationPolicy::from_rule_config(
29 config,
30 issue_codes::LINT_CP_005,
31 "extended_capitalisation_policy",
32 ),
33 ignore_words: ignored_words_from_config(config, issue_codes::LINT_CP_005),
34 ignore_words_regex: ignored_words_regex_from_config(config, issue_codes::LINT_CP_005),
35 }
36 }
37}
38
39impl Default for CapitalisationTypes {
40 fn default() -> Self {
41 Self {
42 policy: CapitalisationPolicy::Consistent,
43 ignore_words: HashSet::new(),
44 ignore_words_regex: None,
45 }
46 }
47}
48
49impl LintRule for CapitalisationTypes {
50 fn code(&self) -> &'static str {
51 issue_codes::LINT_CP_005
52 }
53
54 fn name(&self) -> &'static str {
55 "Type capitalisation"
56 }
57
58 fn description(&self) -> &'static str {
59 "Inconsistent capitalisation of datatypes."
60 }
61
62 fn check(&self, _statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
63 let types =
64 type_tokens_for_context(ctx, &self.ignore_words, self.ignore_words_regex.as_ref());
65 let type_values = types
66 .iter()
67 .map(|candidate| candidate.value.clone())
68 .collect::<Vec<_>>();
69 if !tokens_violate_policy(&type_values, self.policy) {
70 return Vec::new();
71 }
72
73 let autofix_edits = type_autofix_edits(ctx, &types, self.policy);
74
75 if autofix_edits.is_empty() {
77 return vec![Issue::info(
78 issue_codes::LINT_CP_005,
79 "Type names use inconsistent capitalisation.",
80 )
81 .with_statement(ctx.statement_index)];
82 }
83
84 autofix_edits
85 .into_iter()
86 .map(|edit| {
87 let span = Span::new(edit.span.start, edit.span.end);
88 Issue::info(
89 issue_codes::LINT_CP_005,
90 "Type names use inconsistent capitalisation.",
91 )
92 .with_statement(ctx.statement_index)
93 .with_span(span)
94 .with_autofix_edits(IssueAutofixApplicability::Safe, vec![edit])
95 })
96 .collect()
97 }
98}
99
100#[derive(Clone)]
101struct TypeCandidate {
102 value: String,
103 start: usize,
104 end: usize,
105}
106
107fn type_tokens_for_context(
108 ctx: &LintContext,
109 ignore_words: &HashSet<String>,
110 ignore_words_regex: Option<&Regex>,
111) -> Vec<TypeCandidate> {
112 let from_document_tokens = ctx.with_document_tokens(|tokens| {
113 if tokens.is_empty() {
114 return None;
115 }
116
117 let mut statement_tokens = Vec::new();
118 for token in tokens {
119 let Some((start, end)) = token_with_span_offsets(ctx.sql, token) else {
120 continue;
121 };
122 if start < ctx.statement_range.start || end > ctx.statement_range.end {
123 continue;
124 }
125
126 if let Token::Word(word) = &token.token {
127 if !source_word_matches(ctx.sql, start, end, word.value.as_str()) {
131 return None;
132 }
133 }
134
135 statement_tokens.push(token.clone());
136 }
137
138 Some(type_candidates_from_tokens(
139 ctx.sql,
140 ctx.statement_range.start,
141 &statement_tokens,
142 ignore_words,
143 ignore_words_regex,
144 ))
145 });
146
147 if let Some(tokens) = from_document_tokens {
148 return tokens;
149 }
150
151 type_tokens(
152 ctx.statement_sql(),
153 ignore_words,
154 ignore_words_regex,
155 ctx.dialect(),
156 )
157}
158
159fn type_tokens(
160 sql: &str,
161 ignore_words: &HashSet<String>,
162 ignore_words_regex: Option<&Regex>,
163 dialect: Dialect,
164) -> Vec<TypeCandidate> {
165 let dialect = dialect.to_sqlparser_dialect();
166 let mut tokenizer = Tokenizer::new(dialect.as_ref(), sql);
167 let Ok(tokens) = tokenizer.tokenize_with_location() else {
168 return Vec::new();
169 };
170
171 type_candidates_from_tokens(sql, 0, &tokens, ignore_words, ignore_words_regex)
172}
173
174fn type_candidates_from_tokens(
175 sql: &str,
176 statement_start: usize,
177 tokens: &[TokenWithSpan],
178 ignore_words: &HashSet<String>,
179 ignore_words_regex: Option<&Regex>,
180) -> Vec<TypeCandidate> {
181 let user_defined_types = collect_user_defined_type_names(tokens);
182
183 tokens
184 .iter()
185 .enumerate()
186 .filter_map(|(index, token)| {
187 if let Token::Word(word) = &token.token {
188 let is_candidate = word.quote_style.is_none()
189 && (is_tracked_type_name(word.value.as_str())
190 || user_defined_types.contains(&word.value.to_ascii_uppercase()))
191 && !token_is_ignored(word.value.as_str(), ignore_words, ignore_words_regex)
192 && !is_keyword_after_as(tokens, index)
193 && !is_constructor_or_function_call(tokens, index);
194 if is_candidate {
195 let (start, end) = token_with_span_offsets(sql, token)?;
196 let local_start = start.checked_sub(statement_start)?;
197 let local_end = end.checked_sub(statement_start)?;
198 return Some(TypeCandidate {
199 value: word.value.clone(),
200 start: local_start,
201 end: local_end,
202 });
203 }
204 }
205
206 None
207 })
208 .collect()
209}
210
211fn is_keyword_after_as(tokens: &[TokenWithSpan], index: usize) -> bool {
214 let Some(prev_index) = prev_non_trivia_index(tokens, index) else {
215 return false;
216 };
217 matches!(
218 &tokens[prev_index].token,
219 Token::Word(w) if w.value.eq_ignore_ascii_case("AS")
220 )
221}
222
223fn prev_non_trivia_index(tokens: &[TokenWithSpan], index: usize) -> Option<usize> {
224 if index == 0 {
225 return None;
226 }
227 let mut i = index - 1;
228 loop {
229 if !matches!(tokens[i].token, Token::Whitespace(_)) {
230 return Some(i);
231 }
232 if i == 0 {
233 return None;
234 }
235 i -= 1;
236 }
237}
238
239fn type_autofix_edits(
240 ctx: &LintContext,
241 types: &[TypeCandidate],
242 policy: CapitalisationPolicy,
243) -> Vec<IssuePatchEdit> {
244 let resolved_policy = if policy == CapitalisationPolicy::Consistent {
246 resolve_consistent_policy(types)
247 } else {
248 policy
249 };
250
251 let mut edits = Vec::new();
252
253 for candidate in types {
254 let Some(replacement) = type_case_replacement(candidate.value.as_str(), resolved_policy)
255 else {
256 continue;
257 };
258 if replacement == candidate.value {
259 continue;
260 }
261
262 edits.push(IssuePatchEdit::new(
263 ctx.span_from_statement_offset(candidate.start, candidate.end),
264 replacement,
265 ));
266 }
267
268 edits.sort_by_key(|edit| (edit.span.start, edit.span.end));
269 edits.dedup_by(|left, right| {
270 left.span.start == right.span.start
271 && left.span.end == right.span.end
272 && left.replacement == right.replacement
273 });
274 edits
275}
276
277fn type_case_replacement(value: &str, policy: CapitalisationPolicy) -> Option<String> {
278 match policy {
279 CapitalisationPolicy::Consistent => {
280 Some(value.to_ascii_lowercase())
282 }
283 CapitalisationPolicy::Lower => Some(value.to_ascii_lowercase()),
284 CapitalisationPolicy::Upper => Some(value.to_ascii_uppercase()),
285 CapitalisationPolicy::Capitalise => Some(capitalise_ascii_token(value)),
286 CapitalisationPolicy::Pascal
288 | CapitalisationPolicy::Camel
289 | CapitalisationPolicy::Snake => None,
290 }
291}
292
293fn resolve_consistent_policy(types: &[TypeCandidate]) -> CapitalisationPolicy {
297 const UPPER: u8 = 0b001;
298 const LOWER: u8 = 0b010;
299 const CAPITALISE: u8 = 0b100;
300
301 let mut refuted: u8 = 0;
302 let mut latest_possible = CapitalisationPolicy::Upper; for typ in types {
305 let v = typ.value.as_str();
306
307 let first_is_lower = v
308 .chars()
309 .find(|c| c.is_ascii_alphabetic())
310 .is_some_and(|c| c.is_ascii_lowercase());
311
312 if first_is_lower {
313 refuted |= UPPER | CAPITALISE;
314 if v != v.to_ascii_lowercase() {
315 refuted |= LOWER;
316 }
317 } else {
318 refuted |= LOWER;
319 if v != v.to_ascii_uppercase() {
320 refuted |= UPPER;
321 }
322 if v != capitalise_ascii_token(v) {
323 refuted |= CAPITALISE;
324 }
325 }
326
327 let possible = (UPPER | LOWER | CAPITALISE) & !refuted;
328 if possible == 0 {
329 return latest_possible;
330 }
331
332 if possible & UPPER != 0 {
333 latest_possible = CapitalisationPolicy::Upper;
334 } else if possible & LOWER != 0 {
335 latest_possible = CapitalisationPolicy::Lower;
336 } else {
337 latest_possible = CapitalisationPolicy::Capitalise;
338 }
339 }
340
341 latest_possible
342}
343
344fn capitalise_ascii_token(value: &str) -> String {
345 let mut out = String::with_capacity(value.len());
346 let mut seen_alpha = false;
347
348 for ch in value.chars() {
349 if !ch.is_ascii_alphabetic() {
350 out.push(ch);
351 continue;
352 }
353
354 if !seen_alpha {
355 out.push(ch.to_ascii_uppercase());
356 seen_alpha = true;
357 } else {
358 out.push(ch.to_ascii_lowercase());
359 }
360 }
361
362 out
363}
364
365fn token_with_span_offsets(sql: &str, token: &TokenWithSpan) -> Option<(usize, usize)> {
366 let start = line_col_to_offset(
367 sql,
368 token.span.start.line as usize,
369 token.span.start.column as usize,
370 )?;
371 let end = line_col_to_offset(
372 sql,
373 token.span.end.line as usize,
374 token.span.end.column as usize,
375 )?;
376 Some((start, end))
377}
378
379fn line_col_to_offset(sql: &str, line: usize, column: usize) -> Option<usize> {
380 if line == 0 || column == 0 {
381 return None;
382 }
383
384 let mut current_line = 1usize;
385 let mut current_col = 1usize;
386
387 for (offset, ch) in sql.char_indices() {
388 if current_line == line && current_col == column {
389 return Some(offset);
390 }
391
392 if ch == '\n' {
393 current_line += 1;
394 current_col = 1;
395 } else {
396 current_col += 1;
397 }
398 }
399
400 if current_line == line && current_col == column {
401 return Some(sql.len());
402 }
403
404 None
405}
406
407fn source_word_matches(sql: &str, start: usize, end: usize, value: &str) -> bool {
408 let Some(raw) = sql.get(start..end) else {
409 return false;
410 };
411 let normalized = raw.trim_matches(|ch| matches!(ch, '"' | '`' | '[' | ']'));
412 normalized.eq_ignore_ascii_case(value)
413}
414
415fn collect_user_defined_type_names(tokens: &[TokenWithSpan]) -> HashSet<String> {
416 let mut out = HashSet::new();
417
418 for index in 0..tokens.len() {
419 let Token::Word(first) = &tokens[index].token else {
420 continue;
421 };
422 let head = first.value.to_ascii_uppercase();
423 if head != "CREATE" && head != "ALTER" {
424 continue;
425 }
426
427 let Some(type_index) = next_non_trivia_index(tokens, index + 1) else {
428 continue;
429 };
430 let Token::Word(type_word) = &tokens[type_index].token else {
431 continue;
432 };
433 if !type_word.value.eq_ignore_ascii_case("TYPE") {
434 continue;
435 }
436
437 let Some(name_index) = next_non_trivia_index(tokens, type_index + 1) else {
438 continue;
439 };
440 let Token::Word(name_word) = &tokens[name_index].token else {
441 continue;
442 };
443 out.insert(name_word.value.to_ascii_uppercase());
444 }
445
446 out
447}
448
449fn next_non_trivia_index(tokens: &[TokenWithSpan], mut index: usize) -> Option<usize> {
450 while index < tokens.len() {
451 match &tokens[index].token {
452 Token::Whitespace(_) => index += 1,
453 _ => return Some(index),
454 }
455 }
456 None
457}
458
459fn is_constructor_or_function_call(tokens: &[TokenWithSpan], index: usize) -> bool {
471 let Token::Word(word) = &tokens[index].token else {
472 return false;
473 };
474 let Some(next_idx) = next_non_trivia_index(tokens, index + 1) else {
475 return false;
476 };
477 let upper = word.value.to_ascii_uppercase();
478 match &tokens[next_idx].token {
479 Token::LBracket => upper == "ARRAY",
481 Token::LParen => !type_takes_precision(&upper),
484 _ => false,
485 }
486}
487
488fn type_takes_precision(upper: &str) -> bool {
491 matches!(
492 upper,
493 "VARCHAR"
494 | "CHAR"
495 | "NUMERIC"
496 | "DECIMAL"
497 | "FLOAT"
498 | "DOUBLE"
499 | "TIMESTAMP"
500 | "TIME"
501 | "INT"
502 | "INTEGER"
503 | "BIGINT"
504 | "SMALLINT"
505 | "TINYINT"
506 )
507}
508
509fn is_tracked_type_name(value: &str) -> bool {
510 matches!(
511 value.to_ascii_uppercase().as_str(),
512 "INT"
513 | "INTEGER"
514 | "BIGINT"
515 | "SMALLINT"
516 | "TINYINT"
517 | "VARCHAR"
518 | "CHAR"
519 | "TEXT"
520 | "BOOLEAN"
521 | "BOOL"
522 | "STRING"
523 | "INT64"
524 | "FLOAT64"
525 | "BYTES"
526 | "DATE"
527 | "TIME"
528 | "TIMESTAMP"
529 | "INTERVAL"
530 | "NUMERIC"
531 | "DECIMAL"
532 | "FLOAT"
533 | "DOUBLE"
534 | "STRUCT"
535 | "ARRAY"
536 | "MAP"
537 | "ENUM"
538 )
539}
540
541#[cfg(test)]
542mod tests {
543 use super::*;
544 use crate::linter::config::LintConfig;
545 use crate::parser::parse_sql;
546 use crate::types::IssueAutofixApplicability;
547
548 fn run(sql: &str) -> Vec<Issue> {
549 let statements = parse_sql(sql).expect("parse");
550 let rule = CapitalisationTypes::default();
551 statements
552 .iter()
553 .enumerate()
554 .flat_map(|(index, statement)| {
555 rule.check(
556 statement,
557 &LintContext {
558 sql,
559 statement_range: 0..sql.len(),
560 statement_index: index,
561 },
562 )
563 })
564 .collect()
565 }
566
567 fn apply_issue_autofix(sql: &str, issue: &Issue) -> Option<String> {
568 let autofix = issue.autofix.as_ref()?;
569 let mut out = sql.to_string();
570 let mut edits = autofix.edits.clone();
571 edits.sort_by_key(|edit| (edit.span.start, edit.span.end));
572 for edit in edits.into_iter().rev() {
573 out.replace_range(edit.span.start..edit.span.end, &edit.replacement);
574 }
575 Some(out)
576 }
577
578 #[test]
579 fn flags_mixed_type_case() {
580 let issues = run("CREATE TABLE t (a INT, b varchar(10))");
581 assert_eq!(issues.len(), 1);
582 assert_eq!(issues[0].code, issue_codes::LINT_CP_005);
583 }
584
585 #[test]
586 fn emits_safe_autofix_for_mixed_type_case() {
587 let sql = "CREATE TABLE t (a INT, b varchar(10))";
588 let issues = run(sql);
589 assert_eq!(issues.len(), 1);
590 let autofix = issues[0].autofix.as_ref().expect("autofix metadata");
591 assert_eq!(autofix.applicability, IssueAutofixApplicability::Safe);
592 let fixed = apply_issue_autofix(sql, &issues[0]).expect("apply autofix");
593 assert_eq!(fixed, "CREATE TABLE t (a INT, b VARCHAR(10))");
594 }
595
596 #[test]
597 fn does_not_flag_consistent_type_case() {
598 assert!(run("CREATE TABLE t (a int, b varchar(10))").is_empty());
599 }
600
601 #[test]
602 fn does_not_flag_type_words_in_strings_or_comments() {
603 let sql = "SELECT 'INT varchar BOOLEAN' AS txt -- INT varchar\nFROM t";
604 assert!(run(sql).is_empty());
605 }
606
607 #[test]
608 fn upper_policy_flags_lowercase_type_name() {
609 let config = LintConfig {
610 enabled: true,
611 disabled_rules: vec![],
612 rule_configs: std::collections::BTreeMap::from([(
613 "LINT_CP_005".to_string(),
614 serde_json::json!({"extended_capitalisation_policy": "upper"}),
615 )]),
616 };
617 let rule = CapitalisationTypes::from_config(&config);
618 let sql = "CREATE TABLE t (a int)";
619 let statements = parse_sql(sql).expect("parse");
620 let issues = rule.check(
621 &statements[0],
622 &LintContext {
623 sql,
624 statement_range: 0..sql.len(),
625 statement_index: 0,
626 },
627 );
628 assert_eq!(issues.len(), 1);
629 }
630
631 #[test]
632 fn upper_policy_emits_uppercase_autofix() {
633 let config = LintConfig {
634 enabled: true,
635 disabled_rules: vec![],
636 rule_configs: std::collections::BTreeMap::from([(
637 "LINT_CP_005".to_string(),
638 serde_json::json!({"extended_capitalisation_policy": "upper"}),
639 )]),
640 };
641 let rule = CapitalisationTypes::from_config(&config);
642 let sql = "CREATE TABLE t (a int)";
643 let statements = parse_sql(sql).expect("parse");
644 let issues = rule.check(
645 &statements[0],
646 &LintContext {
647 sql,
648 statement_range: 0..sql.len(),
649 statement_index: 0,
650 },
651 );
652 assert_eq!(issues.len(), 1);
653 let fixed = apply_issue_autofix(sql, &issues[0]).expect("apply autofix");
654 assert_eq!(fixed, "CREATE TABLE t (a INT)");
655 }
656
657 #[test]
658 fn camel_policy_violation_remains_report_only() {
659 let config = LintConfig {
660 enabled: true,
661 disabled_rules: vec![],
662 rule_configs: std::collections::BTreeMap::from([(
663 "LINT_CP_005".to_string(),
664 serde_json::json!({"extended_capitalisation_policy": "camel"}),
665 )]),
666 };
667 let rule = CapitalisationTypes::from_config(&config);
668 let sql = "CREATE TABLE t (a INT)";
669 let statements = parse_sql(sql).expect("parse");
670 let issues = rule.check(
671 &statements[0],
672 &LintContext {
673 sql,
674 statement_range: 0..sql.len(),
675 statement_index: 0,
676 },
677 );
678 assert_eq!(issues.len(), 1);
679 assert!(
680 issues[0].autofix.is_none(),
681 "camel/pascal/snake are report-only in current CP005 autofix scope"
682 );
683 }
684
685 #[test]
686 fn ignore_words_regex_excludes_types_from_check() {
687 let config = LintConfig {
688 enabled: true,
689 disabled_rules: vec![],
690 rule_configs: std::collections::BTreeMap::from([(
691 "LINT_CP_005".to_string(),
692 serde_json::json!({"ignore_words_regex": "^varchar$"}),
693 )]),
694 };
695 let rule = CapitalisationTypes::from_config(&config);
696 let sql = "CREATE TABLE t (a INT, b varchar(10))";
697 let statements = parse_sql(sql).expect("parse");
698 let issues = rule.check(
699 &statements[0],
700 &LintContext {
701 sql,
702 statement_range: 0..sql.len(),
703 statement_index: 0,
704 },
705 );
706 assert!(issues.is_empty());
707 }
708
709 #[test]
710 fn array_constructor_is_not_a_type_candidate() {
711 assert!(run("SELECT COALESCE(x, ARRAY[]::text[]) FROM t").is_empty());
714 }
715
716 #[test]
717 fn date_function_is_not_a_type_candidate() {
718 assert!(run("SELECT DATE(created_at), col::text FROM t").is_empty());
720 }
721
722 #[test]
723 fn date_cast_is_still_a_type_candidate() {
724 let issues = run("SELECT col::DATE, x::text FROM t");
727 assert_eq!(issues.len(), 1);
728 }
729}