Skip to main content

flowscope_core/linter/rules/
cv_011.rs

1//! LINT_CV_011: Casting style.
2//!
3//! SQLFluff CV11 parity: detect mixed use of `::`, `CAST()`, and `CONVERT()`
4//! within the same statement and emit autofix edits to normalise to the
5//! preferred style.
6
7use crate::linter::config::LintConfig;
8use crate::linter::rule::{LintContext, LintRule};
9use crate::linter::visit::visit_expressions;
10use crate::types::{issue_codes, Issue, IssueAutofixApplicability, IssuePatchEdit, Span};
11use sqlparser::ast::{CastKind, DataType, Expr, Spanned, Statement};
12
13// ---------------------------------------------------------------------------
14// Configuration
15// ---------------------------------------------------------------------------
16
17#[derive(Clone, Copy, Debug, Eq, PartialEq)]
18enum PreferredTypeCastingStyle {
19    Consistent,
20    Shorthand,
21    Cast,
22    Convert,
23}
24
25impl PreferredTypeCastingStyle {
26    fn from_config(config: &LintConfig) -> Self {
27        match config
28            .rule_option_str(issue_codes::LINT_CV_011, "preferred_type_casting_style")
29            .unwrap_or("consistent")
30            .to_ascii_lowercase()
31            .as_str()
32        {
33            "shorthand" => Self::Shorthand,
34            "cast" => Self::Cast,
35            "convert" => Self::Convert,
36            _ => Self::Consistent,
37        }
38    }
39}
40
41// ---------------------------------------------------------------------------
42// Cast-expression descriptor
43// ---------------------------------------------------------------------------
44
45#[derive(Clone, Copy, Debug, Eq, PartialEq)]
46enum CastStyle {
47    FunctionCast,
48    DoubleColon,
49    Convert,
50}
51
52/// A single cast expression found in the statement.
53struct CastInstance {
54    style: CastStyle,
55    /// Byte range of the whole expression in the full SQL text.
56    start: usize,
57    end: usize,
58    /// Whether the cast contains embedded comments and should not be auto-fixed.
59    has_comments: bool,
60    /// For CONVERT: true if it has 3+ arguments (style argument) — can't be converted.
61    is_3arg_convert: bool,
62}
63
64// ---------------------------------------------------------------------------
65// Rule struct
66// ---------------------------------------------------------------------------
67
68pub struct ConventionCastingStyle {
69    preferred_style: PreferredTypeCastingStyle,
70}
71
72impl ConventionCastingStyle {
73    pub fn from_config(config: &LintConfig) -> Self {
74        Self {
75            preferred_style: PreferredTypeCastingStyle::from_config(config),
76        }
77    }
78}
79
80impl Default for ConventionCastingStyle {
81    fn default() -> Self {
82        Self {
83            preferred_style: PreferredTypeCastingStyle::Consistent,
84        }
85    }
86}
87
88impl LintRule for ConventionCastingStyle {
89    fn code(&self) -> &'static str {
90        issue_codes::LINT_CV_011
91    }
92
93    fn name(&self) -> &'static str {
94        "Casting style"
95    }
96
97    fn description(&self) -> &'static str {
98        "Enforce consistent type casting style."
99    }
100
101    fn check(&self, statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
102        let sql = ctx.sql;
103        let casts = collect_cast_instances(statement, sql);
104
105        if casts.is_empty() {
106            return Vec::new();
107        }
108
109        // Determine the target style.
110        let target = match self.preferred_style {
111            PreferredTypeCastingStyle::Consistent => casts[0].style,
112            PreferredTypeCastingStyle::Shorthand => CastStyle::DoubleColon,
113            PreferredTypeCastingStyle::Cast => CastStyle::FunctionCast,
114            PreferredTypeCastingStyle::Convert => CastStyle::Convert,
115        };
116
117        // Check if there is a violation at all.
118        let has_violation = casts.iter().any(|c| c.style != target);
119        if !has_violation {
120            return Vec::new();
121        }
122
123        let message = match self.preferred_style {
124            PreferredTypeCastingStyle::Consistent => {
125                "Use consistent casting style (avoid mixing CAST styles)."
126            }
127            PreferredTypeCastingStyle::Shorthand => "Use `::` shorthand casting style.",
128            PreferredTypeCastingStyle::Cast => "Use `CAST(...)` style casts.",
129            PreferredTypeCastingStyle::Convert => "Use `CONVERT(...)` style casts.",
130        };
131
132        // Emit one issue per non-conforming cast so that partially fixable
133        // statements (e.g. 3-arg CONVERT that can't be converted) still show
134        // improvement in the violation count after autofix.
135        let mut issues = Vec::new();
136        for cast in &casts {
137            if cast.style == target {
138                continue;
139            }
140
141            let mut issue =
142                Issue::info(issue_codes::LINT_CV_011, message).with_statement(ctx.statement_index);
143
144            if !cast.is_3arg_convert && !cast.has_comments {
145                let cast_text = &sql[cast.start..cast.end];
146                if let Some(replacement) = convert_cast(cast_text, cast.style, target) {
147                    issue = issue.with_autofix_edits(
148                        IssueAutofixApplicability::Unsafe,
149                        vec![IssuePatchEdit::new(
150                            Span::new(cast.start, cast.end),
151                            replacement,
152                        )],
153                    );
154                }
155            }
156
157            issues.push(issue);
158        }
159
160        issues
161    }
162}
163
164// ---------------------------------------------------------------------------
165// Collect all cast expressions with their source positions
166// ---------------------------------------------------------------------------
167
168fn collect_cast_instances(statement: &Statement, sql: &str) -> Vec<CastInstance> {
169    let mut casts = Vec::new();
170
171    visit_expressions(statement, &mut |expr| {
172        match expr {
173            Expr::Cast {
174                kind,
175                expr: inner,
176                data_type,
177                ..
178            } => {
179                let style = match kind {
180                    CastKind::DoubleColon => CastStyle::DoubleColon,
181                    CastKind::Cast | CastKind::TryCast | CastKind::SafeCast => {
182                        CastStyle::FunctionCast
183                    }
184                };
185
186                // For chained :: (inner is also ::), skip the inner — we handle
187                // the entire chain as one entry via the outermost.
188                let is_inner_chain = matches!(
189                    inner.as_ref(),
190                    Expr::Cast {
191                        kind: CastKind::DoubleColon,
192                        ..
193                    }
194                );
195
196                // Get the inner expression's byte range.
197                let inner_span = find_cast_span(sql, inner, kind.clone(), data_type);
198                if let Some((start, end)) = inner_span {
199                    let text = &sql[start..end];
200                    let has_comments = text.contains("--") || text.contains("/*");
201
202                    if style == CastStyle::DoubleColon && is_inner_chain {
203                        // Outermost chained :: — remove previously collected inner.
204                        casts.retain(|c: &CastInstance| c.start < start || c.end > end);
205                    }
206
207                    casts.push(CastInstance {
208                        style,
209                        start,
210                        end,
211                        has_comments,
212                        is_3arg_convert: false,
213                    });
214                }
215            }
216            Expr::Function(function)
217                if function.name.to_string().eq_ignore_ascii_case("CONVERT") =>
218            {
219                if let Some((start, mut end)) = expr_span_offsets(sql, expr) {
220                    // Function::span() may not include the closing paren.
221                    // Scan forward to include it.
222                    if end < sql.len() && sql.as_bytes().get(end) == Some(&b')') {
223                        end += 1;
224                    } else {
225                        // Try to find the closing paren after the span end.
226                        if let Some(close) = find_matching_close_paren(&sql[end..]) {
227                            end += close + 1;
228                        }
229                    }
230
231                    let text = &sql[start..end];
232                    let has_comments = text.contains("--") || text.contains("/*");
233
234                    let arg_count = match &function.args {
235                        sqlparser::ast::FunctionArguments::List(list) => list.args.len(),
236                        _ => 0,
237                    };
238
239                    casts.push(CastInstance {
240                        style: CastStyle::Convert,
241                        start,
242                        end,
243                        has_comments,
244                        is_3arg_convert: arg_count > 2,
245                    });
246                }
247            }
248            _ => {}
249        }
250    });
251
252    // Parser span extraction can miss parenthesized shorthand casts in some
253    // Snowflake semi-structured forms. Add a lightweight lexical fallback.
254    for (start, end) in scan_parenthesized_shorthand_cast_spans(sql) {
255        if casts.iter().any(|cast| {
256            cast.start == start && cast.end == end && cast.style == CastStyle::DoubleColon
257        }) {
258            continue;
259        }
260        let text = &sql[start..end];
261        casts.push(CastInstance {
262            style: CastStyle::DoubleColon,
263            start,
264            end,
265            has_comments: text.contains("--") || text.contains("/*"),
266            is_3arg_convert: false,
267        });
268    }
269
270    // Sort by position so first-seen logic works correctly.
271    casts.sort_by_key(|c| c.start);
272
273    // Deduplicate: remove entries whose ranges are fully contained within
274    // another entry's range (handles chained :: where both outer and inner
275    // are collected by the visitor). For overlapping shorthand casts, keep
276    // the wider range so we don't emit conflicting nested edits.
277    let mut deduped: Vec<CastInstance> = Vec::with_capacity(casts.len());
278    for cast in casts {
279        let mut dominated = false;
280        let mut replace_index = None;
281
282        for (index, other) in deduped.iter().enumerate() {
283            if other.start <= cast.start && other.end >= cast.end {
284                dominated = true;
285                break;
286            }
287            if cast.start <= other.start && cast.end >= other.end {
288                replace_index = Some(index);
289                break;
290            }
291            if cast.style == other.style
292                && spans_overlap(cast.start, cast.end, other.start, other.end)
293            {
294                let cast_len = cast.end.saturating_sub(cast.start);
295                let other_len = other.end.saturating_sub(other.start);
296                if cast_len > other_len {
297                    replace_index = Some(index);
298                } else {
299                    dominated = true;
300                }
301                break;
302            }
303        }
304
305        if dominated {
306            continue;
307        }
308
309        if let Some(index) = replace_index {
310            deduped[index] = cast;
311        } else {
312            deduped.push(cast);
313        }
314    }
315
316    deduped.sort_by_key(|cast| (cast.start, cast.end, cast.style as u8));
317    deduped.dedup_by(|left, right| left.start == right.start && left.end == right.end);
318    deduped
319}
320
321fn spans_overlap(left_start: usize, left_end: usize, right_start: usize, right_end: usize) -> bool {
322    left_start < right_end && right_start < left_end
323}
324
325fn scan_parenthesized_shorthand_cast_spans(sql: &str) -> Vec<(usize, usize)> {
326    let bytes = sql.as_bytes();
327    let mut out = Vec::new();
328    let mut index = 0usize;
329
330    while index + 1 < bytes.len() {
331        if bytes[index] != b':' || bytes[index + 1] != b':' {
332            index += 1;
333            continue;
334        }
335
336        let mut lhs_end = index;
337        while lhs_end > 0 && bytes[lhs_end - 1].is_ascii_whitespace() {
338            lhs_end -= 1;
339        }
340        if lhs_end == 0 || bytes[lhs_end - 1] != b')' {
341            index += 2;
342            continue;
343        }
344        let close_paren = lhs_end - 1;
345        let Some(open_paren) = find_matching_open_paren(bytes, close_paren) else {
346            index += 2;
347            continue;
348        };
349
350        let Some(type_end) = scan_parenthesized_shorthand_type_end(bytes, index + 2) else {
351            index += 2;
352            continue;
353        };
354
355        out.push((open_paren, type_end));
356        index = type_end;
357    }
358
359    out
360}
361
362fn scan_parenthesized_shorthand_type_end(bytes: &[u8], start: usize) -> Option<usize> {
363    let mut index = start;
364    let mut depth = 0i32;
365    let mut saw_any = false;
366
367    while index < bytes.len() {
368        match bytes[index] {
369            b'(' => {
370                depth += 1;
371                saw_any = true;
372                index += 1;
373            }
374            b')' if depth > 0 => {
375                depth -= 1;
376                index += 1;
377            }
378            b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_' | b'.' => {
379                saw_any = true;
380                index += 1;
381            }
382            b',' if depth > 0 => index += 1,
383            b' ' | b'\t' | b'\n' | b'\r' if depth > 0 => index += 1,
384            _ => break,
385        }
386    }
387
388    if saw_any {
389        Some(index)
390    } else {
391        None
392    }
393}
394
395fn find_matching_open_paren(bytes: &[u8], close_paren: usize) -> Option<usize> {
396    if bytes.get(close_paren).copied() != Some(b')') {
397        return None;
398    }
399    let mut depth = 1i32;
400    let mut cursor = close_paren;
401    while cursor > 0 {
402        cursor -= 1;
403        match bytes[cursor] {
404            b')' => depth += 1,
405            b'(' => {
406                depth -= 1;
407                if depth == 0 {
408                    return Some(cursor);
409                }
410            }
411            _ => {}
412        }
413    }
414    None
415}
416
417/// Find the full source span of a CAST/:: expression.
418///
419/// sqlparser's `Expr::Cast.span()` only returns the inner expression's span,
420/// so we compute the full span by:
421/// - For CAST/TRY_CAST/SAFE_CAST: scan backwards from inner expr to find the
422///   keyword, then forwards to find the closing paren.
423/// - For `::`: find the deepest base expression, use its span as start, then
424///   scan forwards through all `::type` segments to find the outermost end.
425fn find_cast_span(
426    sql: &str,
427    inner: &Expr,
428    kind: CastKind,
429    data_type: &DataType,
430) -> Option<(usize, usize)> {
431    match kind {
432        CastKind::Cast | CastKind::TryCast | CastKind::SafeCast => {
433            let (inner_start, inner_end) = expr_span_offsets(sql, inner)?;
434
435            // Scan backwards from inner_start to find `CAST(`, `TRY_CAST(`, or `SAFE_CAST(`.
436            let before = &sql[..inner_start];
437            let paren_pos = before.rfind('(')?;
438            let before_paren = before[..paren_pos].trim_end();
439            let kw = match kind {
440                CastKind::TryCast => "TRY_CAST",
441                CastKind::SafeCast => "SAFE_CAST",
442                _ => "CAST",
443            };
444            let kw_len = kw.len();
445            if before_paren.len() < kw_len {
446                return None;
447            }
448            let kw_candidate = &before_paren[before_paren.len() - kw_len..];
449            if !kw_candidate.eq_ignore_ascii_case(kw) {
450                return None;
451            }
452            let start = before_paren.len() - kw_len;
453
454            // Scan forwards from inner_end to find closing paren.
455            let after = &sql[inner_end..];
456            let close = find_matching_close_paren(after)?;
457            let end = inner_end + close + 1;
458
459            Some((start, end))
460        }
461        CastKind::DoubleColon => {
462            // Find the deepest non-:: base expression to get the real start.
463            let base = deepest_base_expr(inner);
464            let (base_start, base_end) = expr_span_offsets(sql, base)?;
465
466            // Scan forward from base_end through all `::type` segments.
467            let type_str = data_type.to_string();
468            let mut pos = base_end;
469            loop {
470                let after = &sql[pos..];
471                let dc_pos = match after.find("::") {
472                    Some(p) => p,
473                    None => break,
474                };
475                let type_start = pos + dc_pos + 2;
476                let type_len = source_type_len(sql, type_start, &type_str);
477                if type_len == 0 {
478                    break;
479                }
480                pos = type_start + type_len;
481                // Check if this type matches the outermost data_type.
482                let this_type = &sql[type_start..pos];
483                if this_type.eq_ignore_ascii_case(&type_str) {
484                    break;
485                }
486            }
487
488            Some((base_start, pos))
489        }
490    }
491}
492
493/// Walk down the `Expr::Cast { kind: DoubleColon }` chain to find the
494/// deepest non-Cast base expression.
495fn deepest_base_expr(expr: &Expr) -> &Expr {
496    match expr {
497        Expr::Cast {
498            kind: CastKind::DoubleColon,
499            expr: inner,
500            ..
501        } => deepest_base_expr(inner),
502        _ => expr,
503    }
504}
505
506/// Find the position of the matching closing paren in `text`, accounting for
507/// nesting. Returns offset relative to `text` start.
508fn find_matching_close_paren(text: &str) -> Option<usize> {
509    let mut depth = 0i32;
510    let bytes = text.as_bytes();
511    let mut i = 0;
512    while i < bytes.len() {
513        match bytes[i] {
514            b'(' => depth += 1,
515            b')' => {
516                if depth == 0 {
517                    return Some(i);
518                }
519                depth -= 1;
520            }
521            b'\'' | b'"' => {
522                let quote = bytes[i];
523                i += 1;
524                while i < bytes.len() && bytes[i] != quote {
525                    if bytes[i] == b'\\' {
526                        i += 1;
527                    }
528                    i += 1;
529                }
530            }
531            _ => {}
532        }
533        i += 1;
534    }
535    None
536}
537
538/// Determine the length of a type name in the source SQL starting at `pos`.
539/// The type in source may use different casing or spacing than `DataType::to_string()`.
540fn source_type_len(sql: &str, pos: usize, type_display: &str) -> usize {
541    // The type ends at the first character that can't be part of a type name.
542    // Type names consist of alphanumeric chars, `_`, `(`, `)`, `,`, spaces
543    // (for compound types like `CHARACTER VARYING(10)`).
544    // We use the Display length as a guide but match against the actual source.
545    let remaining = &sql[pos..];
546    let display_len = type_display.len();
547
548    // Try exact match first (common case).
549    if remaining.len() >= display_len && remaining[..display_len].eq_ignore_ascii_case(type_display)
550    {
551        return display_len;
552    }
553
554    // Fallback: scan forward through identifier-like characters.
555    let mut len = 0;
556    let mut depth = 0i32;
557    for &b in remaining.as_bytes() {
558        match b {
559            b'(' => {
560                depth += 1;
561                len += 1;
562            }
563            b')' if depth > 0 => {
564                depth -= 1;
565                len += 1;
566            }
567            b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_' => len += 1,
568            b' ' | b'\t' | b'\n' | b',' if depth > 0 => len += 1,
569            _ => break,
570        }
571    }
572    len
573}
574
575// ---------------------------------------------------------------------------
576// Convert a cast expression to the target style
577// ---------------------------------------------------------------------------
578
579fn convert_cast(cast_text: &str, from_style: CastStyle, to_style: CastStyle) -> Option<String> {
580    match (from_style, to_style) {
581        (CastStyle::FunctionCast, CastStyle::DoubleColon) => cast_to_shorthand(cast_text),
582        (CastStyle::FunctionCast, CastStyle::Convert) => cast_to_convert(cast_text),
583        (CastStyle::DoubleColon, CastStyle::FunctionCast) => shorthand_to_cast(cast_text),
584        (CastStyle::DoubleColon, CastStyle::Convert) => shorthand_to_convert(cast_text),
585        (CastStyle::Convert, CastStyle::FunctionCast) => convert_to_cast(cast_text),
586        (CastStyle::Convert, CastStyle::DoubleColon) => convert_to_shorthand(cast_text),
587        _ => None,
588    }
589}
590
591/// Parse the interior of `CAST(expr AS type)` from raw text.
592/// Returns `(expr_text, type_text)`.
593fn parse_cast_interior(cast_text: &str) -> Option<(&str, &str)> {
594    let open = cast_text.find('(')?;
595    let close = cast_text.rfind(')')?;
596    let inner = cast_text[open + 1..close].trim();
597
598    let as_pos = find_top_level_as(inner)?;
599    let expr_part = inner[..as_pos].trim();
600    // The `AS` keyword is typically 2 chars, but ` AS ` starts with a space.
601    // `as_pos` points to the space/newline before `AS`.
602    let type_part = inner[as_pos + 1..].trim();
603    // Strip the leading `AS` keyword.
604    let type_part = type_part
605        .strip_prefix("AS")
606        .or_else(|| type_part.strip_prefix("as"))
607        .or_else(|| type_part.strip_prefix("As"))
608        .or_else(|| type_part.strip_prefix("aS"))
609        .unwrap_or(type_part)
610        .trim();
611    Some((expr_part, type_part))
612}
613
614/// Find the position of top-level whitespace-AS-whitespace in CAST interior.
615fn find_top_level_as(inner: &str) -> Option<usize> {
616    let bytes = inner.as_bytes();
617    let mut depth = 0i32;
618    let mut i = 0;
619    while i < bytes.len() {
620        match bytes[i] {
621            b'(' => depth += 1,
622            b')' => depth -= 1,
623            b'\'' | b'"' => {
624                let quote = bytes[i];
625                i += 1;
626                while i < bytes.len() && bytes[i] != quote {
627                    if bytes[i] == b'\\' {
628                        i += 1;
629                    }
630                    i += 1;
631                }
632            }
633            _ if depth == 0
634                && is_whitespace_byte(bytes[i])
635                && i + 3 < bytes.len()
636                && bytes[i + 1].eq_ignore_ascii_case(&b'A')
637                && bytes[i + 2].eq_ignore_ascii_case(&b'S')
638                && is_whitespace_byte(bytes[i + 3]) =>
639            {
640                return Some(i);
641            }
642            _ => {}
643        }
644        i += 1;
645    }
646    None
647}
648
649fn is_whitespace_byte(b: u8) -> bool {
650    matches!(b, b' ' | b'\t' | b'\n' | b'\r')
651}
652
653/// `CAST(expr AS type)` → `expr::type` or `(expr)::type`.
654fn cast_to_shorthand(cast_text: &str) -> Option<String> {
655    let (expr, type_text) = parse_cast_interior(cast_text)?;
656    let needs_parens = expr_is_complex(expr);
657    if needs_parens {
658        Some(format!("({expr})::{type_text}"))
659    } else {
660        Some(format!("{expr}::{type_text}"))
661    }
662}
663
664/// `CAST(expr AS type)` → `convert(type, expr)`.
665fn cast_to_convert(cast_text: &str) -> Option<String> {
666    let (expr, type_text) = parse_cast_interior(cast_text)?;
667    Some(format!("convert({type_text}, {expr})"))
668}
669
670/// `CONVERT(type, expr)` → `cast(expr as type)`.
671fn convert_to_cast(convert_text: &str) -> Option<String> {
672    let (type_text, expr) = parse_convert_interior(convert_text)?;
673    Some(format!("cast({expr} as {type_text})"))
674}
675
676/// `CONVERT(type, expr)` → `expr::type` or `(expr)::type`.
677fn convert_to_shorthand(convert_text: &str) -> Option<String> {
678    let (type_text, expr) = parse_convert_interior(convert_text)?;
679    let needs_parens = expr_is_complex(expr);
680    if needs_parens {
681        Some(format!("({expr})::{type_text}"))
682    } else {
683        Some(format!("{expr}::{type_text}"))
684    }
685}
686
687/// `expr::type` → `cast(expr as type)`.
688/// Handles chained casts: `expr::t1::t2` → `cast(cast(expr as t1) as t2)`.
689fn shorthand_to_cast(shorthand_text: &str) -> Option<String> {
690    let parts = split_shorthand_chain(shorthand_text)?;
691    if parts.len() < 2 {
692        return None;
693    }
694    let mut result = rewrite_nested_simple_shorthand_to_cast(parts[0]);
695    for type_part in &parts[1..] {
696        result = format!("cast({result} as {type_part})");
697    }
698    Some(result)
699}
700
701/// `expr::type` → `convert(type, expr)`.
702/// Handles chained casts: `expr::t1::t2` → `convert(t2, convert(t1, expr))`.
703fn shorthand_to_convert(shorthand_text: &str) -> Option<String> {
704    let parts = split_shorthand_chain(shorthand_text)?;
705    if parts.len() < 2 {
706        return None;
707    }
708    let mut result = parts[0].to_string();
709    for type_part in &parts[1..] {
710        result = format!("convert({type_part}, {result})");
711    }
712    Some(result)
713}
714
715/// Split a `::` chain like `100::int::text` into `["100", "int", "text"]`.
716fn split_shorthand_chain(text: &str) -> Option<Vec<&str>> {
717    let mut parts = Vec::new();
718    let mut depth = 0i32;
719    let bytes = text.as_bytes();
720    let mut last_split = 0;
721
722    let mut i = 0;
723    while i < bytes.len() {
724        match bytes[i] {
725            b'(' => depth += 1,
726            b')' => depth -= 1,
727            b'\'' | b'"' => {
728                let quote = bytes[i];
729                i += 1;
730                while i < bytes.len() && bytes[i] != quote {
731                    if bytes[i] == b'\\' {
732                        i += 1;
733                    }
734                    i += 1;
735                }
736            }
737            b':' if depth == 0 && i + 1 < bytes.len() && bytes[i + 1] == b':' => {
738                parts.push(&text[last_split..i]);
739                i += 2;
740                last_split = i;
741                continue;
742            }
743            _ => {}
744        }
745        i += 1;
746    }
747    parts.push(&text[last_split..]);
748
749    if parts.len() >= 2 {
750        Some(parts)
751    } else {
752        None
753    }
754}
755
756/// Rewrites simple nested shorthand fragments in an expression, e.g.
757/// `value:Longitude::varchar` -> `cast(value:Longitude as varchar)`.
758/// This is intentionally conservative: it only rewrites contiguous identifier
759/// chains and leaves complex nested expressions to the outer conversion pass.
760fn rewrite_nested_simple_shorthand_to_cast(expr: &str) -> String {
761    let bytes = expr.as_bytes();
762    let mut index = 0usize;
763    let mut out = String::with_capacity(expr.len() + 16);
764
765    while index < bytes.len() {
766        let Some(rel_dc) = expr[index..].find("::") else {
767            out.push_str(&expr[index..]);
768            break;
769        };
770        let dc = index + rel_dc;
771
772        let mut lhs_start = dc;
773        while lhs_start > 0 && is_simple_shorthand_lhs_char(bytes[lhs_start - 1]) {
774            lhs_start -= 1;
775        }
776        if lhs_start == dc {
777            out.push_str(&expr[index..dc + 2]);
778            index = dc + 2;
779            continue;
780        }
781
782        let mut rhs_end = dc + 2;
783        while rhs_end < bytes.len() && is_simple_type_char(bytes[rhs_end]) {
784            rhs_end += 1;
785        }
786        if rhs_end == dc + 2 {
787            out.push_str(&expr[index..dc + 2]);
788            index = dc + 2;
789            continue;
790        }
791
792        out.push_str(&expr[index..lhs_start]);
793        out.push_str("cast(");
794        out.push_str(&expr[lhs_start..dc]);
795        out.push_str(" as ");
796        out.push_str(&expr[dc + 2..rhs_end]);
797        out.push(')');
798        index = rhs_end;
799    }
800
801    out
802}
803
804fn is_simple_shorthand_lhs_char(byte: u8) -> bool {
805    byte.is_ascii_alphanumeric()
806        || matches!(
807            byte,
808            b'_' | b'.' | b':' | b'$' | b'@' | b'"' | b'`' | b'[' | b']'
809        )
810}
811
812fn is_simple_type_char(byte: u8) -> bool {
813    byte.is_ascii_alphanumeric()
814        || matches!(
815            byte,
816            b'_' | b' ' | b'\t' | b'\n' | b'\r' | b'(' | b')' | b','
817        )
818}
819
820/// Parse interior of `CONVERT(type, expr)`. Returns `(type_text, expr_text)`.
821fn parse_convert_interior(convert_text: &str) -> Option<(&str, &str)> {
822    let open = convert_text.find('(')?;
823    let close = convert_text.rfind(')')?;
824    let inner = convert_text[open + 1..close].trim();
825    let comma = find_top_level_comma(inner)?;
826    let type_part = inner[..comma].trim();
827    let expr_part = inner[comma + 1..].trim();
828    Some((type_part, expr_part))
829}
830
831/// Find position of the first top-level comma.
832fn find_top_level_comma(inner: &str) -> Option<usize> {
833    let bytes = inner.as_bytes();
834    let mut depth = 0i32;
835    let mut i = 0;
836    while i < bytes.len() {
837        match bytes[i] {
838            b'(' => depth += 1,
839            b')' => depth -= 1,
840            b'\'' | b'"' => {
841                let quote = bytes[i];
842                i += 1;
843                while i < bytes.len() && bytes[i] != quote {
844                    if bytes[i] == b'\\' {
845                        i += 1;
846                    }
847                    i += 1;
848                }
849            }
850            b',' if depth == 0 => return Some(i),
851            _ => {}
852        }
853        i += 1;
854    }
855    None
856}
857
858/// Returns true if the expression text is "complex" and needs parenthesization
859/// when used in shorthand `::` form.
860fn expr_is_complex(expr: &str) -> bool {
861    let trimmed = expr.trim();
862    let bytes = trimmed.as_bytes();
863    let mut depth = 0i32;
864    for (i, &b) in bytes.iter().enumerate() {
865        match b {
866            b'(' => depth += 1,
867            b')' => depth -= 1,
868            b'\'' | b'"' => return false, // string literal — not complex
869            b'|' | b'+' | b'-' | b'*' | b'/' | b'%' if depth == 0 => {
870                if b == b'-' && i == 0 {
871                    continue;
872                }
873                return true;
874            }
875            b' ' | b'\t' | b'\n' if depth == 0 => return true,
876            _ => {}
877        }
878    }
879    false
880}
881
882// ---------------------------------------------------------------------------
883// Span helpers
884// ---------------------------------------------------------------------------
885
886fn expr_span_offsets(sql: &str, expr: &Expr) -> Option<(usize, usize)> {
887    let span = expr.span();
888    if span.start.line == 0 || span.start.column == 0 || span.end.line == 0 || span.end.column == 0
889    {
890        return None;
891    }
892
893    let start = line_col_to_offset(sql, span.start.line as usize, span.start.column as usize)?;
894    let end = line_col_to_offset(sql, span.end.line as usize, span.end.column as usize)?;
895    (end >= start).then_some((start, end))
896}
897
898fn line_col_to_offset(sql: &str, line: usize, column: usize) -> Option<usize> {
899    if line == 0 || column == 0 {
900        return None;
901    }
902
903    let mut current_line = 1usize;
904    let mut line_start = 0usize;
905
906    for (idx, ch) in sql.char_indices() {
907        if current_line == line {
908            break;
909        }
910        if ch == '\n' {
911            current_line += 1;
912            line_start = idx + ch.len_utf8();
913        }
914    }
915
916    if current_line != line {
917        return None;
918    }
919
920    let mut col = 1usize;
921    for (idx, _ch) in sql[line_start..].char_indices() {
922        if col == column {
923            return Some(line_start + idx);
924        }
925        col += 1;
926    }
927    if col == column {
928        return Some(sql.len());
929    }
930    None
931}
932
933#[cfg(test)]
934mod tests {
935    use super::*;
936    use crate::parser::parse_sql;
937
938    fn run(sql: &str) -> Vec<Issue> {
939        let statements = parse_sql(sql).expect("parse");
940        let rule = ConventionCastingStyle::default();
941        statements
942            .iter()
943            .enumerate()
944            .flat_map(|(index, statement)| {
945                rule.check(
946                    statement,
947                    &LintContext {
948                        sql,
949                        statement_range: 0..sql.len(),
950                        statement_index: index,
951                    },
952                )
953            })
954            .collect()
955    }
956
957    fn run_with_config(sql: &str, config: &LintConfig) -> Vec<Issue> {
958        let statements = parse_sql(sql).expect("parse");
959        let rule = ConventionCastingStyle::from_config(config);
960        statements
961            .iter()
962            .enumerate()
963            .flat_map(|(index, statement)| {
964                rule.check(
965                    statement,
966                    &LintContext {
967                        sql,
968                        statement_range: 0..sql.len(),
969                        statement_index: index,
970                    },
971                )
972            })
973            .collect()
974    }
975
976    fn apply_edits(sql: &str, edits: &[IssuePatchEdit]) -> String {
977        let mut sorted: Vec<_> = edits.iter().collect();
978        sorted.sort_by_key(|e| std::cmp::Reverse(e.span.start));
979        let mut result = sql.to_string();
980        for edit in sorted {
981            result.replace_range(edit.span.start..edit.span.end, &edit.replacement);
982        }
983        result
984    }
985
986    fn collect_all_edits(issues: &[Issue]) -> Vec<&IssuePatchEdit> {
987        issues
988            .iter()
989            .filter_map(|i| i.autofix.as_ref())
990            .flat_map(|a| a.edits.iter())
991            .collect()
992    }
993
994    fn apply_all_fixes(sql: &str, issues: &[Issue]) -> String {
995        let edits = collect_all_edits(issues);
996        let owned: Vec<IssuePatchEdit> = edits.into_iter().cloned().collect();
997        apply_edits(sql, &owned)
998    }
999
1000    #[test]
1001    fn flags_mixed_casting_styles() {
1002        let issues = run("SELECT CAST(amount AS INT)::TEXT FROM t");
1003        assert_eq!(issues.len(), 1);
1004        assert_eq!(issues[0].code, issue_codes::LINT_CV_011);
1005    }
1006
1007    #[test]
1008    fn does_not_flag_single_casting_style() {
1009        assert!(run("SELECT amount::INT FROM t").is_empty());
1010        assert!(run("SELECT CAST(amount AS INT) FROM t").is_empty());
1011    }
1012
1013    #[test]
1014    fn does_not_flag_cast_like_tokens_inside_string_literal() {
1015        assert!(run("SELECT 'value::TEXT and CAST(value AS INT)' AS note").is_empty());
1016    }
1017
1018    #[test]
1019    fn flags_mixed_try_cast_and_double_colon_styles() {
1020        let issues = run("SELECT TRY_CAST(amount AS INT)::TEXT FROM t");
1021        assert_eq!(issues.len(), 1);
1022        assert_eq!(issues[0].code, issue_codes::LINT_CV_011);
1023    }
1024
1025    #[test]
1026    fn shorthand_preference_flags_cast_function_style() {
1027        let config = LintConfig {
1028            enabled: true,
1029            disabled_rules: vec![],
1030            rule_configs: std::collections::BTreeMap::from([(
1031                "convention.casting_style".to_string(),
1032                serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1033            )]),
1034        };
1035        let rule = ConventionCastingStyle::from_config(&config);
1036        let sql = "SELECT CAST(amount AS INT) FROM t";
1037        let statements = parse_sql(sql).expect("parse");
1038        let issues = rule.check(
1039            &statements[0],
1040            &LintContext {
1041                sql,
1042                statement_range: 0..sql.len(),
1043                statement_index: 0,
1044            },
1045        );
1046        assert_eq!(issues.len(), 1);
1047    }
1048
1049    #[test]
1050    fn cast_preference_flags_shorthand_style() {
1051        let config = LintConfig {
1052            enabled: true,
1053            disabled_rules: vec![],
1054            rule_configs: std::collections::BTreeMap::from([(
1055                "LINT_CV_011".to_string(),
1056                serde_json::json!({"preferred_type_casting_style": "cast"}),
1057            )]),
1058        };
1059        let rule = ConventionCastingStyle::from_config(&config);
1060        let sql = "SELECT amount::INT FROM t";
1061        let statements = parse_sql(sql).expect("parse");
1062        let issues = rule.check(
1063            &statements[0],
1064            &LintContext {
1065                sql,
1066                statement_range: 0..sql.len(),
1067                statement_index: 0,
1068            },
1069        );
1070        assert_eq!(issues.len(), 1);
1071    }
1072
1073    // -----------------------------------------------------------------------
1074    // Autofix tests — SQLFluff CV11 fixture parity
1075    // -----------------------------------------------------------------------
1076
1077    #[test]
1078    fn autofix_consistent_prior_convert() {
1079        let sql = "select\n    convert(int, 1) as bar,\n    100::int::text,\n    cast(10\n    as text) as coo\nfrom foo;";
1080        let issues = run(sql);
1081        assert!(!issues.is_empty());
1082        let fixed = apply_all_fixes(sql, &issues);
1083        assert_eq!(
1084            fixed,
1085            "select\n    convert(int, 1) as bar,\n    convert(text, convert(int, 100)),\n    convert(text, 10) as coo\nfrom foo;"
1086        );
1087    }
1088
1089    #[test]
1090    fn autofix_consistent_prior_cast() {
1091        let sql = "select\n    cast(10 as text) as coo,\n    convert(int, 1) as bar,\n    100::int::text,\nfrom foo;";
1092        let issues = run(sql);
1093        assert!(!issues.is_empty());
1094        let fixed = apply_all_fixes(sql, &issues);
1095        assert_eq!(
1096            fixed,
1097            "select\n    cast(10 as text) as coo,\n    cast(1 as int) as bar,\n    cast(cast(100 as int) as text),\nfrom foo;"
1098        );
1099    }
1100
1101    #[test]
1102    fn autofix_consistent_prior_shorthand() {
1103        let sql = "select\n    100::int::text,\n    cast(10 as text) as coo,\n    convert(int, 1) as bar\nfrom foo;";
1104        let issues = run(sql);
1105        assert!(!issues.is_empty());
1106        let fixed = apply_all_fixes(sql, &issues);
1107        assert_eq!(
1108            fixed,
1109            "select\n    100::int::text,\n    10::text as coo,\n    1::int as bar\nfrom foo;"
1110        );
1111    }
1112
1113    #[test]
1114    fn autofix_config_cast() {
1115        let config = LintConfig {
1116            enabled: true,
1117            disabled_rules: vec![],
1118            rule_configs: std::collections::BTreeMap::from([(
1119                "convention.casting_style".to_string(),
1120                serde_json::json!({"preferred_type_casting_style": "cast"}),
1121            )]),
1122        };
1123        let sql = "select\n    convert(int, 1) as bar,\n    100::int::text,\n    cast(10 as text) as coo\nfrom foo;";
1124        let issues = run_with_config(sql, &config);
1125        assert!(!issues.is_empty());
1126        let fixed = apply_all_fixes(sql, &issues);
1127        assert_eq!(
1128            fixed,
1129            "select\n    cast(1 as int) as bar,\n    cast(cast(100 as int) as text),\n    cast(10 as text) as coo\nfrom foo;"
1130        );
1131    }
1132
1133    #[test]
1134    fn autofix_config_convert() {
1135        let config = LintConfig {
1136            enabled: true,
1137            disabled_rules: vec![],
1138            rule_configs: std::collections::BTreeMap::from([(
1139                "convention.casting_style".to_string(),
1140                serde_json::json!({"preferred_type_casting_style": "convert"}),
1141            )]),
1142        };
1143        let sql = "select\n    convert(int, 1) as bar,\n    100::int::text,\n    cast(10 as text) as coo\nfrom foo;";
1144        let issues = run_with_config(sql, &config);
1145        assert!(!issues.is_empty());
1146        let fixed = apply_all_fixes(sql, &issues);
1147        assert_eq!(
1148            fixed,
1149            "select\n    convert(int, 1) as bar,\n    convert(text, convert(int, 100)),\n    convert(text, 10) as coo\nfrom foo;"
1150        );
1151    }
1152
1153    #[test]
1154    fn autofix_config_shorthand() {
1155        let config = LintConfig {
1156            enabled: true,
1157            disabled_rules: vec![],
1158            rule_configs: std::collections::BTreeMap::from([(
1159                "convention.casting_style".to_string(),
1160                serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1161            )]),
1162        };
1163        let sql = "select\n    convert(int, 1) as bar,\n    100::int::text,\n    cast(10 as text) as coo\nfrom foo;";
1164        let issues = run_with_config(sql, &config);
1165        assert!(!issues.is_empty());
1166        let fixed = apply_all_fixes(sql, &issues);
1167        assert_eq!(
1168            fixed,
1169            "select\n    1::int as bar,\n    100::int::text,\n    10::text as coo\nfrom foo;"
1170        );
1171    }
1172
1173    #[test]
1174    fn autofix_3arg_convert_skipped_config_cast() {
1175        let config = LintConfig {
1176            enabled: true,
1177            disabled_rules: vec![],
1178            rule_configs: std::collections::BTreeMap::from([(
1179                "convention.casting_style".to_string(),
1180                serde_json::json!({"preferred_type_casting_style": "cast"}),
1181            )]),
1182        };
1183        let sql = "select\n    convert(int, 1, 126) as bar,\n    100::int::text,\n    cast(10 as text) as coo\nfrom foo;";
1184        let issues = run_with_config(sql, &config);
1185        assert!(!issues.is_empty());
1186        let fixed = apply_all_fixes(sql, &issues);
1187        assert_eq!(
1188            fixed,
1189            "select\n    convert(int, 1, 126) as bar,\n    cast(cast(100 as int) as text),\n    cast(10 as text) as coo\nfrom foo;"
1190        );
1191    }
1192
1193    #[test]
1194    fn autofix_3arg_convert_skipped_config_shorthand() {
1195        let config = LintConfig {
1196            enabled: true,
1197            disabled_rules: vec![],
1198            rule_configs: std::collections::BTreeMap::from([(
1199                "convention.casting_style".to_string(),
1200                serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1201            )]),
1202        };
1203        let sql = "select\n    convert(int, 1, 126) as bar,\n    100::int::text,\n    cast(10 as text) as coo\nfrom foo;";
1204        let issues = run_with_config(sql, &config);
1205        assert!(!issues.is_empty());
1206        let fixed = apply_all_fixes(sql, &issues);
1207        assert_eq!(
1208            fixed,
1209            "select\n    convert(int, 1, 126) as bar,\n    100::int::text,\n    10::text as coo\nfrom foo;"
1210        );
1211    }
1212
1213    #[test]
1214    fn autofix_parenthesize_complex_expr_shorthand_from_cast() {
1215        let config = LintConfig {
1216            enabled: true,
1217            disabled_rules: vec![],
1218            rule_configs: std::collections::BTreeMap::from([(
1219                "convention.casting_style".to_string(),
1220                serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1221            )]),
1222        };
1223        let sql = "select\n    id::int,\n    cast(calendar_date||' 11:00:00' as timestamp) as calendar_datetime\nfrom foo;";
1224        let issues = run_with_config(sql, &config);
1225        assert!(!issues.is_empty());
1226        let fixed = apply_all_fixes(sql, &issues);
1227        assert_eq!(
1228            fixed,
1229            "select\n    id::int,\n    (calendar_date||' 11:00:00')::timestamp as calendar_datetime\nfrom foo;"
1230        );
1231    }
1232
1233    #[test]
1234    fn autofix_parenthesize_complex_expr_shorthand_from_convert() {
1235        let config = LintConfig {
1236            enabled: true,
1237            disabled_rules: vec![],
1238            rule_configs: std::collections::BTreeMap::from([(
1239                "convention.casting_style".to_string(),
1240                serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1241            )]),
1242        };
1243        let sql = "select\n    id::int,\n    convert(timestamp, calendar_date||' 11:00:00') as calendar_datetime\nfrom foo;";
1244        let issues = run_with_config(sql, &config);
1245        assert!(!issues.is_empty());
1246        let fixed = apply_all_fixes(sql, &issues);
1247        assert_eq!(
1248            fixed,
1249            "select\n    id::int,\n    (calendar_date||' 11:00:00')::timestamp as calendar_datetime\nfrom foo;"
1250        );
1251    }
1252
1253    #[test]
1254    fn autofix_comment_cast_skipped() {
1255        let sql = "select\n    cast(10 as text) as coo,\n    convert( -- Convert the value\n        int, /*\n              to an integer\n            */ 1) as bar,\n    100::int::text\nfrom foo;";
1256        let issues = run(sql);
1257        assert!(!issues.is_empty());
1258        let fixed = apply_all_fixes(sql, &issues);
1259        assert_eq!(
1260            fixed,
1261            "select\n    cast(10 as text) as coo,\n    convert( -- Convert the value\n        int, /*\n              to an integer\n            */ 1) as bar,\n    cast(cast(100 as int) as text)\nfrom foo;"
1262        );
1263    }
1264
1265    #[test]
1266    fn autofix_3arg_convert_consistent_prior_cast() {
1267        let sql = "select\n    cast(10 as text) as coo,\n    convert(int, 1, 126) as bar,\n    100::int::text\nfrom foo;";
1268        let issues = run(sql);
1269        assert!(!issues.is_empty());
1270        let fixed = apply_all_fixes(sql, &issues);
1271        assert_eq!(
1272            fixed,
1273            "select\n    cast(10 as text) as coo,\n    convert(int, 1, 126) as bar,\n    cast(cast(100 as int) as text)\nfrom foo;"
1274        );
1275    }
1276
1277    #[test]
1278    fn autofix_comment_prior_convert_shorthand_fixed() {
1279        let sql = "select\n    convert(int, 126) as bar,\n    cast(\n    1 /* cast the value\n        to an integer\n      */ as int) as coo,\n    100::int::text\nfrom foo;";
1280        let issues = run(sql);
1281        assert!(!issues.is_empty());
1282        let fixed = apply_all_fixes(sql, &issues);
1283        assert_eq!(
1284            fixed,
1285            "select\n    convert(int, 126) as bar,\n    cast(\n    1 /* cast the value\n        to an integer\n      */ as int) as coo,\n    convert(text, convert(int, 100))\nfrom foo;"
1286        );
1287    }
1288
1289    #[test]
1290    fn autofix_comment_prior_shorthand_convert_fixed() {
1291        let sql = "select\n    100::int::text,\n    convert(int, 126) as bar,\n    cast(\n    1 /* cast the value\n        to an integer\n      */ as int) as coo\nfrom foo;";
1292        let issues = run(sql);
1293        assert!(!issues.is_empty());
1294        let fixed = apply_all_fixes(sql, &issues);
1295        assert_eq!(
1296            fixed,
1297            "select\n    100::int::text,\n    126::int as bar,\n    cast(\n    1 /* cast the value\n        to an integer\n      */ as int) as coo\nfrom foo;"
1298        );
1299    }
1300
1301    #[test]
1302    fn shorthand_to_cast_rewrites_nested_snowflake_path_casts() {
1303        let fixed = shorthand_to_cast("(trim(value:Longitude::varchar))::double").expect("rewrite");
1304        assert_eq!(
1305            fixed,
1306            "cast((trim(cast(value:Longitude as varchar))) as double)"
1307        );
1308        assert_eq!(
1309            shorthand_to_cast("col:a.b:c::varchar").expect("rewrite"),
1310            "cast(col:a.b:c as varchar)"
1311        );
1312    }
1313}