1use crate::linter::config::LintConfig;
8use crate::linter::rule::{LintContext, LintRule};
9use crate::linter::visit::visit_expressions;
10use crate::types::{issue_codes, Issue, IssueAutofixApplicability, IssuePatchEdit, Span};
11use sqlparser::ast::{CastKind, DataType, Expr, Spanned, Statement};
12
13#[derive(Clone, Copy, Debug, Eq, PartialEq)]
18enum PreferredTypeCastingStyle {
19 Consistent,
20 Shorthand,
21 Cast,
22 Convert,
23}
24
25impl PreferredTypeCastingStyle {
26 fn from_config(config: &LintConfig) -> Self {
27 match config
28 .rule_option_str(issue_codes::LINT_CV_011, "preferred_type_casting_style")
29 .unwrap_or("consistent")
30 .to_ascii_lowercase()
31 .as_str()
32 {
33 "shorthand" => Self::Shorthand,
34 "cast" => Self::Cast,
35 "convert" => Self::Convert,
36 _ => Self::Consistent,
37 }
38 }
39}
40
41#[derive(Clone, Copy, Debug, Eq, PartialEq)]
46enum CastStyle {
47 FunctionCast,
48 DoubleColon,
49 Convert,
50}
51
52struct CastInstance {
54 style: CastStyle,
55 start: usize,
57 end: usize,
58 has_comments: bool,
60 is_3arg_convert: bool,
62}
63
64pub struct ConventionCastingStyle {
69 preferred_style: PreferredTypeCastingStyle,
70}
71
72impl ConventionCastingStyle {
73 pub fn from_config(config: &LintConfig) -> Self {
74 Self {
75 preferred_style: PreferredTypeCastingStyle::from_config(config),
76 }
77 }
78}
79
80impl Default for ConventionCastingStyle {
81 fn default() -> Self {
82 Self {
83 preferred_style: PreferredTypeCastingStyle::Consistent,
84 }
85 }
86}
87
88impl LintRule for ConventionCastingStyle {
89 fn code(&self) -> &'static str {
90 issue_codes::LINT_CV_011
91 }
92
93 fn name(&self) -> &'static str {
94 "Casting style"
95 }
96
97 fn description(&self) -> &'static str {
98 "Enforce consistent type casting style."
99 }
100
101 fn check(&self, statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
102 let sql = ctx.sql;
103 let casts = collect_cast_instances(statement, sql);
104
105 if casts.is_empty() {
106 return Vec::new();
107 }
108
109 let target = match self.preferred_style {
111 PreferredTypeCastingStyle::Consistent => casts[0].style,
112 PreferredTypeCastingStyle::Shorthand => CastStyle::DoubleColon,
113 PreferredTypeCastingStyle::Cast => CastStyle::FunctionCast,
114 PreferredTypeCastingStyle::Convert => CastStyle::Convert,
115 };
116
117 let has_violation = casts.iter().any(|c| c.style != target);
119 if !has_violation {
120 return Vec::new();
121 }
122
123 let message = match self.preferred_style {
124 PreferredTypeCastingStyle::Consistent => {
125 "Use consistent casting style (avoid mixing CAST styles)."
126 }
127 PreferredTypeCastingStyle::Shorthand => "Use `::` shorthand casting style.",
128 PreferredTypeCastingStyle::Cast => "Use `CAST(...)` style casts.",
129 PreferredTypeCastingStyle::Convert => "Use `CONVERT(...)` style casts.",
130 };
131
132 let mut issues = Vec::new();
136 for cast in &casts {
137 if cast.style == target {
138 continue;
139 }
140
141 let mut issue =
142 Issue::info(issue_codes::LINT_CV_011, message).with_statement(ctx.statement_index);
143
144 if !cast.is_3arg_convert && !cast.has_comments {
145 let cast_text = &sql[cast.start..cast.end];
146 if let Some(replacement) = convert_cast(cast_text, cast.style, target) {
147 issue = issue.with_autofix_edits(
148 IssueAutofixApplicability::Unsafe,
149 vec![IssuePatchEdit::new(
150 Span::new(cast.start, cast.end),
151 replacement,
152 )],
153 );
154 }
155 }
156
157 issues.push(issue);
158 }
159
160 issues
161 }
162}
163
164fn collect_cast_instances(statement: &Statement, sql: &str) -> Vec<CastInstance> {
169 let mut casts = Vec::new();
170
171 visit_expressions(statement, &mut |expr| {
172 match expr {
173 Expr::Cast {
174 kind,
175 expr: inner,
176 data_type,
177 ..
178 } => {
179 let style = match kind {
180 CastKind::DoubleColon => CastStyle::DoubleColon,
181 CastKind::Cast | CastKind::TryCast | CastKind::SafeCast => {
182 CastStyle::FunctionCast
183 }
184 };
185
186 let is_inner_chain = matches!(
189 inner.as_ref(),
190 Expr::Cast {
191 kind: CastKind::DoubleColon,
192 ..
193 }
194 );
195
196 let inner_span = find_cast_span(sql, inner, kind.clone(), data_type);
198 if let Some((start, end)) = inner_span {
199 let text = &sql[start..end];
200 let has_comments = text.contains("--") || text.contains("/*");
201
202 if style == CastStyle::DoubleColon && is_inner_chain {
203 casts.retain(|c: &CastInstance| c.start < start || c.end > end);
205 }
206
207 casts.push(CastInstance {
208 style,
209 start,
210 end,
211 has_comments,
212 is_3arg_convert: false,
213 });
214 }
215 }
216 Expr::Function(function)
217 if function.name.to_string().eq_ignore_ascii_case("CONVERT") =>
218 {
219 if let Some((start, mut end)) = expr_span_offsets(sql, expr) {
220 if end < sql.len() && sql.as_bytes().get(end) == Some(&b')') {
223 end += 1;
224 } else {
225 if let Some(close) = find_matching_close_paren(&sql[end..]) {
227 end += close + 1;
228 }
229 }
230
231 let text = &sql[start..end];
232 let has_comments = text.contains("--") || text.contains("/*");
233
234 let arg_count = match &function.args {
235 sqlparser::ast::FunctionArguments::List(list) => list.args.len(),
236 _ => 0,
237 };
238
239 casts.push(CastInstance {
240 style: CastStyle::Convert,
241 start,
242 end,
243 has_comments,
244 is_3arg_convert: arg_count > 2,
245 });
246 }
247 }
248 _ => {}
249 }
250 });
251
252 for (start, end) in scan_parenthesized_shorthand_cast_spans(sql) {
255 if casts.iter().any(|cast| {
256 cast.start == start && cast.end == end && cast.style == CastStyle::DoubleColon
257 }) {
258 continue;
259 }
260 let text = &sql[start..end];
261 casts.push(CastInstance {
262 style: CastStyle::DoubleColon,
263 start,
264 end,
265 has_comments: text.contains("--") || text.contains("/*"),
266 is_3arg_convert: false,
267 });
268 }
269
270 casts.sort_by_key(|c| c.start);
272
273 let mut deduped: Vec<CastInstance> = Vec::with_capacity(casts.len());
278 for cast in casts {
279 let mut dominated = false;
280 let mut replace_index = None;
281
282 for (index, other) in deduped.iter().enumerate() {
283 if other.start <= cast.start && other.end >= cast.end {
284 dominated = true;
285 break;
286 }
287 if cast.start <= other.start && cast.end >= other.end {
288 replace_index = Some(index);
289 break;
290 }
291 if cast.style == other.style
292 && spans_overlap(cast.start, cast.end, other.start, other.end)
293 {
294 let cast_len = cast.end.saturating_sub(cast.start);
295 let other_len = other.end.saturating_sub(other.start);
296 if cast_len > other_len {
297 replace_index = Some(index);
298 } else {
299 dominated = true;
300 }
301 break;
302 }
303 }
304
305 if dominated {
306 continue;
307 }
308
309 if let Some(index) = replace_index {
310 deduped[index] = cast;
311 } else {
312 deduped.push(cast);
313 }
314 }
315
316 deduped.sort_by_key(|cast| (cast.start, cast.end, cast.style as u8));
317 deduped.dedup_by(|left, right| left.start == right.start && left.end == right.end);
318 deduped
319}
320
321fn spans_overlap(left_start: usize, left_end: usize, right_start: usize, right_end: usize) -> bool {
322 left_start < right_end && right_start < left_end
323}
324
325fn scan_parenthesized_shorthand_cast_spans(sql: &str) -> Vec<(usize, usize)> {
326 let bytes = sql.as_bytes();
327 let mut out = Vec::new();
328 let mut index = 0usize;
329
330 while index + 1 < bytes.len() {
331 if bytes[index] != b':' || bytes[index + 1] != b':' {
332 index += 1;
333 continue;
334 }
335
336 let mut lhs_end = index;
337 while lhs_end > 0 && bytes[lhs_end - 1].is_ascii_whitespace() {
338 lhs_end -= 1;
339 }
340 if lhs_end == 0 || bytes[lhs_end - 1] != b')' {
341 index += 2;
342 continue;
343 }
344 let close_paren = lhs_end - 1;
345 let Some(open_paren) = find_matching_open_paren(bytes, close_paren) else {
346 index += 2;
347 continue;
348 };
349
350 let Some(type_end) = scan_parenthesized_shorthand_type_end(bytes, index + 2) else {
351 index += 2;
352 continue;
353 };
354
355 out.push((open_paren, type_end));
356 index = type_end;
357 }
358
359 out
360}
361
362fn scan_parenthesized_shorthand_type_end(bytes: &[u8], start: usize) -> Option<usize> {
363 let mut index = start;
364 let mut depth = 0i32;
365 let mut saw_any = false;
366
367 while index < bytes.len() {
368 match bytes[index] {
369 b'(' => {
370 depth += 1;
371 saw_any = true;
372 index += 1;
373 }
374 b')' if depth > 0 => {
375 depth -= 1;
376 index += 1;
377 }
378 b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_' | b'.' => {
379 saw_any = true;
380 index += 1;
381 }
382 b',' if depth > 0 => index += 1,
383 b' ' | b'\t' | b'\n' | b'\r' if depth > 0 => index += 1,
384 _ => break,
385 }
386 }
387
388 if saw_any {
389 Some(index)
390 } else {
391 None
392 }
393}
394
395fn find_matching_open_paren(bytes: &[u8], close_paren: usize) -> Option<usize> {
396 if bytes.get(close_paren).copied() != Some(b')') {
397 return None;
398 }
399 let mut depth = 1i32;
400 let mut cursor = close_paren;
401 while cursor > 0 {
402 cursor -= 1;
403 match bytes[cursor] {
404 b')' => depth += 1,
405 b'(' => {
406 depth -= 1;
407 if depth == 0 {
408 return Some(cursor);
409 }
410 }
411 _ => {}
412 }
413 }
414 None
415}
416
417fn find_cast_span(
426 sql: &str,
427 inner: &Expr,
428 kind: CastKind,
429 data_type: &DataType,
430) -> Option<(usize, usize)> {
431 match kind {
432 CastKind::Cast | CastKind::TryCast | CastKind::SafeCast => {
433 let (inner_start, inner_end) = expr_span_offsets(sql, inner)?;
434
435 let before = &sql[..inner_start];
437 let paren_pos = before.rfind('(')?;
438 let before_paren = before[..paren_pos].trim_end();
439 let kw = match kind {
440 CastKind::TryCast => "TRY_CAST",
441 CastKind::SafeCast => "SAFE_CAST",
442 _ => "CAST",
443 };
444 let kw_len = kw.len();
445 if before_paren.len() < kw_len {
446 return None;
447 }
448 let kw_candidate = &before_paren[before_paren.len() - kw_len..];
449 if !kw_candidate.eq_ignore_ascii_case(kw) {
450 return None;
451 }
452 let start = before_paren.len() - kw_len;
453
454 let after = &sql[inner_end..];
456 let close = find_matching_close_paren(after)?;
457 let end = inner_end + close + 1;
458
459 Some((start, end))
460 }
461 CastKind::DoubleColon => {
462 let base = deepest_base_expr(inner);
464 let (base_start, base_end) = expr_span_offsets(sql, base)?;
465
466 let type_str = data_type.to_string();
468 let mut pos = base_end;
469 loop {
470 let after = &sql[pos..];
471 let dc_pos = match after.find("::") {
472 Some(p) => p,
473 None => break,
474 };
475 let type_start = pos + dc_pos + 2;
476 let type_len = source_type_len(sql, type_start, &type_str);
477 if type_len == 0 {
478 break;
479 }
480 pos = type_start + type_len;
481 let this_type = &sql[type_start..pos];
483 if this_type.eq_ignore_ascii_case(&type_str) {
484 break;
485 }
486 }
487
488 Some((base_start, pos))
489 }
490 }
491}
492
493fn deepest_base_expr(expr: &Expr) -> &Expr {
496 match expr {
497 Expr::Cast {
498 kind: CastKind::DoubleColon,
499 expr: inner,
500 ..
501 } => deepest_base_expr(inner),
502 _ => expr,
503 }
504}
505
506fn find_matching_close_paren(text: &str) -> Option<usize> {
509 let mut depth = 0i32;
510 let bytes = text.as_bytes();
511 let mut i = 0;
512 while i < bytes.len() {
513 match bytes[i] {
514 b'(' => depth += 1,
515 b')' => {
516 if depth == 0 {
517 return Some(i);
518 }
519 depth -= 1;
520 }
521 b'\'' | b'"' => {
522 let quote = bytes[i];
523 i += 1;
524 while i < bytes.len() && bytes[i] != quote {
525 if bytes[i] == b'\\' {
526 i += 1;
527 }
528 i += 1;
529 }
530 }
531 _ => {}
532 }
533 i += 1;
534 }
535 None
536}
537
538fn source_type_len(sql: &str, pos: usize, type_display: &str) -> usize {
541 let remaining = &sql[pos..];
546 let display_len = type_display.len();
547
548 if remaining.len() >= display_len && remaining[..display_len].eq_ignore_ascii_case(type_display)
550 {
551 return display_len;
552 }
553
554 let mut len = 0;
556 let mut depth = 0i32;
557 for &b in remaining.as_bytes() {
558 match b {
559 b'(' => {
560 depth += 1;
561 len += 1;
562 }
563 b')' if depth > 0 => {
564 depth -= 1;
565 len += 1;
566 }
567 b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_' => len += 1,
568 b' ' | b'\t' | b'\n' | b',' if depth > 0 => len += 1,
569 _ => break,
570 }
571 }
572 len
573}
574
575fn convert_cast(cast_text: &str, from_style: CastStyle, to_style: CastStyle) -> Option<String> {
580 match (from_style, to_style) {
581 (CastStyle::FunctionCast, CastStyle::DoubleColon) => cast_to_shorthand(cast_text),
582 (CastStyle::FunctionCast, CastStyle::Convert) => cast_to_convert(cast_text),
583 (CastStyle::DoubleColon, CastStyle::FunctionCast) => shorthand_to_cast(cast_text),
584 (CastStyle::DoubleColon, CastStyle::Convert) => shorthand_to_convert(cast_text),
585 (CastStyle::Convert, CastStyle::FunctionCast) => convert_to_cast(cast_text),
586 (CastStyle::Convert, CastStyle::DoubleColon) => convert_to_shorthand(cast_text),
587 _ => None,
588 }
589}
590
591fn parse_cast_interior(cast_text: &str) -> Option<(&str, &str)> {
594 let open = cast_text.find('(')?;
595 let close = cast_text.rfind(')')?;
596 let inner = cast_text[open + 1..close].trim();
597
598 let as_pos = find_top_level_as(inner)?;
599 let expr_part = inner[..as_pos].trim();
600 let type_part = inner[as_pos + 1..].trim();
603 let type_part = type_part
605 .strip_prefix("AS")
606 .or_else(|| type_part.strip_prefix("as"))
607 .or_else(|| type_part.strip_prefix("As"))
608 .or_else(|| type_part.strip_prefix("aS"))
609 .unwrap_or(type_part)
610 .trim();
611 Some((expr_part, type_part))
612}
613
614fn find_top_level_as(inner: &str) -> Option<usize> {
616 let bytes = inner.as_bytes();
617 let mut depth = 0i32;
618 let mut i = 0;
619 while i < bytes.len() {
620 match bytes[i] {
621 b'(' => depth += 1,
622 b')' => depth -= 1,
623 b'\'' | b'"' => {
624 let quote = bytes[i];
625 i += 1;
626 while i < bytes.len() && bytes[i] != quote {
627 if bytes[i] == b'\\' {
628 i += 1;
629 }
630 i += 1;
631 }
632 }
633 _ if depth == 0
634 && is_whitespace_byte(bytes[i])
635 && i + 3 < bytes.len()
636 && bytes[i + 1].eq_ignore_ascii_case(&b'A')
637 && bytes[i + 2].eq_ignore_ascii_case(&b'S')
638 && is_whitespace_byte(bytes[i + 3]) =>
639 {
640 return Some(i);
641 }
642 _ => {}
643 }
644 i += 1;
645 }
646 None
647}
648
649fn is_whitespace_byte(b: u8) -> bool {
650 matches!(b, b' ' | b'\t' | b'\n' | b'\r')
651}
652
653fn cast_to_shorthand(cast_text: &str) -> Option<String> {
655 let (expr, type_text) = parse_cast_interior(cast_text)?;
656 let needs_parens = expr_is_complex(expr);
657 if needs_parens {
658 Some(format!("({expr})::{type_text}"))
659 } else {
660 Some(format!("{expr}::{type_text}"))
661 }
662}
663
664fn cast_to_convert(cast_text: &str) -> Option<String> {
666 let (expr, type_text) = parse_cast_interior(cast_text)?;
667 Some(format!("convert({type_text}, {expr})"))
668}
669
670fn convert_to_cast(convert_text: &str) -> Option<String> {
672 let (type_text, expr) = parse_convert_interior(convert_text)?;
673 Some(format!("cast({expr} as {type_text})"))
674}
675
676fn convert_to_shorthand(convert_text: &str) -> Option<String> {
678 let (type_text, expr) = parse_convert_interior(convert_text)?;
679 let needs_parens = expr_is_complex(expr);
680 if needs_parens {
681 Some(format!("({expr})::{type_text}"))
682 } else {
683 Some(format!("{expr}::{type_text}"))
684 }
685}
686
687fn shorthand_to_cast(shorthand_text: &str) -> Option<String> {
690 let parts = split_shorthand_chain(shorthand_text)?;
691 if parts.len() < 2 {
692 return None;
693 }
694 let mut result = rewrite_nested_simple_shorthand_to_cast(parts[0]);
695 for type_part in &parts[1..] {
696 result = format!("cast({result} as {type_part})");
697 }
698 Some(result)
699}
700
701fn shorthand_to_convert(shorthand_text: &str) -> Option<String> {
704 let parts = split_shorthand_chain(shorthand_text)?;
705 if parts.len() < 2 {
706 return None;
707 }
708 let mut result = parts[0].to_string();
709 for type_part in &parts[1..] {
710 result = format!("convert({type_part}, {result})");
711 }
712 Some(result)
713}
714
715fn split_shorthand_chain(text: &str) -> Option<Vec<&str>> {
717 let mut parts = Vec::new();
718 let mut depth = 0i32;
719 let bytes = text.as_bytes();
720 let mut last_split = 0;
721
722 let mut i = 0;
723 while i < bytes.len() {
724 match bytes[i] {
725 b'(' => depth += 1,
726 b')' => depth -= 1,
727 b'\'' | b'"' => {
728 let quote = bytes[i];
729 i += 1;
730 while i < bytes.len() && bytes[i] != quote {
731 if bytes[i] == b'\\' {
732 i += 1;
733 }
734 i += 1;
735 }
736 }
737 b':' if depth == 0 && i + 1 < bytes.len() && bytes[i + 1] == b':' => {
738 parts.push(&text[last_split..i]);
739 i += 2;
740 last_split = i;
741 continue;
742 }
743 _ => {}
744 }
745 i += 1;
746 }
747 parts.push(&text[last_split..]);
748
749 if parts.len() >= 2 {
750 Some(parts)
751 } else {
752 None
753 }
754}
755
756fn rewrite_nested_simple_shorthand_to_cast(expr: &str) -> String {
761 let bytes = expr.as_bytes();
762 let mut index = 0usize;
763 let mut out = String::with_capacity(expr.len() + 16);
764
765 while index < bytes.len() {
766 let Some(rel_dc) = expr[index..].find("::") else {
767 out.push_str(&expr[index..]);
768 break;
769 };
770 let dc = index + rel_dc;
771
772 let mut lhs_start = dc;
773 while lhs_start > 0 && is_simple_shorthand_lhs_char(bytes[lhs_start - 1]) {
774 lhs_start -= 1;
775 }
776 if lhs_start == dc {
777 out.push_str(&expr[index..dc + 2]);
778 index = dc + 2;
779 continue;
780 }
781
782 let mut rhs_end = dc + 2;
783 while rhs_end < bytes.len() && is_simple_type_char(bytes[rhs_end]) {
784 rhs_end += 1;
785 }
786 if rhs_end == dc + 2 {
787 out.push_str(&expr[index..dc + 2]);
788 index = dc + 2;
789 continue;
790 }
791
792 out.push_str(&expr[index..lhs_start]);
793 out.push_str("cast(");
794 out.push_str(&expr[lhs_start..dc]);
795 out.push_str(" as ");
796 out.push_str(&expr[dc + 2..rhs_end]);
797 out.push(')');
798 index = rhs_end;
799 }
800
801 out
802}
803
804fn is_simple_shorthand_lhs_char(byte: u8) -> bool {
805 byte.is_ascii_alphanumeric()
806 || matches!(
807 byte,
808 b'_' | b'.' | b':' | b'$' | b'@' | b'"' | b'`' | b'[' | b']'
809 )
810}
811
812fn is_simple_type_char(byte: u8) -> bool {
813 byte.is_ascii_alphanumeric()
814 || matches!(
815 byte,
816 b'_' | b' ' | b'\t' | b'\n' | b'\r' | b'(' | b')' | b','
817 )
818}
819
820fn parse_convert_interior(convert_text: &str) -> Option<(&str, &str)> {
822 let open = convert_text.find('(')?;
823 let close = convert_text.rfind(')')?;
824 let inner = convert_text[open + 1..close].trim();
825 let comma = find_top_level_comma(inner)?;
826 let type_part = inner[..comma].trim();
827 let expr_part = inner[comma + 1..].trim();
828 Some((type_part, expr_part))
829}
830
831fn find_top_level_comma(inner: &str) -> Option<usize> {
833 let bytes = inner.as_bytes();
834 let mut depth = 0i32;
835 let mut i = 0;
836 while i < bytes.len() {
837 match bytes[i] {
838 b'(' => depth += 1,
839 b')' => depth -= 1,
840 b'\'' | b'"' => {
841 let quote = bytes[i];
842 i += 1;
843 while i < bytes.len() && bytes[i] != quote {
844 if bytes[i] == b'\\' {
845 i += 1;
846 }
847 i += 1;
848 }
849 }
850 b',' if depth == 0 => return Some(i),
851 _ => {}
852 }
853 i += 1;
854 }
855 None
856}
857
858fn expr_is_complex(expr: &str) -> bool {
861 let trimmed = expr.trim();
862 let bytes = trimmed.as_bytes();
863 let mut depth = 0i32;
864 for (i, &b) in bytes.iter().enumerate() {
865 match b {
866 b'(' => depth += 1,
867 b')' => depth -= 1,
868 b'\'' | b'"' => return false, b'|' | b'+' | b'-' | b'*' | b'/' | b'%' if depth == 0 => {
870 if b == b'-' && i == 0 {
871 continue;
872 }
873 return true;
874 }
875 b' ' | b'\t' | b'\n' if depth == 0 => return true,
876 _ => {}
877 }
878 }
879 false
880}
881
882fn expr_span_offsets(sql: &str, expr: &Expr) -> Option<(usize, usize)> {
887 let span = expr.span();
888 if span.start.line == 0 || span.start.column == 0 || span.end.line == 0 || span.end.column == 0
889 {
890 return None;
891 }
892
893 let start = line_col_to_offset(sql, span.start.line as usize, span.start.column as usize)?;
894 let end = line_col_to_offset(sql, span.end.line as usize, span.end.column as usize)?;
895 (end >= start).then_some((start, end))
896}
897
898fn line_col_to_offset(sql: &str, line: usize, column: usize) -> Option<usize> {
899 if line == 0 || column == 0 {
900 return None;
901 }
902
903 let mut current_line = 1usize;
904 let mut line_start = 0usize;
905
906 for (idx, ch) in sql.char_indices() {
907 if current_line == line {
908 break;
909 }
910 if ch == '\n' {
911 current_line += 1;
912 line_start = idx + ch.len_utf8();
913 }
914 }
915
916 if current_line != line {
917 return None;
918 }
919
920 let mut col = 1usize;
921 for (idx, _ch) in sql[line_start..].char_indices() {
922 if col == column {
923 return Some(line_start + idx);
924 }
925 col += 1;
926 }
927 if col == column {
928 return Some(sql.len());
929 }
930 None
931}
932
933#[cfg(test)]
934mod tests {
935 use super::*;
936 use crate::parser::parse_sql;
937
938 fn run(sql: &str) -> Vec<Issue> {
939 let statements = parse_sql(sql).expect("parse");
940 let rule = ConventionCastingStyle::default();
941 statements
942 .iter()
943 .enumerate()
944 .flat_map(|(index, statement)| {
945 rule.check(
946 statement,
947 &LintContext {
948 sql,
949 statement_range: 0..sql.len(),
950 statement_index: index,
951 },
952 )
953 })
954 .collect()
955 }
956
957 fn run_with_config(sql: &str, config: &LintConfig) -> Vec<Issue> {
958 let statements = parse_sql(sql).expect("parse");
959 let rule = ConventionCastingStyle::from_config(config);
960 statements
961 .iter()
962 .enumerate()
963 .flat_map(|(index, statement)| {
964 rule.check(
965 statement,
966 &LintContext {
967 sql,
968 statement_range: 0..sql.len(),
969 statement_index: index,
970 },
971 )
972 })
973 .collect()
974 }
975
976 fn apply_edits(sql: &str, edits: &[IssuePatchEdit]) -> String {
977 let mut sorted: Vec<_> = edits.iter().collect();
978 sorted.sort_by_key(|e| std::cmp::Reverse(e.span.start));
979 let mut result = sql.to_string();
980 for edit in sorted {
981 result.replace_range(edit.span.start..edit.span.end, &edit.replacement);
982 }
983 result
984 }
985
986 fn collect_all_edits(issues: &[Issue]) -> Vec<&IssuePatchEdit> {
987 issues
988 .iter()
989 .filter_map(|i| i.autofix.as_ref())
990 .flat_map(|a| a.edits.iter())
991 .collect()
992 }
993
994 fn apply_all_fixes(sql: &str, issues: &[Issue]) -> String {
995 let edits = collect_all_edits(issues);
996 let owned: Vec<IssuePatchEdit> = edits.into_iter().cloned().collect();
997 apply_edits(sql, &owned)
998 }
999
1000 #[test]
1001 fn flags_mixed_casting_styles() {
1002 let issues = run("SELECT CAST(amount AS INT)::TEXT FROM t");
1003 assert_eq!(issues.len(), 1);
1004 assert_eq!(issues[0].code, issue_codes::LINT_CV_011);
1005 }
1006
1007 #[test]
1008 fn does_not_flag_single_casting_style() {
1009 assert!(run("SELECT amount::INT FROM t").is_empty());
1010 assert!(run("SELECT CAST(amount AS INT) FROM t").is_empty());
1011 }
1012
1013 #[test]
1014 fn does_not_flag_cast_like_tokens_inside_string_literal() {
1015 assert!(run("SELECT 'value::TEXT and CAST(value AS INT)' AS note").is_empty());
1016 }
1017
1018 #[test]
1019 fn flags_mixed_try_cast_and_double_colon_styles() {
1020 let issues = run("SELECT TRY_CAST(amount AS INT)::TEXT FROM t");
1021 assert_eq!(issues.len(), 1);
1022 assert_eq!(issues[0].code, issue_codes::LINT_CV_011);
1023 }
1024
1025 #[test]
1026 fn shorthand_preference_flags_cast_function_style() {
1027 let config = LintConfig {
1028 enabled: true,
1029 disabled_rules: vec![],
1030 rule_configs: std::collections::BTreeMap::from([(
1031 "convention.casting_style".to_string(),
1032 serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1033 )]),
1034 };
1035 let rule = ConventionCastingStyle::from_config(&config);
1036 let sql = "SELECT CAST(amount AS INT) FROM t";
1037 let statements = parse_sql(sql).expect("parse");
1038 let issues = rule.check(
1039 &statements[0],
1040 &LintContext {
1041 sql,
1042 statement_range: 0..sql.len(),
1043 statement_index: 0,
1044 },
1045 );
1046 assert_eq!(issues.len(), 1);
1047 }
1048
1049 #[test]
1050 fn cast_preference_flags_shorthand_style() {
1051 let config = LintConfig {
1052 enabled: true,
1053 disabled_rules: vec![],
1054 rule_configs: std::collections::BTreeMap::from([(
1055 "LINT_CV_011".to_string(),
1056 serde_json::json!({"preferred_type_casting_style": "cast"}),
1057 )]),
1058 };
1059 let rule = ConventionCastingStyle::from_config(&config);
1060 let sql = "SELECT amount::INT FROM t";
1061 let statements = parse_sql(sql).expect("parse");
1062 let issues = rule.check(
1063 &statements[0],
1064 &LintContext {
1065 sql,
1066 statement_range: 0..sql.len(),
1067 statement_index: 0,
1068 },
1069 );
1070 assert_eq!(issues.len(), 1);
1071 }
1072
1073 #[test]
1078 fn autofix_consistent_prior_convert() {
1079 let sql = "select\n convert(int, 1) as bar,\n 100::int::text,\n cast(10\n as text) as coo\nfrom foo;";
1080 let issues = run(sql);
1081 assert!(!issues.is_empty());
1082 let fixed = apply_all_fixes(sql, &issues);
1083 assert_eq!(
1084 fixed,
1085 "select\n convert(int, 1) as bar,\n convert(text, convert(int, 100)),\n convert(text, 10) as coo\nfrom foo;"
1086 );
1087 }
1088
1089 #[test]
1090 fn autofix_consistent_prior_cast() {
1091 let sql = "select\n cast(10 as text) as coo,\n convert(int, 1) as bar,\n 100::int::text,\nfrom foo;";
1092 let issues = run(sql);
1093 assert!(!issues.is_empty());
1094 let fixed = apply_all_fixes(sql, &issues);
1095 assert_eq!(
1096 fixed,
1097 "select\n cast(10 as text) as coo,\n cast(1 as int) as bar,\n cast(cast(100 as int) as text),\nfrom foo;"
1098 );
1099 }
1100
1101 #[test]
1102 fn autofix_consistent_prior_shorthand() {
1103 let sql = "select\n 100::int::text,\n cast(10 as text) as coo,\n convert(int, 1) as bar\nfrom foo;";
1104 let issues = run(sql);
1105 assert!(!issues.is_empty());
1106 let fixed = apply_all_fixes(sql, &issues);
1107 assert_eq!(
1108 fixed,
1109 "select\n 100::int::text,\n 10::text as coo,\n 1::int as bar\nfrom foo;"
1110 );
1111 }
1112
1113 #[test]
1114 fn autofix_config_cast() {
1115 let config = LintConfig {
1116 enabled: true,
1117 disabled_rules: vec![],
1118 rule_configs: std::collections::BTreeMap::from([(
1119 "convention.casting_style".to_string(),
1120 serde_json::json!({"preferred_type_casting_style": "cast"}),
1121 )]),
1122 };
1123 let sql = "select\n convert(int, 1) as bar,\n 100::int::text,\n cast(10 as text) as coo\nfrom foo;";
1124 let issues = run_with_config(sql, &config);
1125 assert!(!issues.is_empty());
1126 let fixed = apply_all_fixes(sql, &issues);
1127 assert_eq!(
1128 fixed,
1129 "select\n cast(1 as int) as bar,\n cast(cast(100 as int) as text),\n cast(10 as text) as coo\nfrom foo;"
1130 );
1131 }
1132
1133 #[test]
1134 fn autofix_config_convert() {
1135 let config = LintConfig {
1136 enabled: true,
1137 disabled_rules: vec![],
1138 rule_configs: std::collections::BTreeMap::from([(
1139 "convention.casting_style".to_string(),
1140 serde_json::json!({"preferred_type_casting_style": "convert"}),
1141 )]),
1142 };
1143 let sql = "select\n convert(int, 1) as bar,\n 100::int::text,\n cast(10 as text) as coo\nfrom foo;";
1144 let issues = run_with_config(sql, &config);
1145 assert!(!issues.is_empty());
1146 let fixed = apply_all_fixes(sql, &issues);
1147 assert_eq!(
1148 fixed,
1149 "select\n convert(int, 1) as bar,\n convert(text, convert(int, 100)),\n convert(text, 10) as coo\nfrom foo;"
1150 );
1151 }
1152
1153 #[test]
1154 fn autofix_config_shorthand() {
1155 let config = LintConfig {
1156 enabled: true,
1157 disabled_rules: vec![],
1158 rule_configs: std::collections::BTreeMap::from([(
1159 "convention.casting_style".to_string(),
1160 serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1161 )]),
1162 };
1163 let sql = "select\n convert(int, 1) as bar,\n 100::int::text,\n cast(10 as text) as coo\nfrom foo;";
1164 let issues = run_with_config(sql, &config);
1165 assert!(!issues.is_empty());
1166 let fixed = apply_all_fixes(sql, &issues);
1167 assert_eq!(
1168 fixed,
1169 "select\n 1::int as bar,\n 100::int::text,\n 10::text as coo\nfrom foo;"
1170 );
1171 }
1172
1173 #[test]
1174 fn autofix_3arg_convert_skipped_config_cast() {
1175 let config = LintConfig {
1176 enabled: true,
1177 disabled_rules: vec![],
1178 rule_configs: std::collections::BTreeMap::from([(
1179 "convention.casting_style".to_string(),
1180 serde_json::json!({"preferred_type_casting_style": "cast"}),
1181 )]),
1182 };
1183 let sql = "select\n convert(int, 1, 126) as bar,\n 100::int::text,\n cast(10 as text) as coo\nfrom foo;";
1184 let issues = run_with_config(sql, &config);
1185 assert!(!issues.is_empty());
1186 let fixed = apply_all_fixes(sql, &issues);
1187 assert_eq!(
1188 fixed,
1189 "select\n convert(int, 1, 126) as bar,\n cast(cast(100 as int) as text),\n cast(10 as text) as coo\nfrom foo;"
1190 );
1191 }
1192
1193 #[test]
1194 fn autofix_3arg_convert_skipped_config_shorthand() {
1195 let config = LintConfig {
1196 enabled: true,
1197 disabled_rules: vec![],
1198 rule_configs: std::collections::BTreeMap::from([(
1199 "convention.casting_style".to_string(),
1200 serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1201 )]),
1202 };
1203 let sql = "select\n convert(int, 1, 126) as bar,\n 100::int::text,\n cast(10 as text) as coo\nfrom foo;";
1204 let issues = run_with_config(sql, &config);
1205 assert!(!issues.is_empty());
1206 let fixed = apply_all_fixes(sql, &issues);
1207 assert_eq!(
1208 fixed,
1209 "select\n convert(int, 1, 126) as bar,\n 100::int::text,\n 10::text as coo\nfrom foo;"
1210 );
1211 }
1212
1213 #[test]
1214 fn autofix_parenthesize_complex_expr_shorthand_from_cast() {
1215 let config = LintConfig {
1216 enabled: true,
1217 disabled_rules: vec![],
1218 rule_configs: std::collections::BTreeMap::from([(
1219 "convention.casting_style".to_string(),
1220 serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1221 )]),
1222 };
1223 let sql = "select\n id::int,\n cast(calendar_date||' 11:00:00' as timestamp) as calendar_datetime\nfrom foo;";
1224 let issues = run_with_config(sql, &config);
1225 assert!(!issues.is_empty());
1226 let fixed = apply_all_fixes(sql, &issues);
1227 assert_eq!(
1228 fixed,
1229 "select\n id::int,\n (calendar_date||' 11:00:00')::timestamp as calendar_datetime\nfrom foo;"
1230 );
1231 }
1232
1233 #[test]
1234 fn autofix_parenthesize_complex_expr_shorthand_from_convert() {
1235 let config = LintConfig {
1236 enabled: true,
1237 disabled_rules: vec![],
1238 rule_configs: std::collections::BTreeMap::from([(
1239 "convention.casting_style".to_string(),
1240 serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1241 )]),
1242 };
1243 let sql = "select\n id::int,\n convert(timestamp, calendar_date||' 11:00:00') as calendar_datetime\nfrom foo;";
1244 let issues = run_with_config(sql, &config);
1245 assert!(!issues.is_empty());
1246 let fixed = apply_all_fixes(sql, &issues);
1247 assert_eq!(
1248 fixed,
1249 "select\n id::int,\n (calendar_date||' 11:00:00')::timestamp as calendar_datetime\nfrom foo;"
1250 );
1251 }
1252
1253 #[test]
1254 fn autofix_comment_cast_skipped() {
1255 let sql = "select\n cast(10 as text) as coo,\n convert( -- Convert the value\n int, /*\n to an integer\n */ 1) as bar,\n 100::int::text\nfrom foo;";
1256 let issues = run(sql);
1257 assert!(!issues.is_empty());
1258 let fixed = apply_all_fixes(sql, &issues);
1259 assert_eq!(
1260 fixed,
1261 "select\n cast(10 as text) as coo,\n convert( -- Convert the value\n int, /*\n to an integer\n */ 1) as bar,\n cast(cast(100 as int) as text)\nfrom foo;"
1262 );
1263 }
1264
1265 #[test]
1266 fn autofix_3arg_convert_consistent_prior_cast() {
1267 let sql = "select\n cast(10 as text) as coo,\n convert(int, 1, 126) as bar,\n 100::int::text\nfrom foo;";
1268 let issues = run(sql);
1269 assert!(!issues.is_empty());
1270 let fixed = apply_all_fixes(sql, &issues);
1271 assert_eq!(
1272 fixed,
1273 "select\n cast(10 as text) as coo,\n convert(int, 1, 126) as bar,\n cast(cast(100 as int) as text)\nfrom foo;"
1274 );
1275 }
1276
1277 #[test]
1278 fn autofix_comment_prior_convert_shorthand_fixed() {
1279 let sql = "select\n convert(int, 126) as bar,\n cast(\n 1 /* cast the value\n to an integer\n */ as int) as coo,\n 100::int::text\nfrom foo;";
1280 let issues = run(sql);
1281 assert!(!issues.is_empty());
1282 let fixed = apply_all_fixes(sql, &issues);
1283 assert_eq!(
1284 fixed,
1285 "select\n convert(int, 126) as bar,\n cast(\n 1 /* cast the value\n to an integer\n */ as int) as coo,\n convert(text, convert(int, 100))\nfrom foo;"
1286 );
1287 }
1288
1289 #[test]
1290 fn autofix_comment_prior_shorthand_convert_fixed() {
1291 let sql = "select\n 100::int::text,\n convert(int, 126) as bar,\n cast(\n 1 /* cast the value\n to an integer\n */ as int) as coo\nfrom foo;";
1292 let issues = run(sql);
1293 assert!(!issues.is_empty());
1294 let fixed = apply_all_fixes(sql, &issues);
1295 assert_eq!(
1296 fixed,
1297 "select\n 100::int::text,\n 126::int as bar,\n cast(\n 1 /* cast the value\n to an integer\n */ as int) as coo\nfrom foo;"
1298 );
1299 }
1300
1301 #[test]
1302 fn shorthand_to_cast_rewrites_nested_snowflake_path_casts() {
1303 let fixed = shorthand_to_cast("(trim(value:Longitude::varchar))::double").expect("rewrite");
1304 assert_eq!(
1305 fixed,
1306 "cast((trim(cast(value:Longitude as varchar))) as double)"
1307 );
1308 assert_eq!(
1309 shorthand_to_cast("col:a.b:c::varchar").expect("rewrite"),
1310 "cast(col:a.b:c as varchar)"
1311 );
1312 }
1313}