1use crate::linter::config::LintConfig;
8use crate::linter::rule::{LintContext, LintRule};
9use crate::linter::visit::visit_expressions;
10use crate::types::{issue_codes, Issue, IssueAutofixApplicability, IssuePatchEdit, Span};
11use sqlparser::ast::{CastKind, DataType, Expr, Spanned, Statement};
12
13#[derive(Clone, Copy, Debug, Eq, PartialEq)]
18enum PreferredTypeCastingStyle {
19 Consistent,
20 Shorthand,
21 Cast,
22 Convert,
23}
24
25impl PreferredTypeCastingStyle {
26 fn from_config(config: &LintConfig) -> Self {
27 match config
28 .rule_option_str(issue_codes::LINT_CV_011, "preferred_type_casting_style")
29 .unwrap_or("consistent")
30 .to_ascii_lowercase()
31 .as_str()
32 {
33 "shorthand" => Self::Shorthand,
34 "cast" => Self::Cast,
35 "convert" => Self::Convert,
36 _ => Self::Consistent,
37 }
38 }
39}
40
41#[derive(Clone, Copy, Debug, Eq, PartialEq)]
46enum CastStyle {
47 FunctionCast,
48 DoubleColon,
49 Convert,
50}
51
52struct CastInstance {
54 style: CastStyle,
55 start: usize,
57 end: usize,
58 has_comments: bool,
60 is_3arg_convert: bool,
62}
63
64pub struct ConventionCastingStyle {
69 preferred_style: PreferredTypeCastingStyle,
70}
71
72impl ConventionCastingStyle {
73 pub fn from_config(config: &LintConfig) -> Self {
74 Self {
75 preferred_style: PreferredTypeCastingStyle::from_config(config),
76 }
77 }
78}
79
80impl Default for ConventionCastingStyle {
81 fn default() -> Self {
82 Self {
83 preferred_style: PreferredTypeCastingStyle::Consistent,
84 }
85 }
86}
87
88impl LintRule for ConventionCastingStyle {
89 fn code(&self) -> &'static str {
90 issue_codes::LINT_CV_011
91 }
92
93 fn name(&self) -> &'static str {
94 "Casting style"
95 }
96
97 fn description(&self) -> &'static str {
98 "Enforce consistent type casting style."
99 }
100
101 fn check(&self, statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
102 let sql = ctx.sql;
103 let casts = collect_cast_instances(statement, sql);
104
105 if casts.is_empty() {
106 return Vec::new();
107 }
108
109 let target = match self.preferred_style {
111 PreferredTypeCastingStyle::Consistent => casts[0].style,
112 PreferredTypeCastingStyle::Shorthand => CastStyle::DoubleColon,
113 PreferredTypeCastingStyle::Cast => CastStyle::FunctionCast,
114 PreferredTypeCastingStyle::Convert => CastStyle::Convert,
115 };
116
117 let has_violation = casts.iter().any(|c| c.style != target);
119 if !has_violation {
120 return Vec::new();
121 }
122
123 let message = match self.preferred_style {
124 PreferredTypeCastingStyle::Consistent => {
125 "Use consistent casting style (avoid mixing CAST styles)."
126 }
127 PreferredTypeCastingStyle::Shorthand => "Use `::` shorthand casting style.",
128 PreferredTypeCastingStyle::Cast => "Use `CAST(...)` style casts.",
129 PreferredTypeCastingStyle::Convert => "Use `CONVERT(...)` style casts.",
130 };
131
132 let mut issues = Vec::new();
136 for cast in &casts {
137 if cast.style == target {
138 continue;
139 }
140
141 let mut issue =
142 Issue::info(issue_codes::LINT_CV_011, message).with_statement(ctx.statement_index);
143
144 if !cast.is_3arg_convert && !cast.has_comments {
145 let cast_text = &sql[cast.start..cast.end];
146 if let Some(replacement) = convert_cast(cast_text, cast.style, target) {
147 issue = issue.with_autofix_edits(
148 IssueAutofixApplicability::Unsafe,
149 vec![IssuePatchEdit::new(
150 Span::new(cast.start, cast.end),
151 replacement,
152 )],
153 );
154 }
155 }
156
157 issues.push(issue);
158 }
159
160 issues
161 }
162}
163
164fn collect_cast_instances(statement: &Statement, sql: &str) -> Vec<CastInstance> {
169 let mut casts = Vec::new();
170
171 visit_expressions(statement, &mut |expr| {
172 match expr {
173 Expr::Cast {
174 kind,
175 expr: inner,
176 data_type,
177 ..
178 } => {
179 let style = match kind {
180 CastKind::DoubleColon => CastStyle::DoubleColon,
181 CastKind::Cast | CastKind::TryCast | CastKind::SafeCast => {
182 CastStyle::FunctionCast
183 }
184 };
185
186 let is_inner_chain = matches!(
189 inner.as_ref(),
190 Expr::Cast {
191 kind: CastKind::DoubleColon,
192 ..
193 }
194 );
195
196 let inner_span = find_cast_span(sql, inner, kind.clone(), data_type);
198 if let Some((start, end)) = inner_span {
199 let text = &sql[start..end];
200 let has_comments = text.contains("--") || text.contains("/*");
201
202 if style == CastStyle::DoubleColon && is_inner_chain {
203 casts.retain(|c: &CastInstance| c.start < start || c.end > end);
205 }
206
207 casts.push(CastInstance {
208 style,
209 start,
210 end,
211 has_comments,
212 is_3arg_convert: false,
213 });
214 }
215 }
216 Expr::Function(function)
217 if function.name.to_string().eq_ignore_ascii_case("CONVERT") =>
218 {
219 if let Some((start, mut end)) = expr_span_offsets(sql, expr) {
220 if end < sql.len() && sql.as_bytes().get(end) == Some(&b')') {
223 end += 1;
224 } else {
225 if let Some(close) = find_matching_close_paren(&sql[end..]) {
227 end += close + 1;
228 }
229 }
230
231 let text = &sql[start..end];
232 let has_comments = text.contains("--") || text.contains("/*");
233
234 let arg_count = match &function.args {
235 sqlparser::ast::FunctionArguments::List(list) => list.args.len(),
236 _ => 0,
237 };
238
239 casts.push(CastInstance {
240 style: CastStyle::Convert,
241 start,
242 end,
243 has_comments,
244 is_3arg_convert: arg_count > 2,
245 });
246 }
247 }
248 _ => {}
249 }
250 });
251
252 for (start, end) in scan_parenthesized_shorthand_cast_spans(sql) {
255 if casts.iter().any(|cast| {
256 cast.start == start && cast.end == end && cast.style == CastStyle::DoubleColon
257 }) {
258 continue;
259 }
260 let text = &sql[start..end];
261 casts.push(CastInstance {
262 style: CastStyle::DoubleColon,
263 start,
264 end,
265 has_comments: text.contains("--") || text.contains("/*"),
266 is_3arg_convert: false,
267 });
268 }
269
270 casts.sort_by_key(|c| c.start);
272
273 let mut deduped: Vec<CastInstance> = Vec::with_capacity(casts.len());
278 for cast in casts {
279 let mut dominated = false;
280 let mut replace_index = None;
281
282 for (index, other) in deduped.iter().enumerate() {
283 if other.start <= cast.start && other.end >= cast.end {
284 dominated = true;
285 break;
286 }
287 if cast.start <= other.start && cast.end >= other.end {
288 replace_index = Some(index);
289 break;
290 }
291 if cast.style == other.style
292 && spans_overlap(cast.start, cast.end, other.start, other.end)
293 {
294 let cast_len = cast.end.saturating_sub(cast.start);
295 let other_len = other.end.saturating_sub(other.start);
296 if cast_len > other_len {
297 replace_index = Some(index);
298 } else {
299 dominated = true;
300 }
301 break;
302 }
303 }
304
305 if dominated {
306 continue;
307 }
308
309 if let Some(index) = replace_index {
310 deduped[index] = cast;
311 } else {
312 deduped.push(cast);
313 }
314 }
315
316 deduped.sort_by_key(|cast| (cast.start, cast.end, cast.style as u8));
317 deduped.dedup_by(|left, right| left.start == right.start && left.end == right.end);
318 deduped
319}
320
321fn spans_overlap(left_start: usize, left_end: usize, right_start: usize, right_end: usize) -> bool {
322 left_start < right_end && right_start < left_end
323}
324
325fn scan_parenthesized_shorthand_cast_spans(sql: &str) -> Vec<(usize, usize)> {
326 let bytes = sql.as_bytes();
327 let mut out = Vec::new();
328 let mut index = 0usize;
329
330 while index + 1 < bytes.len() {
331 if bytes[index] != b':' || bytes[index + 1] != b':' {
332 index += 1;
333 continue;
334 }
335
336 let mut lhs_end = index;
337 while lhs_end > 0 && bytes[lhs_end - 1].is_ascii_whitespace() {
338 lhs_end -= 1;
339 }
340 if lhs_end == 0 || bytes[lhs_end - 1] != b')' {
341 index += 2;
342 continue;
343 }
344 let close_paren = lhs_end - 1;
345 let Some(open_paren) = find_matching_open_paren(bytes, close_paren) else {
346 index += 2;
347 continue;
348 };
349
350 let Some(type_end) = scan_parenthesized_shorthand_type_end(bytes, index + 2) else {
351 index += 2;
352 continue;
353 };
354
355 out.push((open_paren, type_end));
356 index = type_end;
357 }
358
359 out
360}
361
362fn scan_parenthesized_shorthand_type_end(bytes: &[u8], start: usize) -> Option<usize> {
363 let mut index = start;
364 let mut depth = 0i32;
365 let mut saw_any = false;
366
367 while index < bytes.len() {
368 match bytes[index] {
369 b'(' => {
370 depth += 1;
371 saw_any = true;
372 index += 1;
373 }
374 b')' if depth > 0 => {
375 depth -= 1;
376 index += 1;
377 }
378 b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_' | b'.' => {
379 saw_any = true;
380 index += 1;
381 }
382 b',' if depth > 0 => index += 1,
383 b' ' | b'\t' | b'\n' | b'\r' if depth > 0 => index += 1,
384 _ => break,
385 }
386 }
387
388 if saw_any {
389 Some(index)
390 } else {
391 None
392 }
393}
394
395fn find_matching_open_paren(bytes: &[u8], close_paren: usize) -> Option<usize> {
396 if bytes.get(close_paren).copied() != Some(b')') {
397 return None;
398 }
399 let mut depth = 1i32;
400 let mut cursor = close_paren;
401 while cursor > 0 {
402 cursor -= 1;
403 match bytes[cursor] {
404 b')' => depth += 1,
405 b'(' => {
406 depth -= 1;
407 if depth == 0 {
408 return Some(cursor);
409 }
410 }
411 _ => {}
412 }
413 }
414 None
415}
416
417fn find_cast_span(
426 sql: &str,
427 inner: &Expr,
428 kind: CastKind,
429 data_type: &DataType,
430) -> Option<(usize, usize)> {
431 match kind {
432 CastKind::Cast | CastKind::TryCast | CastKind::SafeCast => {
433 let (inner_start, inner_end) = expr_span_offsets(sql, inner)?;
434
435 let before = &sql[..inner_start];
437 let paren_pos = before.rfind('(')?;
438 let before_paren = before[..paren_pos].trim_end();
439 let kw = match kind {
440 CastKind::TryCast => "TRY_CAST",
441 CastKind::SafeCast => "SAFE_CAST",
442 _ => "CAST",
443 };
444 let kw_len = kw.len();
445 if before_paren.len() < kw_len {
446 return None;
447 }
448 let kw_candidate = &before_paren[before_paren.len() - kw_len..];
449 if !kw_candidate.eq_ignore_ascii_case(kw) {
450 return None;
451 }
452 let start = before_paren.len() - kw_len;
453
454 let after = &sql[inner_end..];
456 let close = find_matching_close_paren(after)?;
457 let end = inner_end + close + 1;
458
459 Some((start, end))
460 }
461 CastKind::DoubleColon => {
462 let base = deepest_base_expr(inner);
464 let (base_start, base_end) = expr_span_offsets(sql, base)?;
465
466 let type_str = data_type.to_string();
468 let mut pos = base_end;
469 loop {
470 let after = &sql[pos..];
471 let dc_pos = match after.find("::") {
472 Some(p) => p,
473 None => break,
474 };
475 let type_start = pos + dc_pos + 2;
476 let type_len = source_type_len(sql, type_start, &type_str);
477 if type_len == 0 {
478 break;
479 }
480 pos = type_start + type_len;
481 let this_type = &sql[type_start..pos];
483 if this_type.eq_ignore_ascii_case(&type_str) {
484 break;
485 }
486 }
487
488 Some((base_start, pos))
489 }
490 }
491}
492
493fn deepest_base_expr(expr: &Expr) -> &Expr {
496 match expr {
497 Expr::Cast {
498 kind: CastKind::DoubleColon,
499 expr: inner,
500 ..
501 } => deepest_base_expr(inner),
502 _ => expr,
503 }
504}
505
506fn find_matching_close_paren(text: &str) -> Option<usize> {
509 let mut depth = 0i32;
510 let bytes = text.as_bytes();
511 let mut i = 0;
512 while i < bytes.len() {
513 match bytes[i] {
514 b'(' => depth += 1,
515 b')' => {
516 if depth == 0 {
517 return Some(i);
518 }
519 depth -= 1;
520 }
521 b'\'' | b'"' => {
522 let quote = bytes[i];
523 i += 1;
524 while i < bytes.len() && bytes[i] != quote {
525 if bytes[i] == b'\\' {
526 i += 1;
527 }
528 i += 1;
529 }
530 }
531 _ => {}
532 }
533 i += 1;
534 }
535 None
536}
537
538fn source_type_len(sql: &str, pos: usize, type_display: &str) -> usize {
541 let remaining = &sql[pos..];
546 let display_len = type_display.len();
547
548 if remaining.len() >= display_len && remaining[..display_len].eq_ignore_ascii_case(type_display)
550 {
551 return display_len;
552 }
553
554 let mut len = 0;
556 let mut depth = 0i32;
557 for &b in remaining.as_bytes() {
558 match b {
559 b'(' => {
560 depth += 1;
561 len += 1;
562 }
563 b')' if depth > 0 => {
564 depth -= 1;
565 len += 1;
566 }
567 b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_' => len += 1,
568 b' ' | b'\t' | b'\n' | b',' if depth > 0 => len += 1,
569 _ => break,
570 }
571 }
572 len
573}
574
575fn convert_cast(cast_text: &str, from_style: CastStyle, to_style: CastStyle) -> Option<String> {
580 match (from_style, to_style) {
581 (CastStyle::FunctionCast, CastStyle::DoubleColon) => cast_to_shorthand(cast_text),
582 (CastStyle::FunctionCast, CastStyle::Convert) => cast_to_convert(cast_text),
583 (CastStyle::DoubleColon, CastStyle::FunctionCast) => shorthand_to_cast(cast_text),
584 (CastStyle::DoubleColon, CastStyle::Convert) => shorthand_to_convert(cast_text),
585 (CastStyle::Convert, CastStyle::FunctionCast) => convert_to_cast(cast_text),
586 (CastStyle::Convert, CastStyle::DoubleColon) => convert_to_shorthand(cast_text),
587 _ => None,
588 }
589}
590
591fn parse_cast_interior(cast_text: &str) -> Option<(&str, &str)> {
594 let open = cast_text.find('(')?;
595 let close = cast_text.rfind(')')?;
596 let inner = cast_text[open + 1..close].trim();
597
598 let as_pos = find_top_level_as(inner)?;
599 let expr_part = inner[..as_pos].trim();
600 let type_part = inner[as_pos + 1..].trim();
603 let type_part = type_part
605 .strip_prefix("AS")
606 .or_else(|| type_part.strip_prefix("as"))
607 .or_else(|| type_part.strip_prefix("As"))
608 .or_else(|| type_part.strip_prefix("aS"))
609 .unwrap_or(type_part)
610 .trim();
611 Some((expr_part, type_part))
612}
613
614fn find_top_level_as(inner: &str) -> Option<usize> {
616 let bytes = inner.as_bytes();
617 let mut depth = 0i32;
618 let mut i = 0;
619 while i < bytes.len() {
620 match bytes[i] {
621 b'(' => depth += 1,
622 b')' => depth -= 1,
623 b'\'' | b'"' => {
624 let quote = bytes[i];
625 i += 1;
626 while i < bytes.len() && bytes[i] != quote {
627 if bytes[i] == b'\\' {
628 i += 1;
629 }
630 i += 1;
631 }
632 }
633 _ if depth == 0 => {
634 if is_whitespace_byte(bytes[i])
635 && i + 3 < bytes.len()
636 && bytes[i + 1].eq_ignore_ascii_case(&b'A')
637 && bytes[i + 2].eq_ignore_ascii_case(&b'S')
638 && is_whitespace_byte(bytes[i + 3])
639 {
640 return Some(i);
641 }
642 }
643 _ => {}
644 }
645 i += 1;
646 }
647 None
648}
649
650fn is_whitespace_byte(b: u8) -> bool {
651 matches!(b, b' ' | b'\t' | b'\n' | b'\r')
652}
653
654fn cast_to_shorthand(cast_text: &str) -> Option<String> {
656 let (expr, type_text) = parse_cast_interior(cast_text)?;
657 let needs_parens = expr_is_complex(expr);
658 if needs_parens {
659 Some(format!("({expr})::{type_text}"))
660 } else {
661 Some(format!("{expr}::{type_text}"))
662 }
663}
664
665fn cast_to_convert(cast_text: &str) -> Option<String> {
667 let (expr, type_text) = parse_cast_interior(cast_text)?;
668 Some(format!("convert({type_text}, {expr})"))
669}
670
671fn convert_to_cast(convert_text: &str) -> Option<String> {
673 let (type_text, expr) = parse_convert_interior(convert_text)?;
674 Some(format!("cast({expr} as {type_text})"))
675}
676
677fn convert_to_shorthand(convert_text: &str) -> Option<String> {
679 let (type_text, expr) = parse_convert_interior(convert_text)?;
680 let needs_parens = expr_is_complex(expr);
681 if needs_parens {
682 Some(format!("({expr})::{type_text}"))
683 } else {
684 Some(format!("{expr}::{type_text}"))
685 }
686}
687
688fn shorthand_to_cast(shorthand_text: &str) -> Option<String> {
691 let parts = split_shorthand_chain(shorthand_text)?;
692 if parts.len() < 2 {
693 return None;
694 }
695 let mut result = rewrite_nested_simple_shorthand_to_cast(parts[0]);
696 for type_part in &parts[1..] {
697 result = format!("cast({result} as {type_part})");
698 }
699 Some(result)
700}
701
702fn shorthand_to_convert(shorthand_text: &str) -> Option<String> {
705 let parts = split_shorthand_chain(shorthand_text)?;
706 if parts.len() < 2 {
707 return None;
708 }
709 let mut result = parts[0].to_string();
710 for type_part in &parts[1..] {
711 result = format!("convert({type_part}, {result})");
712 }
713 Some(result)
714}
715
716fn split_shorthand_chain(text: &str) -> Option<Vec<&str>> {
718 let mut parts = Vec::new();
719 let mut depth = 0i32;
720 let bytes = text.as_bytes();
721 let mut last_split = 0;
722
723 let mut i = 0;
724 while i < bytes.len() {
725 match bytes[i] {
726 b'(' => depth += 1,
727 b')' => depth -= 1,
728 b'\'' | b'"' => {
729 let quote = bytes[i];
730 i += 1;
731 while i < bytes.len() && bytes[i] != quote {
732 if bytes[i] == b'\\' {
733 i += 1;
734 }
735 i += 1;
736 }
737 }
738 b':' if depth == 0 && i + 1 < bytes.len() && bytes[i + 1] == b':' => {
739 parts.push(&text[last_split..i]);
740 i += 2;
741 last_split = i;
742 continue;
743 }
744 _ => {}
745 }
746 i += 1;
747 }
748 parts.push(&text[last_split..]);
749
750 if parts.len() >= 2 {
751 Some(parts)
752 } else {
753 None
754 }
755}
756
757fn rewrite_nested_simple_shorthand_to_cast(expr: &str) -> String {
762 let bytes = expr.as_bytes();
763 let mut index = 0usize;
764 let mut out = String::with_capacity(expr.len() + 16);
765
766 while index < bytes.len() {
767 let Some(rel_dc) = expr[index..].find("::") else {
768 out.push_str(&expr[index..]);
769 break;
770 };
771 let dc = index + rel_dc;
772
773 let mut lhs_start = dc;
774 while lhs_start > 0 && is_simple_shorthand_lhs_char(bytes[lhs_start - 1]) {
775 lhs_start -= 1;
776 }
777 if lhs_start == dc {
778 out.push_str(&expr[index..dc + 2]);
779 index = dc + 2;
780 continue;
781 }
782
783 let mut rhs_end = dc + 2;
784 while rhs_end < bytes.len() && is_simple_type_char(bytes[rhs_end]) {
785 rhs_end += 1;
786 }
787 if rhs_end == dc + 2 {
788 out.push_str(&expr[index..dc + 2]);
789 index = dc + 2;
790 continue;
791 }
792
793 out.push_str(&expr[index..lhs_start]);
794 out.push_str("cast(");
795 out.push_str(&expr[lhs_start..dc]);
796 out.push_str(" as ");
797 out.push_str(&expr[dc + 2..rhs_end]);
798 out.push(')');
799 index = rhs_end;
800 }
801
802 out
803}
804
805fn is_simple_shorthand_lhs_char(byte: u8) -> bool {
806 byte.is_ascii_alphanumeric()
807 || matches!(
808 byte,
809 b'_' | b'.' | b':' | b'$' | b'@' | b'"' | b'`' | b'[' | b']'
810 )
811}
812
813fn is_simple_type_char(byte: u8) -> bool {
814 byte.is_ascii_alphanumeric()
815 || matches!(
816 byte,
817 b'_' | b' ' | b'\t' | b'\n' | b'\r' | b'(' | b')' | b','
818 )
819}
820
821fn parse_convert_interior(convert_text: &str) -> Option<(&str, &str)> {
823 let open = convert_text.find('(')?;
824 let close = convert_text.rfind(')')?;
825 let inner = convert_text[open + 1..close].trim();
826 let comma = find_top_level_comma(inner)?;
827 let type_part = inner[..comma].trim();
828 let expr_part = inner[comma + 1..].trim();
829 Some((type_part, expr_part))
830}
831
832fn find_top_level_comma(inner: &str) -> Option<usize> {
834 let bytes = inner.as_bytes();
835 let mut depth = 0i32;
836 let mut i = 0;
837 while i < bytes.len() {
838 match bytes[i] {
839 b'(' => depth += 1,
840 b')' => depth -= 1,
841 b'\'' | b'"' => {
842 let quote = bytes[i];
843 i += 1;
844 while i < bytes.len() && bytes[i] != quote {
845 if bytes[i] == b'\\' {
846 i += 1;
847 }
848 i += 1;
849 }
850 }
851 b',' if depth == 0 => return Some(i),
852 _ => {}
853 }
854 i += 1;
855 }
856 None
857}
858
859fn expr_is_complex(expr: &str) -> bool {
862 let trimmed = expr.trim();
863 let bytes = trimmed.as_bytes();
864 let mut depth = 0i32;
865 for (i, &b) in bytes.iter().enumerate() {
866 match b {
867 b'(' => depth += 1,
868 b')' => depth -= 1,
869 b'\'' | b'"' => return false, b'|' | b'+' | b'-' | b'*' | b'/' | b'%' if depth == 0 => {
871 if b == b'-' && i == 0 {
872 continue;
873 }
874 return true;
875 }
876 b' ' | b'\t' | b'\n' if depth == 0 => return true,
877 _ => {}
878 }
879 }
880 false
881}
882
883fn expr_span_offsets(sql: &str, expr: &Expr) -> Option<(usize, usize)> {
888 let span = expr.span();
889 if span.start.line == 0 || span.start.column == 0 || span.end.line == 0 || span.end.column == 0
890 {
891 return None;
892 }
893
894 let start = line_col_to_offset(sql, span.start.line as usize, span.start.column as usize)?;
895 let end = line_col_to_offset(sql, span.end.line as usize, span.end.column as usize)?;
896 (end >= start).then_some((start, end))
897}
898
899fn line_col_to_offset(sql: &str, line: usize, column: usize) -> Option<usize> {
900 if line == 0 || column == 0 {
901 return None;
902 }
903
904 let mut current_line = 1usize;
905 let mut line_start = 0usize;
906
907 for (idx, ch) in sql.char_indices() {
908 if current_line == line {
909 break;
910 }
911 if ch == '\n' {
912 current_line += 1;
913 line_start = idx + ch.len_utf8();
914 }
915 }
916
917 if current_line != line {
918 return None;
919 }
920
921 let mut col = 1usize;
922 for (idx, _ch) in sql[line_start..].char_indices() {
923 if col == column {
924 return Some(line_start + idx);
925 }
926 col += 1;
927 }
928 if col == column {
929 return Some(sql.len());
930 }
931 None
932}
933
934#[cfg(test)]
935mod tests {
936 use super::*;
937 use crate::parser::parse_sql;
938
939 fn run(sql: &str) -> Vec<Issue> {
940 let statements = parse_sql(sql).expect("parse");
941 let rule = ConventionCastingStyle::default();
942 statements
943 .iter()
944 .enumerate()
945 .flat_map(|(index, statement)| {
946 rule.check(
947 statement,
948 &LintContext {
949 sql,
950 statement_range: 0..sql.len(),
951 statement_index: index,
952 },
953 )
954 })
955 .collect()
956 }
957
958 fn run_with_config(sql: &str, config: &LintConfig) -> Vec<Issue> {
959 let statements = parse_sql(sql).expect("parse");
960 let rule = ConventionCastingStyle::from_config(config);
961 statements
962 .iter()
963 .enumerate()
964 .flat_map(|(index, statement)| {
965 rule.check(
966 statement,
967 &LintContext {
968 sql,
969 statement_range: 0..sql.len(),
970 statement_index: index,
971 },
972 )
973 })
974 .collect()
975 }
976
977 fn apply_edits(sql: &str, edits: &[IssuePatchEdit]) -> String {
978 let mut sorted: Vec<_> = edits.iter().collect();
979 sorted.sort_by_key(|e| std::cmp::Reverse(e.span.start));
980 let mut result = sql.to_string();
981 for edit in sorted {
982 result.replace_range(edit.span.start..edit.span.end, &edit.replacement);
983 }
984 result
985 }
986
987 fn collect_all_edits(issues: &[Issue]) -> Vec<&IssuePatchEdit> {
988 issues
989 .iter()
990 .filter_map(|i| i.autofix.as_ref())
991 .flat_map(|a| a.edits.iter())
992 .collect()
993 }
994
995 fn apply_all_fixes(sql: &str, issues: &[Issue]) -> String {
996 let edits = collect_all_edits(issues);
997 let owned: Vec<IssuePatchEdit> = edits.into_iter().cloned().collect();
998 apply_edits(sql, &owned)
999 }
1000
1001 #[test]
1002 fn flags_mixed_casting_styles() {
1003 let issues = run("SELECT CAST(amount AS INT)::TEXT FROM t");
1004 assert_eq!(issues.len(), 1);
1005 assert_eq!(issues[0].code, issue_codes::LINT_CV_011);
1006 }
1007
1008 #[test]
1009 fn does_not_flag_single_casting_style() {
1010 assert!(run("SELECT amount::INT FROM t").is_empty());
1011 assert!(run("SELECT CAST(amount AS INT) FROM t").is_empty());
1012 }
1013
1014 #[test]
1015 fn does_not_flag_cast_like_tokens_inside_string_literal() {
1016 assert!(run("SELECT 'value::TEXT and CAST(value AS INT)' AS note").is_empty());
1017 }
1018
1019 #[test]
1020 fn flags_mixed_try_cast_and_double_colon_styles() {
1021 let issues = run("SELECT TRY_CAST(amount AS INT)::TEXT FROM t");
1022 assert_eq!(issues.len(), 1);
1023 assert_eq!(issues[0].code, issue_codes::LINT_CV_011);
1024 }
1025
1026 #[test]
1027 fn shorthand_preference_flags_cast_function_style() {
1028 let config = LintConfig {
1029 enabled: true,
1030 disabled_rules: vec![],
1031 rule_configs: std::collections::BTreeMap::from([(
1032 "convention.casting_style".to_string(),
1033 serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1034 )]),
1035 };
1036 let rule = ConventionCastingStyle::from_config(&config);
1037 let sql = "SELECT CAST(amount AS INT) FROM t";
1038 let statements = parse_sql(sql).expect("parse");
1039 let issues = rule.check(
1040 &statements[0],
1041 &LintContext {
1042 sql,
1043 statement_range: 0..sql.len(),
1044 statement_index: 0,
1045 },
1046 );
1047 assert_eq!(issues.len(), 1);
1048 }
1049
1050 #[test]
1051 fn cast_preference_flags_shorthand_style() {
1052 let config = LintConfig {
1053 enabled: true,
1054 disabled_rules: vec![],
1055 rule_configs: std::collections::BTreeMap::from([(
1056 "LINT_CV_011".to_string(),
1057 serde_json::json!({"preferred_type_casting_style": "cast"}),
1058 )]),
1059 };
1060 let rule = ConventionCastingStyle::from_config(&config);
1061 let sql = "SELECT amount::INT FROM t";
1062 let statements = parse_sql(sql).expect("parse");
1063 let issues = rule.check(
1064 &statements[0],
1065 &LintContext {
1066 sql,
1067 statement_range: 0..sql.len(),
1068 statement_index: 0,
1069 },
1070 );
1071 assert_eq!(issues.len(), 1);
1072 }
1073
1074 #[test]
1079 fn autofix_consistent_prior_convert() {
1080 let sql = "select\n convert(int, 1) as bar,\n 100::int::text,\n cast(10\n as text) as coo\nfrom foo;";
1081 let issues = run(sql);
1082 assert!(!issues.is_empty());
1083 let fixed = apply_all_fixes(sql, &issues);
1084 assert_eq!(
1085 fixed,
1086 "select\n convert(int, 1) as bar,\n convert(text, convert(int, 100)),\n convert(text, 10) as coo\nfrom foo;"
1087 );
1088 }
1089
1090 #[test]
1091 fn autofix_consistent_prior_cast() {
1092 let sql = "select\n cast(10 as text) as coo,\n convert(int, 1) as bar,\n 100::int::text,\nfrom foo;";
1093 let issues = run(sql);
1094 assert!(!issues.is_empty());
1095 let fixed = apply_all_fixes(sql, &issues);
1096 assert_eq!(
1097 fixed,
1098 "select\n cast(10 as text) as coo,\n cast(1 as int) as bar,\n cast(cast(100 as int) as text),\nfrom foo;"
1099 );
1100 }
1101
1102 #[test]
1103 fn autofix_consistent_prior_shorthand() {
1104 let sql = "select\n 100::int::text,\n cast(10 as text) as coo,\n convert(int, 1) as bar\nfrom foo;";
1105 let issues = run(sql);
1106 assert!(!issues.is_empty());
1107 let fixed = apply_all_fixes(sql, &issues);
1108 assert_eq!(
1109 fixed,
1110 "select\n 100::int::text,\n 10::text as coo,\n 1::int as bar\nfrom foo;"
1111 );
1112 }
1113
1114 #[test]
1115 fn autofix_config_cast() {
1116 let config = LintConfig {
1117 enabled: true,
1118 disabled_rules: vec![],
1119 rule_configs: std::collections::BTreeMap::from([(
1120 "convention.casting_style".to_string(),
1121 serde_json::json!({"preferred_type_casting_style": "cast"}),
1122 )]),
1123 };
1124 let sql = "select\n convert(int, 1) as bar,\n 100::int::text,\n cast(10 as text) as coo\nfrom foo;";
1125 let issues = run_with_config(sql, &config);
1126 assert!(!issues.is_empty());
1127 let fixed = apply_all_fixes(sql, &issues);
1128 assert_eq!(
1129 fixed,
1130 "select\n cast(1 as int) as bar,\n cast(cast(100 as int) as text),\n cast(10 as text) as coo\nfrom foo;"
1131 );
1132 }
1133
1134 #[test]
1135 fn autofix_config_convert() {
1136 let config = LintConfig {
1137 enabled: true,
1138 disabled_rules: vec![],
1139 rule_configs: std::collections::BTreeMap::from([(
1140 "convention.casting_style".to_string(),
1141 serde_json::json!({"preferred_type_casting_style": "convert"}),
1142 )]),
1143 };
1144 let sql = "select\n convert(int, 1) as bar,\n 100::int::text,\n cast(10 as text) as coo\nfrom foo;";
1145 let issues = run_with_config(sql, &config);
1146 assert!(!issues.is_empty());
1147 let fixed = apply_all_fixes(sql, &issues);
1148 assert_eq!(
1149 fixed,
1150 "select\n convert(int, 1) as bar,\n convert(text, convert(int, 100)),\n convert(text, 10) as coo\nfrom foo;"
1151 );
1152 }
1153
1154 #[test]
1155 fn autofix_config_shorthand() {
1156 let config = LintConfig {
1157 enabled: true,
1158 disabled_rules: vec![],
1159 rule_configs: std::collections::BTreeMap::from([(
1160 "convention.casting_style".to_string(),
1161 serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1162 )]),
1163 };
1164 let sql = "select\n convert(int, 1) as bar,\n 100::int::text,\n cast(10 as text) as coo\nfrom foo;";
1165 let issues = run_with_config(sql, &config);
1166 assert!(!issues.is_empty());
1167 let fixed = apply_all_fixes(sql, &issues);
1168 assert_eq!(
1169 fixed,
1170 "select\n 1::int as bar,\n 100::int::text,\n 10::text as coo\nfrom foo;"
1171 );
1172 }
1173
1174 #[test]
1175 fn autofix_3arg_convert_skipped_config_cast() {
1176 let config = LintConfig {
1177 enabled: true,
1178 disabled_rules: vec![],
1179 rule_configs: std::collections::BTreeMap::from([(
1180 "convention.casting_style".to_string(),
1181 serde_json::json!({"preferred_type_casting_style": "cast"}),
1182 )]),
1183 };
1184 let sql = "select\n convert(int, 1, 126) as bar,\n 100::int::text,\n cast(10 as text) as coo\nfrom foo;";
1185 let issues = run_with_config(sql, &config);
1186 assert!(!issues.is_empty());
1187 let fixed = apply_all_fixes(sql, &issues);
1188 assert_eq!(
1189 fixed,
1190 "select\n convert(int, 1, 126) as bar,\n cast(cast(100 as int) as text),\n cast(10 as text) as coo\nfrom foo;"
1191 );
1192 }
1193
1194 #[test]
1195 fn autofix_3arg_convert_skipped_config_shorthand() {
1196 let config = LintConfig {
1197 enabled: true,
1198 disabled_rules: vec![],
1199 rule_configs: std::collections::BTreeMap::from([(
1200 "convention.casting_style".to_string(),
1201 serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1202 )]),
1203 };
1204 let sql = "select\n convert(int, 1, 126) as bar,\n 100::int::text,\n cast(10 as text) as coo\nfrom foo;";
1205 let issues = run_with_config(sql, &config);
1206 assert!(!issues.is_empty());
1207 let fixed = apply_all_fixes(sql, &issues);
1208 assert_eq!(
1209 fixed,
1210 "select\n convert(int, 1, 126) as bar,\n 100::int::text,\n 10::text as coo\nfrom foo;"
1211 );
1212 }
1213
1214 #[test]
1215 fn autofix_parenthesize_complex_expr_shorthand_from_cast() {
1216 let config = LintConfig {
1217 enabled: true,
1218 disabled_rules: vec![],
1219 rule_configs: std::collections::BTreeMap::from([(
1220 "convention.casting_style".to_string(),
1221 serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1222 )]),
1223 };
1224 let sql = "select\n id::int,\n cast(calendar_date||' 11:00:00' as timestamp) as calendar_datetime\nfrom foo;";
1225 let issues = run_with_config(sql, &config);
1226 assert!(!issues.is_empty());
1227 let fixed = apply_all_fixes(sql, &issues);
1228 assert_eq!(
1229 fixed,
1230 "select\n id::int,\n (calendar_date||' 11:00:00')::timestamp as calendar_datetime\nfrom foo;"
1231 );
1232 }
1233
1234 #[test]
1235 fn autofix_parenthesize_complex_expr_shorthand_from_convert() {
1236 let config = LintConfig {
1237 enabled: true,
1238 disabled_rules: vec![],
1239 rule_configs: std::collections::BTreeMap::from([(
1240 "convention.casting_style".to_string(),
1241 serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1242 )]),
1243 };
1244 let sql = "select\n id::int,\n convert(timestamp, calendar_date||' 11:00:00') as calendar_datetime\nfrom foo;";
1245 let issues = run_with_config(sql, &config);
1246 assert!(!issues.is_empty());
1247 let fixed = apply_all_fixes(sql, &issues);
1248 assert_eq!(
1249 fixed,
1250 "select\n id::int,\n (calendar_date||' 11:00:00')::timestamp as calendar_datetime\nfrom foo;"
1251 );
1252 }
1253
1254 #[test]
1255 fn autofix_comment_cast_skipped() {
1256 let sql = "select\n cast(10 as text) as coo,\n convert( -- Convert the value\n int, /*\n to an integer\n */ 1) as bar,\n 100::int::text\nfrom foo;";
1257 let issues = run(sql);
1258 assert!(!issues.is_empty());
1259 let fixed = apply_all_fixes(sql, &issues);
1260 assert_eq!(
1261 fixed,
1262 "select\n cast(10 as text) as coo,\n convert( -- Convert the value\n int, /*\n to an integer\n */ 1) as bar,\n cast(cast(100 as int) as text)\nfrom foo;"
1263 );
1264 }
1265
1266 #[test]
1267 fn autofix_3arg_convert_consistent_prior_cast() {
1268 let sql = "select\n cast(10 as text) as coo,\n convert(int, 1, 126) as bar,\n 100::int::text\nfrom foo;";
1269 let issues = run(sql);
1270 assert!(!issues.is_empty());
1271 let fixed = apply_all_fixes(sql, &issues);
1272 assert_eq!(
1273 fixed,
1274 "select\n cast(10 as text) as coo,\n convert(int, 1, 126) as bar,\n cast(cast(100 as int) as text)\nfrom foo;"
1275 );
1276 }
1277
1278 #[test]
1279 fn autofix_comment_prior_convert_shorthand_fixed() {
1280 let sql = "select\n convert(int, 126) as bar,\n cast(\n 1 /* cast the value\n to an integer\n */ as int) as coo,\n 100::int::text\nfrom foo;";
1281 let issues = run(sql);
1282 assert!(!issues.is_empty());
1283 let fixed = apply_all_fixes(sql, &issues);
1284 assert_eq!(
1285 fixed,
1286 "select\n convert(int, 126) as bar,\n cast(\n 1 /* cast the value\n to an integer\n */ as int) as coo,\n convert(text, convert(int, 100))\nfrom foo;"
1287 );
1288 }
1289
1290 #[test]
1291 fn autofix_comment_prior_shorthand_convert_fixed() {
1292 let sql = "select\n 100::int::text,\n convert(int, 126) as bar,\n cast(\n 1 /* cast the value\n to an integer\n */ as int) as coo\nfrom foo;";
1293 let issues = run(sql);
1294 assert!(!issues.is_empty());
1295 let fixed = apply_all_fixes(sql, &issues);
1296 assert_eq!(
1297 fixed,
1298 "select\n 100::int::text,\n 126::int as bar,\n cast(\n 1 /* cast the value\n to an integer\n */ as int) as coo\nfrom foo;"
1299 );
1300 }
1301
1302 #[test]
1303 fn shorthand_to_cast_rewrites_nested_snowflake_path_casts() {
1304 let fixed = shorthand_to_cast("(trim(value:Longitude::varchar))::double").expect("rewrite");
1305 assert_eq!(
1306 fixed,
1307 "cast((trim(cast(value:Longitude as varchar))) as double)"
1308 );
1309 assert_eq!(
1310 shorthand_to_cast("col:a.b:c::varchar").expect("rewrite"),
1311 "cast(col:a.b:c as varchar)"
1312 );
1313 }
1314}