Skip to main content

flowscope_core/linter/rules/
cv_011.rs

1//! LINT_CV_011: Casting style.
2//!
3//! SQLFluff CV11 parity: detect mixed use of `::`, `CAST()`, and `CONVERT()`
4//! within the same statement and emit autofix edits to normalise to the
5//! preferred style.
6
7use crate::linter::config::LintConfig;
8use crate::linter::rule::{LintContext, LintRule};
9use crate::linter::visit::visit_expressions;
10use crate::types::{issue_codes, Issue, IssueAutofixApplicability, IssuePatchEdit, Span};
11use sqlparser::ast::{CastKind, DataType, Expr, Spanned, Statement};
12
13// ---------------------------------------------------------------------------
14// Configuration
15// ---------------------------------------------------------------------------
16
17#[derive(Clone, Copy, Debug, Eq, PartialEq)]
18enum PreferredTypeCastingStyle {
19    Consistent,
20    Shorthand,
21    Cast,
22    Convert,
23}
24
25impl PreferredTypeCastingStyle {
26    fn from_config(config: &LintConfig) -> Self {
27        match config
28            .rule_option_str(issue_codes::LINT_CV_011, "preferred_type_casting_style")
29            .unwrap_or("consistent")
30            .to_ascii_lowercase()
31            .as_str()
32        {
33            "shorthand" => Self::Shorthand,
34            "cast" => Self::Cast,
35            "convert" => Self::Convert,
36            _ => Self::Consistent,
37        }
38    }
39}
40
41// ---------------------------------------------------------------------------
42// Cast-expression descriptor
43// ---------------------------------------------------------------------------
44
45#[derive(Clone, Copy, Debug, Eq, PartialEq)]
46enum CastStyle {
47    FunctionCast,
48    DoubleColon,
49    Convert,
50}
51
52/// A single cast expression found in the statement.
53struct CastInstance {
54    style: CastStyle,
55    /// Byte range of the whole expression in the full SQL text.
56    start: usize,
57    end: usize,
58    /// Whether the cast contains embedded comments and should not be auto-fixed.
59    has_comments: bool,
60    /// For CONVERT: true if it has 3+ arguments (style argument) — can't be converted.
61    is_3arg_convert: bool,
62}
63
64// ---------------------------------------------------------------------------
65// Rule struct
66// ---------------------------------------------------------------------------
67
68pub struct ConventionCastingStyle {
69    preferred_style: PreferredTypeCastingStyle,
70}
71
72impl ConventionCastingStyle {
73    pub fn from_config(config: &LintConfig) -> Self {
74        Self {
75            preferred_style: PreferredTypeCastingStyle::from_config(config),
76        }
77    }
78}
79
80impl Default for ConventionCastingStyle {
81    fn default() -> Self {
82        Self {
83            preferred_style: PreferredTypeCastingStyle::Consistent,
84        }
85    }
86}
87
88impl LintRule for ConventionCastingStyle {
89    fn code(&self) -> &'static str {
90        issue_codes::LINT_CV_011
91    }
92
93    fn name(&self) -> &'static str {
94        "Casting style"
95    }
96
97    fn description(&self) -> &'static str {
98        "Enforce consistent type casting style."
99    }
100
101    fn check(&self, statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
102        let sql = ctx.sql;
103        let casts = collect_cast_instances(statement, sql);
104
105        if casts.is_empty() {
106            return Vec::new();
107        }
108
109        // Determine the target style.
110        let target = match self.preferred_style {
111            PreferredTypeCastingStyle::Consistent => casts[0].style,
112            PreferredTypeCastingStyle::Shorthand => CastStyle::DoubleColon,
113            PreferredTypeCastingStyle::Cast => CastStyle::FunctionCast,
114            PreferredTypeCastingStyle::Convert => CastStyle::Convert,
115        };
116
117        // Check if there is a violation at all.
118        let has_violation = casts.iter().any(|c| c.style != target);
119        if !has_violation {
120            return Vec::new();
121        }
122
123        let message = match self.preferred_style {
124            PreferredTypeCastingStyle::Consistent => {
125                "Use consistent casting style (avoid mixing CAST styles)."
126            }
127            PreferredTypeCastingStyle::Shorthand => "Use `::` shorthand casting style.",
128            PreferredTypeCastingStyle::Cast => "Use `CAST(...)` style casts.",
129            PreferredTypeCastingStyle::Convert => "Use `CONVERT(...)` style casts.",
130        };
131
132        // Emit one issue per non-conforming cast so that partially fixable
133        // statements (e.g. 3-arg CONVERT that can't be converted) still show
134        // improvement in the violation count after autofix.
135        let mut issues = Vec::new();
136        for cast in &casts {
137            if cast.style == target {
138                continue;
139            }
140
141            let mut issue =
142                Issue::info(issue_codes::LINT_CV_011, message).with_statement(ctx.statement_index);
143
144            if !cast.is_3arg_convert && !cast.has_comments {
145                let cast_text = &sql[cast.start..cast.end];
146                if let Some(replacement) = convert_cast(cast_text, cast.style, target) {
147                    issue = issue.with_autofix_edits(
148                        IssueAutofixApplicability::Unsafe,
149                        vec![IssuePatchEdit::new(
150                            Span::new(cast.start, cast.end),
151                            replacement,
152                        )],
153                    );
154                }
155            }
156
157            issues.push(issue);
158        }
159
160        issues
161    }
162}
163
164// ---------------------------------------------------------------------------
165// Collect all cast expressions with their source positions
166// ---------------------------------------------------------------------------
167
168fn collect_cast_instances(statement: &Statement, sql: &str) -> Vec<CastInstance> {
169    let mut casts = Vec::new();
170
171    visit_expressions(statement, &mut |expr| {
172        match expr {
173            Expr::Cast {
174                kind,
175                expr: inner,
176                data_type,
177                ..
178            } => {
179                let style = match kind {
180                    CastKind::DoubleColon => CastStyle::DoubleColon,
181                    CastKind::Cast | CastKind::TryCast | CastKind::SafeCast => {
182                        CastStyle::FunctionCast
183                    }
184                };
185
186                // For chained :: (inner is also ::), skip the inner — we handle
187                // the entire chain as one entry via the outermost.
188                let is_inner_chain = matches!(
189                    inner.as_ref(),
190                    Expr::Cast {
191                        kind: CastKind::DoubleColon,
192                        ..
193                    }
194                );
195
196                // Get the inner expression's byte range.
197                let inner_span = find_cast_span(sql, inner, kind.clone(), data_type);
198                if let Some((start, end)) = inner_span {
199                    let text = &sql[start..end];
200                    let has_comments = text.contains("--") || text.contains("/*");
201
202                    if style == CastStyle::DoubleColon && is_inner_chain {
203                        // Outermost chained :: — remove previously collected inner.
204                        casts.retain(|c: &CastInstance| c.start < start || c.end > end);
205                    }
206
207                    casts.push(CastInstance {
208                        style,
209                        start,
210                        end,
211                        has_comments,
212                        is_3arg_convert: false,
213                    });
214                }
215            }
216            Expr::Function(function)
217                if function.name.to_string().eq_ignore_ascii_case("CONVERT") =>
218            {
219                if let Some((start, mut end)) = expr_span_offsets(sql, expr) {
220                    // Function::span() may not include the closing paren.
221                    // Scan forward to include it.
222                    if end < sql.len() && sql.as_bytes().get(end) == Some(&b')') {
223                        end += 1;
224                    } else {
225                        // Try to find the closing paren after the span end.
226                        if let Some(close) = find_matching_close_paren(&sql[end..]) {
227                            end += close + 1;
228                        }
229                    }
230
231                    let text = &sql[start..end];
232                    let has_comments = text.contains("--") || text.contains("/*");
233
234                    let arg_count = match &function.args {
235                        sqlparser::ast::FunctionArguments::List(list) => list.args.len(),
236                        _ => 0,
237                    };
238
239                    casts.push(CastInstance {
240                        style: CastStyle::Convert,
241                        start,
242                        end,
243                        has_comments,
244                        is_3arg_convert: arg_count > 2,
245                    });
246                }
247            }
248            _ => {}
249        }
250    });
251
252    // Parser span extraction can miss parenthesized shorthand casts in some
253    // Snowflake semi-structured forms. Add a lightweight lexical fallback.
254    for (start, end) in scan_parenthesized_shorthand_cast_spans(sql) {
255        if casts.iter().any(|cast| {
256            cast.start == start && cast.end == end && cast.style == CastStyle::DoubleColon
257        }) {
258            continue;
259        }
260        let text = &sql[start..end];
261        casts.push(CastInstance {
262            style: CastStyle::DoubleColon,
263            start,
264            end,
265            has_comments: text.contains("--") || text.contains("/*"),
266            is_3arg_convert: false,
267        });
268    }
269
270    // Sort by position so first-seen logic works correctly.
271    casts.sort_by_key(|c| c.start);
272
273    // Deduplicate: remove entries whose ranges are fully contained within
274    // another entry's range (handles chained :: where both outer and inner
275    // are collected by the visitor). For overlapping shorthand casts, keep
276    // the wider range so we don't emit conflicting nested edits.
277    let mut deduped: Vec<CastInstance> = Vec::with_capacity(casts.len());
278    for cast in casts {
279        let mut dominated = false;
280        let mut replace_index = None;
281
282        for (index, other) in deduped.iter().enumerate() {
283            if other.start <= cast.start && other.end >= cast.end {
284                dominated = true;
285                break;
286            }
287            if cast.start <= other.start && cast.end >= other.end {
288                replace_index = Some(index);
289                break;
290            }
291            if cast.style == other.style
292                && spans_overlap(cast.start, cast.end, other.start, other.end)
293            {
294                let cast_len = cast.end.saturating_sub(cast.start);
295                let other_len = other.end.saturating_sub(other.start);
296                if cast_len > other_len {
297                    replace_index = Some(index);
298                } else {
299                    dominated = true;
300                }
301                break;
302            }
303        }
304
305        if dominated {
306            continue;
307        }
308
309        if let Some(index) = replace_index {
310            deduped[index] = cast;
311        } else {
312            deduped.push(cast);
313        }
314    }
315
316    deduped.sort_by_key(|cast| (cast.start, cast.end, cast.style as u8));
317    deduped.dedup_by(|left, right| left.start == right.start && left.end == right.end);
318    deduped
319}
320
321fn spans_overlap(left_start: usize, left_end: usize, right_start: usize, right_end: usize) -> bool {
322    left_start < right_end && right_start < left_end
323}
324
325fn scan_parenthesized_shorthand_cast_spans(sql: &str) -> Vec<(usize, usize)> {
326    let bytes = sql.as_bytes();
327    let mut out = Vec::new();
328    let mut index = 0usize;
329
330    while index + 1 < bytes.len() {
331        if bytes[index] != b':' || bytes[index + 1] != b':' {
332            index += 1;
333            continue;
334        }
335
336        let mut lhs_end = index;
337        while lhs_end > 0 && bytes[lhs_end - 1].is_ascii_whitespace() {
338            lhs_end -= 1;
339        }
340        if lhs_end == 0 || bytes[lhs_end - 1] != b')' {
341            index += 2;
342            continue;
343        }
344        let close_paren = lhs_end - 1;
345        let Some(open_paren) = find_matching_open_paren(bytes, close_paren) else {
346            index += 2;
347            continue;
348        };
349
350        let Some(type_end) = scan_parenthesized_shorthand_type_end(bytes, index + 2) else {
351            index += 2;
352            continue;
353        };
354
355        out.push((open_paren, type_end));
356        index = type_end;
357    }
358
359    out
360}
361
362fn scan_parenthesized_shorthand_type_end(bytes: &[u8], start: usize) -> Option<usize> {
363    let mut index = start;
364    let mut depth = 0i32;
365    let mut saw_any = false;
366
367    while index < bytes.len() {
368        match bytes[index] {
369            b'(' => {
370                depth += 1;
371                saw_any = true;
372                index += 1;
373            }
374            b')' if depth > 0 => {
375                depth -= 1;
376                index += 1;
377            }
378            b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_' | b'.' => {
379                saw_any = true;
380                index += 1;
381            }
382            b',' if depth > 0 => index += 1,
383            b' ' | b'\t' | b'\n' | b'\r' if depth > 0 => index += 1,
384            _ => break,
385        }
386    }
387
388    if saw_any {
389        Some(index)
390    } else {
391        None
392    }
393}
394
395fn find_matching_open_paren(bytes: &[u8], close_paren: usize) -> Option<usize> {
396    if bytes.get(close_paren).copied() != Some(b')') {
397        return None;
398    }
399    let mut depth = 1i32;
400    let mut cursor = close_paren;
401    while cursor > 0 {
402        cursor -= 1;
403        match bytes[cursor] {
404            b')' => depth += 1,
405            b'(' => {
406                depth -= 1;
407                if depth == 0 {
408                    return Some(cursor);
409                }
410            }
411            _ => {}
412        }
413    }
414    None
415}
416
417/// Find the full source span of a CAST/:: expression.
418///
419/// sqlparser's `Expr::Cast.span()` only returns the inner expression's span,
420/// so we compute the full span by:
421/// - For CAST/TRY_CAST/SAFE_CAST: scan backwards from inner expr to find the
422///   keyword, then forwards to find the closing paren.
423/// - For `::`: find the deepest base expression, use its span as start, then
424///   scan forwards through all `::type` segments to find the outermost end.
425fn find_cast_span(
426    sql: &str,
427    inner: &Expr,
428    kind: CastKind,
429    data_type: &DataType,
430) -> Option<(usize, usize)> {
431    match kind {
432        CastKind::Cast | CastKind::TryCast | CastKind::SafeCast => {
433            let (inner_start, inner_end) = expr_span_offsets(sql, inner)?;
434
435            // Scan backwards from inner_start to find `CAST(`, `TRY_CAST(`, or `SAFE_CAST(`.
436            let before = &sql[..inner_start];
437            let paren_pos = before.rfind('(')?;
438            let before_paren = before[..paren_pos].trim_end();
439            let kw = match kind {
440                CastKind::TryCast => "TRY_CAST",
441                CastKind::SafeCast => "SAFE_CAST",
442                _ => "CAST",
443            };
444            let kw_len = kw.len();
445            if before_paren.len() < kw_len {
446                return None;
447            }
448            let kw_candidate = &before_paren[before_paren.len() - kw_len..];
449            if !kw_candidate.eq_ignore_ascii_case(kw) {
450                return None;
451            }
452            let start = before_paren.len() - kw_len;
453
454            // Scan forwards from inner_end to find closing paren.
455            let after = &sql[inner_end..];
456            let close = find_matching_close_paren(after)?;
457            let end = inner_end + close + 1;
458
459            Some((start, end))
460        }
461        CastKind::DoubleColon => {
462            // Find the deepest non-:: base expression to get the real start.
463            let base = deepest_base_expr(inner);
464            let (base_start, base_end) = expr_span_offsets(sql, base)?;
465
466            // Scan forward from base_end through all `::type` segments.
467            let type_str = data_type.to_string();
468            let mut pos = base_end;
469            loop {
470                let after = &sql[pos..];
471                let dc_pos = match after.find("::") {
472                    Some(p) => p,
473                    None => break,
474                };
475                let type_start = pos + dc_pos + 2;
476                let type_len = source_type_len(sql, type_start, &type_str);
477                if type_len == 0 {
478                    break;
479                }
480                pos = type_start + type_len;
481                // Check if this type matches the outermost data_type.
482                let this_type = &sql[type_start..pos];
483                if this_type.eq_ignore_ascii_case(&type_str) {
484                    break;
485                }
486            }
487
488            Some((base_start, pos))
489        }
490    }
491}
492
493/// Walk down the `Expr::Cast { kind: DoubleColon }` chain to find the
494/// deepest non-Cast base expression.
495fn deepest_base_expr(expr: &Expr) -> &Expr {
496    match expr {
497        Expr::Cast {
498            kind: CastKind::DoubleColon,
499            expr: inner,
500            ..
501        } => deepest_base_expr(inner),
502        _ => expr,
503    }
504}
505
506/// Find the position of the matching closing paren in `text`, accounting for
507/// nesting. Returns offset relative to `text` start.
508fn find_matching_close_paren(text: &str) -> Option<usize> {
509    let mut depth = 0i32;
510    let bytes = text.as_bytes();
511    let mut i = 0;
512    while i < bytes.len() {
513        match bytes[i] {
514            b'(' => depth += 1,
515            b')' => {
516                if depth == 0 {
517                    return Some(i);
518                }
519                depth -= 1;
520            }
521            b'\'' | b'"' => {
522                let quote = bytes[i];
523                i += 1;
524                while i < bytes.len() && bytes[i] != quote {
525                    if bytes[i] == b'\\' {
526                        i += 1;
527                    }
528                    i += 1;
529                }
530            }
531            _ => {}
532        }
533        i += 1;
534    }
535    None
536}
537
538/// Determine the length of a type name in the source SQL starting at `pos`.
539/// The type in source may use different casing or spacing than `DataType::to_string()`.
540fn source_type_len(sql: &str, pos: usize, type_display: &str) -> usize {
541    // The type ends at the first character that can't be part of a type name.
542    // Type names consist of alphanumeric chars, `_`, `(`, `)`, `,`, spaces
543    // (for compound types like `CHARACTER VARYING(10)`).
544    // We use the Display length as a guide but match against the actual source.
545    let remaining = &sql[pos..];
546    let display_len = type_display.len();
547
548    // Try exact match first (common case).
549    if remaining.len() >= display_len && remaining[..display_len].eq_ignore_ascii_case(type_display)
550    {
551        return display_len;
552    }
553
554    // Fallback: scan forward through identifier-like characters.
555    let mut len = 0;
556    let mut depth = 0i32;
557    for &b in remaining.as_bytes() {
558        match b {
559            b'(' => {
560                depth += 1;
561                len += 1;
562            }
563            b')' if depth > 0 => {
564                depth -= 1;
565                len += 1;
566            }
567            b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_' => len += 1,
568            b' ' | b'\t' | b'\n' | b',' if depth > 0 => len += 1,
569            _ => break,
570        }
571    }
572    len
573}
574
575// ---------------------------------------------------------------------------
576// Convert a cast expression to the target style
577// ---------------------------------------------------------------------------
578
579fn convert_cast(cast_text: &str, from_style: CastStyle, to_style: CastStyle) -> Option<String> {
580    match (from_style, to_style) {
581        (CastStyle::FunctionCast, CastStyle::DoubleColon) => cast_to_shorthand(cast_text),
582        (CastStyle::FunctionCast, CastStyle::Convert) => cast_to_convert(cast_text),
583        (CastStyle::DoubleColon, CastStyle::FunctionCast) => shorthand_to_cast(cast_text),
584        (CastStyle::DoubleColon, CastStyle::Convert) => shorthand_to_convert(cast_text),
585        (CastStyle::Convert, CastStyle::FunctionCast) => convert_to_cast(cast_text),
586        (CastStyle::Convert, CastStyle::DoubleColon) => convert_to_shorthand(cast_text),
587        _ => None,
588    }
589}
590
591/// Parse the interior of `CAST(expr AS type)` from raw text.
592/// Returns `(expr_text, type_text)`.
593fn parse_cast_interior(cast_text: &str) -> Option<(&str, &str)> {
594    let open = cast_text.find('(')?;
595    let close = cast_text.rfind(')')?;
596    let inner = cast_text[open + 1..close].trim();
597
598    let as_pos = find_top_level_as(inner)?;
599    let expr_part = inner[..as_pos].trim();
600    // The `AS` keyword is typically 2 chars, but ` AS ` starts with a space.
601    // `as_pos` points to the space/newline before `AS`.
602    let type_part = inner[as_pos + 1..].trim();
603    // Strip the leading `AS` keyword.
604    let type_part = type_part
605        .strip_prefix("AS")
606        .or_else(|| type_part.strip_prefix("as"))
607        .or_else(|| type_part.strip_prefix("As"))
608        .or_else(|| type_part.strip_prefix("aS"))
609        .unwrap_or(type_part)
610        .trim();
611    Some((expr_part, type_part))
612}
613
614/// Find the position of top-level whitespace-AS-whitespace in CAST interior.
615fn find_top_level_as(inner: &str) -> Option<usize> {
616    let bytes = inner.as_bytes();
617    let mut depth = 0i32;
618    let mut i = 0;
619    while i < bytes.len() {
620        match bytes[i] {
621            b'(' => depth += 1,
622            b')' => depth -= 1,
623            b'\'' | b'"' => {
624                let quote = bytes[i];
625                i += 1;
626                while i < bytes.len() && bytes[i] != quote {
627                    if bytes[i] == b'\\' {
628                        i += 1;
629                    }
630                    i += 1;
631                }
632            }
633            _ if depth == 0 => {
634                if is_whitespace_byte(bytes[i])
635                    && i + 3 < bytes.len()
636                    && bytes[i + 1].eq_ignore_ascii_case(&b'A')
637                    && bytes[i + 2].eq_ignore_ascii_case(&b'S')
638                    && is_whitespace_byte(bytes[i + 3])
639                {
640                    return Some(i);
641                }
642            }
643            _ => {}
644        }
645        i += 1;
646    }
647    None
648}
649
650fn is_whitespace_byte(b: u8) -> bool {
651    matches!(b, b' ' | b'\t' | b'\n' | b'\r')
652}
653
654/// `CAST(expr AS type)` → `expr::type` or `(expr)::type`.
655fn cast_to_shorthand(cast_text: &str) -> Option<String> {
656    let (expr, type_text) = parse_cast_interior(cast_text)?;
657    let needs_parens = expr_is_complex(expr);
658    if needs_parens {
659        Some(format!("({expr})::{type_text}"))
660    } else {
661        Some(format!("{expr}::{type_text}"))
662    }
663}
664
665/// `CAST(expr AS type)` → `convert(type, expr)`.
666fn cast_to_convert(cast_text: &str) -> Option<String> {
667    let (expr, type_text) = parse_cast_interior(cast_text)?;
668    Some(format!("convert({type_text}, {expr})"))
669}
670
671/// `CONVERT(type, expr)` → `cast(expr as type)`.
672fn convert_to_cast(convert_text: &str) -> Option<String> {
673    let (type_text, expr) = parse_convert_interior(convert_text)?;
674    Some(format!("cast({expr} as {type_text})"))
675}
676
677/// `CONVERT(type, expr)` → `expr::type` or `(expr)::type`.
678fn convert_to_shorthand(convert_text: &str) -> Option<String> {
679    let (type_text, expr) = parse_convert_interior(convert_text)?;
680    let needs_parens = expr_is_complex(expr);
681    if needs_parens {
682        Some(format!("({expr})::{type_text}"))
683    } else {
684        Some(format!("{expr}::{type_text}"))
685    }
686}
687
688/// `expr::type` → `cast(expr as type)`.
689/// Handles chained casts: `expr::t1::t2` → `cast(cast(expr as t1) as t2)`.
690fn shorthand_to_cast(shorthand_text: &str) -> Option<String> {
691    let parts = split_shorthand_chain(shorthand_text)?;
692    if parts.len() < 2 {
693        return None;
694    }
695    let mut result = rewrite_nested_simple_shorthand_to_cast(parts[0]);
696    for type_part in &parts[1..] {
697        result = format!("cast({result} as {type_part})");
698    }
699    Some(result)
700}
701
702/// `expr::type` → `convert(type, expr)`.
703/// Handles chained casts: `expr::t1::t2` → `convert(t2, convert(t1, expr))`.
704fn shorthand_to_convert(shorthand_text: &str) -> Option<String> {
705    let parts = split_shorthand_chain(shorthand_text)?;
706    if parts.len() < 2 {
707        return None;
708    }
709    let mut result = parts[0].to_string();
710    for type_part in &parts[1..] {
711        result = format!("convert({type_part}, {result})");
712    }
713    Some(result)
714}
715
716/// Split a `::` chain like `100::int::text` into `["100", "int", "text"]`.
717fn split_shorthand_chain(text: &str) -> Option<Vec<&str>> {
718    let mut parts = Vec::new();
719    let mut depth = 0i32;
720    let bytes = text.as_bytes();
721    let mut last_split = 0;
722
723    let mut i = 0;
724    while i < bytes.len() {
725        match bytes[i] {
726            b'(' => depth += 1,
727            b')' => depth -= 1,
728            b'\'' | b'"' => {
729                let quote = bytes[i];
730                i += 1;
731                while i < bytes.len() && bytes[i] != quote {
732                    if bytes[i] == b'\\' {
733                        i += 1;
734                    }
735                    i += 1;
736                }
737            }
738            b':' if depth == 0 && i + 1 < bytes.len() && bytes[i + 1] == b':' => {
739                parts.push(&text[last_split..i]);
740                i += 2;
741                last_split = i;
742                continue;
743            }
744            _ => {}
745        }
746        i += 1;
747    }
748    parts.push(&text[last_split..]);
749
750    if parts.len() >= 2 {
751        Some(parts)
752    } else {
753        None
754    }
755}
756
757/// Rewrites simple nested shorthand fragments in an expression, e.g.
758/// `value:Longitude::varchar` -> `cast(value:Longitude as varchar)`.
759/// This is intentionally conservative: it only rewrites contiguous identifier
760/// chains and leaves complex nested expressions to the outer conversion pass.
761fn rewrite_nested_simple_shorthand_to_cast(expr: &str) -> String {
762    let bytes = expr.as_bytes();
763    let mut index = 0usize;
764    let mut out = String::with_capacity(expr.len() + 16);
765
766    while index < bytes.len() {
767        let Some(rel_dc) = expr[index..].find("::") else {
768            out.push_str(&expr[index..]);
769            break;
770        };
771        let dc = index + rel_dc;
772
773        let mut lhs_start = dc;
774        while lhs_start > 0 && is_simple_shorthand_lhs_char(bytes[lhs_start - 1]) {
775            lhs_start -= 1;
776        }
777        if lhs_start == dc {
778            out.push_str(&expr[index..dc + 2]);
779            index = dc + 2;
780            continue;
781        }
782
783        let mut rhs_end = dc + 2;
784        while rhs_end < bytes.len() && is_simple_type_char(bytes[rhs_end]) {
785            rhs_end += 1;
786        }
787        if rhs_end == dc + 2 {
788            out.push_str(&expr[index..dc + 2]);
789            index = dc + 2;
790            continue;
791        }
792
793        out.push_str(&expr[index..lhs_start]);
794        out.push_str("cast(");
795        out.push_str(&expr[lhs_start..dc]);
796        out.push_str(" as ");
797        out.push_str(&expr[dc + 2..rhs_end]);
798        out.push(')');
799        index = rhs_end;
800    }
801
802    out
803}
804
805fn is_simple_shorthand_lhs_char(byte: u8) -> bool {
806    byte.is_ascii_alphanumeric()
807        || matches!(
808            byte,
809            b'_' | b'.' | b':' | b'$' | b'@' | b'"' | b'`' | b'[' | b']'
810        )
811}
812
813fn is_simple_type_char(byte: u8) -> bool {
814    byte.is_ascii_alphanumeric()
815        || matches!(
816            byte,
817            b'_' | b' ' | b'\t' | b'\n' | b'\r' | b'(' | b')' | b','
818        )
819}
820
821/// Parse interior of `CONVERT(type, expr)`. Returns `(type_text, expr_text)`.
822fn parse_convert_interior(convert_text: &str) -> Option<(&str, &str)> {
823    let open = convert_text.find('(')?;
824    let close = convert_text.rfind(')')?;
825    let inner = convert_text[open + 1..close].trim();
826    let comma = find_top_level_comma(inner)?;
827    let type_part = inner[..comma].trim();
828    let expr_part = inner[comma + 1..].trim();
829    Some((type_part, expr_part))
830}
831
832/// Find position of the first top-level comma.
833fn find_top_level_comma(inner: &str) -> Option<usize> {
834    let bytes = inner.as_bytes();
835    let mut depth = 0i32;
836    let mut i = 0;
837    while i < bytes.len() {
838        match bytes[i] {
839            b'(' => depth += 1,
840            b')' => depth -= 1,
841            b'\'' | b'"' => {
842                let quote = bytes[i];
843                i += 1;
844                while i < bytes.len() && bytes[i] != quote {
845                    if bytes[i] == b'\\' {
846                        i += 1;
847                    }
848                    i += 1;
849                }
850            }
851            b',' if depth == 0 => return Some(i),
852            _ => {}
853        }
854        i += 1;
855    }
856    None
857}
858
859/// Returns true if the expression text is "complex" and needs parenthesization
860/// when used in shorthand `::` form.
861fn expr_is_complex(expr: &str) -> bool {
862    let trimmed = expr.trim();
863    let bytes = trimmed.as_bytes();
864    let mut depth = 0i32;
865    for (i, &b) in bytes.iter().enumerate() {
866        match b {
867            b'(' => depth += 1,
868            b')' => depth -= 1,
869            b'\'' | b'"' => return false, // string literal — not complex
870            b'|' | b'+' | b'-' | b'*' | b'/' | b'%' if depth == 0 => {
871                if b == b'-' && i == 0 {
872                    continue;
873                }
874                return true;
875            }
876            b' ' | b'\t' | b'\n' if depth == 0 => return true,
877            _ => {}
878        }
879    }
880    false
881}
882
883// ---------------------------------------------------------------------------
884// Span helpers
885// ---------------------------------------------------------------------------
886
887fn expr_span_offsets(sql: &str, expr: &Expr) -> Option<(usize, usize)> {
888    let span = expr.span();
889    if span.start.line == 0 || span.start.column == 0 || span.end.line == 0 || span.end.column == 0
890    {
891        return None;
892    }
893
894    let start = line_col_to_offset(sql, span.start.line as usize, span.start.column as usize)?;
895    let end = line_col_to_offset(sql, span.end.line as usize, span.end.column as usize)?;
896    (end >= start).then_some((start, end))
897}
898
899fn line_col_to_offset(sql: &str, line: usize, column: usize) -> Option<usize> {
900    if line == 0 || column == 0 {
901        return None;
902    }
903
904    let mut current_line = 1usize;
905    let mut line_start = 0usize;
906
907    for (idx, ch) in sql.char_indices() {
908        if current_line == line {
909            break;
910        }
911        if ch == '\n' {
912            current_line += 1;
913            line_start = idx + ch.len_utf8();
914        }
915    }
916
917    if current_line != line {
918        return None;
919    }
920
921    let mut col = 1usize;
922    for (idx, _ch) in sql[line_start..].char_indices() {
923        if col == column {
924            return Some(line_start + idx);
925        }
926        col += 1;
927    }
928    if col == column {
929        return Some(sql.len());
930    }
931    None
932}
933
934#[cfg(test)]
935mod tests {
936    use super::*;
937    use crate::parser::parse_sql;
938
939    fn run(sql: &str) -> Vec<Issue> {
940        let statements = parse_sql(sql).expect("parse");
941        let rule = ConventionCastingStyle::default();
942        statements
943            .iter()
944            .enumerate()
945            .flat_map(|(index, statement)| {
946                rule.check(
947                    statement,
948                    &LintContext {
949                        sql,
950                        statement_range: 0..sql.len(),
951                        statement_index: index,
952                    },
953                )
954            })
955            .collect()
956    }
957
958    fn run_with_config(sql: &str, config: &LintConfig) -> Vec<Issue> {
959        let statements = parse_sql(sql).expect("parse");
960        let rule = ConventionCastingStyle::from_config(config);
961        statements
962            .iter()
963            .enumerate()
964            .flat_map(|(index, statement)| {
965                rule.check(
966                    statement,
967                    &LintContext {
968                        sql,
969                        statement_range: 0..sql.len(),
970                        statement_index: index,
971                    },
972                )
973            })
974            .collect()
975    }
976
977    fn apply_edits(sql: &str, edits: &[IssuePatchEdit]) -> String {
978        let mut sorted: Vec<_> = edits.iter().collect();
979        sorted.sort_by_key(|e| std::cmp::Reverse(e.span.start));
980        let mut result = sql.to_string();
981        for edit in sorted {
982            result.replace_range(edit.span.start..edit.span.end, &edit.replacement);
983        }
984        result
985    }
986
987    fn collect_all_edits(issues: &[Issue]) -> Vec<&IssuePatchEdit> {
988        issues
989            .iter()
990            .filter_map(|i| i.autofix.as_ref())
991            .flat_map(|a| a.edits.iter())
992            .collect()
993    }
994
995    fn apply_all_fixes(sql: &str, issues: &[Issue]) -> String {
996        let edits = collect_all_edits(issues);
997        let owned: Vec<IssuePatchEdit> = edits.into_iter().cloned().collect();
998        apply_edits(sql, &owned)
999    }
1000
1001    #[test]
1002    fn flags_mixed_casting_styles() {
1003        let issues = run("SELECT CAST(amount AS INT)::TEXT FROM t");
1004        assert_eq!(issues.len(), 1);
1005        assert_eq!(issues[0].code, issue_codes::LINT_CV_011);
1006    }
1007
1008    #[test]
1009    fn does_not_flag_single_casting_style() {
1010        assert!(run("SELECT amount::INT FROM t").is_empty());
1011        assert!(run("SELECT CAST(amount AS INT) FROM t").is_empty());
1012    }
1013
1014    #[test]
1015    fn does_not_flag_cast_like_tokens_inside_string_literal() {
1016        assert!(run("SELECT 'value::TEXT and CAST(value AS INT)' AS note").is_empty());
1017    }
1018
1019    #[test]
1020    fn flags_mixed_try_cast_and_double_colon_styles() {
1021        let issues = run("SELECT TRY_CAST(amount AS INT)::TEXT FROM t");
1022        assert_eq!(issues.len(), 1);
1023        assert_eq!(issues[0].code, issue_codes::LINT_CV_011);
1024    }
1025
1026    #[test]
1027    fn shorthand_preference_flags_cast_function_style() {
1028        let config = LintConfig {
1029            enabled: true,
1030            disabled_rules: vec![],
1031            rule_configs: std::collections::BTreeMap::from([(
1032                "convention.casting_style".to_string(),
1033                serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1034            )]),
1035        };
1036        let rule = ConventionCastingStyle::from_config(&config);
1037        let sql = "SELECT CAST(amount AS INT) FROM t";
1038        let statements = parse_sql(sql).expect("parse");
1039        let issues = rule.check(
1040            &statements[0],
1041            &LintContext {
1042                sql,
1043                statement_range: 0..sql.len(),
1044                statement_index: 0,
1045            },
1046        );
1047        assert_eq!(issues.len(), 1);
1048    }
1049
1050    #[test]
1051    fn cast_preference_flags_shorthand_style() {
1052        let config = LintConfig {
1053            enabled: true,
1054            disabled_rules: vec![],
1055            rule_configs: std::collections::BTreeMap::from([(
1056                "LINT_CV_011".to_string(),
1057                serde_json::json!({"preferred_type_casting_style": "cast"}),
1058            )]),
1059        };
1060        let rule = ConventionCastingStyle::from_config(&config);
1061        let sql = "SELECT amount::INT FROM t";
1062        let statements = parse_sql(sql).expect("parse");
1063        let issues = rule.check(
1064            &statements[0],
1065            &LintContext {
1066                sql,
1067                statement_range: 0..sql.len(),
1068                statement_index: 0,
1069            },
1070        );
1071        assert_eq!(issues.len(), 1);
1072    }
1073
1074    // -----------------------------------------------------------------------
1075    // Autofix tests — SQLFluff CV11 fixture parity
1076    // -----------------------------------------------------------------------
1077
1078    #[test]
1079    fn autofix_consistent_prior_convert() {
1080        let sql = "select\n    convert(int, 1) as bar,\n    100::int::text,\n    cast(10\n    as text) as coo\nfrom foo;";
1081        let issues = run(sql);
1082        assert!(!issues.is_empty());
1083        let fixed = apply_all_fixes(sql, &issues);
1084        assert_eq!(
1085            fixed,
1086            "select\n    convert(int, 1) as bar,\n    convert(text, convert(int, 100)),\n    convert(text, 10) as coo\nfrom foo;"
1087        );
1088    }
1089
1090    #[test]
1091    fn autofix_consistent_prior_cast() {
1092        let sql = "select\n    cast(10 as text) as coo,\n    convert(int, 1) as bar,\n    100::int::text,\nfrom foo;";
1093        let issues = run(sql);
1094        assert!(!issues.is_empty());
1095        let fixed = apply_all_fixes(sql, &issues);
1096        assert_eq!(
1097            fixed,
1098            "select\n    cast(10 as text) as coo,\n    cast(1 as int) as bar,\n    cast(cast(100 as int) as text),\nfrom foo;"
1099        );
1100    }
1101
1102    #[test]
1103    fn autofix_consistent_prior_shorthand() {
1104        let sql = "select\n    100::int::text,\n    cast(10 as text) as coo,\n    convert(int, 1) as bar\nfrom foo;";
1105        let issues = run(sql);
1106        assert!(!issues.is_empty());
1107        let fixed = apply_all_fixes(sql, &issues);
1108        assert_eq!(
1109            fixed,
1110            "select\n    100::int::text,\n    10::text as coo,\n    1::int as bar\nfrom foo;"
1111        );
1112    }
1113
1114    #[test]
1115    fn autofix_config_cast() {
1116        let config = LintConfig {
1117            enabled: true,
1118            disabled_rules: vec![],
1119            rule_configs: std::collections::BTreeMap::from([(
1120                "convention.casting_style".to_string(),
1121                serde_json::json!({"preferred_type_casting_style": "cast"}),
1122            )]),
1123        };
1124        let sql = "select\n    convert(int, 1) as bar,\n    100::int::text,\n    cast(10 as text) as coo\nfrom foo;";
1125        let issues = run_with_config(sql, &config);
1126        assert!(!issues.is_empty());
1127        let fixed = apply_all_fixes(sql, &issues);
1128        assert_eq!(
1129            fixed,
1130            "select\n    cast(1 as int) as bar,\n    cast(cast(100 as int) as text),\n    cast(10 as text) as coo\nfrom foo;"
1131        );
1132    }
1133
1134    #[test]
1135    fn autofix_config_convert() {
1136        let config = LintConfig {
1137            enabled: true,
1138            disabled_rules: vec![],
1139            rule_configs: std::collections::BTreeMap::from([(
1140                "convention.casting_style".to_string(),
1141                serde_json::json!({"preferred_type_casting_style": "convert"}),
1142            )]),
1143        };
1144        let sql = "select\n    convert(int, 1) as bar,\n    100::int::text,\n    cast(10 as text) as coo\nfrom foo;";
1145        let issues = run_with_config(sql, &config);
1146        assert!(!issues.is_empty());
1147        let fixed = apply_all_fixes(sql, &issues);
1148        assert_eq!(
1149            fixed,
1150            "select\n    convert(int, 1) as bar,\n    convert(text, convert(int, 100)),\n    convert(text, 10) as coo\nfrom foo;"
1151        );
1152    }
1153
1154    #[test]
1155    fn autofix_config_shorthand() {
1156        let config = LintConfig {
1157            enabled: true,
1158            disabled_rules: vec![],
1159            rule_configs: std::collections::BTreeMap::from([(
1160                "convention.casting_style".to_string(),
1161                serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1162            )]),
1163        };
1164        let sql = "select\n    convert(int, 1) as bar,\n    100::int::text,\n    cast(10 as text) as coo\nfrom foo;";
1165        let issues = run_with_config(sql, &config);
1166        assert!(!issues.is_empty());
1167        let fixed = apply_all_fixes(sql, &issues);
1168        assert_eq!(
1169            fixed,
1170            "select\n    1::int as bar,\n    100::int::text,\n    10::text as coo\nfrom foo;"
1171        );
1172    }
1173
1174    #[test]
1175    fn autofix_3arg_convert_skipped_config_cast() {
1176        let config = LintConfig {
1177            enabled: true,
1178            disabled_rules: vec![],
1179            rule_configs: std::collections::BTreeMap::from([(
1180                "convention.casting_style".to_string(),
1181                serde_json::json!({"preferred_type_casting_style": "cast"}),
1182            )]),
1183        };
1184        let sql = "select\n    convert(int, 1, 126) as bar,\n    100::int::text,\n    cast(10 as text) as coo\nfrom foo;";
1185        let issues = run_with_config(sql, &config);
1186        assert!(!issues.is_empty());
1187        let fixed = apply_all_fixes(sql, &issues);
1188        assert_eq!(
1189            fixed,
1190            "select\n    convert(int, 1, 126) as bar,\n    cast(cast(100 as int) as text),\n    cast(10 as text) as coo\nfrom foo;"
1191        );
1192    }
1193
1194    #[test]
1195    fn autofix_3arg_convert_skipped_config_shorthand() {
1196        let config = LintConfig {
1197            enabled: true,
1198            disabled_rules: vec![],
1199            rule_configs: std::collections::BTreeMap::from([(
1200                "convention.casting_style".to_string(),
1201                serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1202            )]),
1203        };
1204        let sql = "select\n    convert(int, 1, 126) as bar,\n    100::int::text,\n    cast(10 as text) as coo\nfrom foo;";
1205        let issues = run_with_config(sql, &config);
1206        assert!(!issues.is_empty());
1207        let fixed = apply_all_fixes(sql, &issues);
1208        assert_eq!(
1209            fixed,
1210            "select\n    convert(int, 1, 126) as bar,\n    100::int::text,\n    10::text as coo\nfrom foo;"
1211        );
1212    }
1213
1214    #[test]
1215    fn autofix_parenthesize_complex_expr_shorthand_from_cast() {
1216        let config = LintConfig {
1217            enabled: true,
1218            disabled_rules: vec![],
1219            rule_configs: std::collections::BTreeMap::from([(
1220                "convention.casting_style".to_string(),
1221                serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1222            )]),
1223        };
1224        let sql = "select\n    id::int,\n    cast(calendar_date||' 11:00:00' as timestamp) as calendar_datetime\nfrom foo;";
1225        let issues = run_with_config(sql, &config);
1226        assert!(!issues.is_empty());
1227        let fixed = apply_all_fixes(sql, &issues);
1228        assert_eq!(
1229            fixed,
1230            "select\n    id::int,\n    (calendar_date||' 11:00:00')::timestamp as calendar_datetime\nfrom foo;"
1231        );
1232    }
1233
1234    #[test]
1235    fn autofix_parenthesize_complex_expr_shorthand_from_convert() {
1236        let config = LintConfig {
1237            enabled: true,
1238            disabled_rules: vec![],
1239            rule_configs: std::collections::BTreeMap::from([(
1240                "convention.casting_style".to_string(),
1241                serde_json::json!({"preferred_type_casting_style": "shorthand"}),
1242            )]),
1243        };
1244        let sql = "select\n    id::int,\n    convert(timestamp, calendar_date||' 11:00:00') as calendar_datetime\nfrom foo;";
1245        let issues = run_with_config(sql, &config);
1246        assert!(!issues.is_empty());
1247        let fixed = apply_all_fixes(sql, &issues);
1248        assert_eq!(
1249            fixed,
1250            "select\n    id::int,\n    (calendar_date||' 11:00:00')::timestamp as calendar_datetime\nfrom foo;"
1251        );
1252    }
1253
1254    #[test]
1255    fn autofix_comment_cast_skipped() {
1256        let sql = "select\n    cast(10 as text) as coo,\n    convert( -- Convert the value\n        int, /*\n              to an integer\n            */ 1) as bar,\n    100::int::text\nfrom foo;";
1257        let issues = run(sql);
1258        assert!(!issues.is_empty());
1259        let fixed = apply_all_fixes(sql, &issues);
1260        assert_eq!(
1261            fixed,
1262            "select\n    cast(10 as text) as coo,\n    convert( -- Convert the value\n        int, /*\n              to an integer\n            */ 1) as bar,\n    cast(cast(100 as int) as text)\nfrom foo;"
1263        );
1264    }
1265
1266    #[test]
1267    fn autofix_3arg_convert_consistent_prior_cast() {
1268        let sql = "select\n    cast(10 as text) as coo,\n    convert(int, 1, 126) as bar,\n    100::int::text\nfrom foo;";
1269        let issues = run(sql);
1270        assert!(!issues.is_empty());
1271        let fixed = apply_all_fixes(sql, &issues);
1272        assert_eq!(
1273            fixed,
1274            "select\n    cast(10 as text) as coo,\n    convert(int, 1, 126) as bar,\n    cast(cast(100 as int) as text)\nfrom foo;"
1275        );
1276    }
1277
1278    #[test]
1279    fn autofix_comment_prior_convert_shorthand_fixed() {
1280        let sql = "select\n    convert(int, 126) as bar,\n    cast(\n    1 /* cast the value\n        to an integer\n      */ as int) as coo,\n    100::int::text\nfrom foo;";
1281        let issues = run(sql);
1282        assert!(!issues.is_empty());
1283        let fixed = apply_all_fixes(sql, &issues);
1284        assert_eq!(
1285            fixed,
1286            "select\n    convert(int, 126) as bar,\n    cast(\n    1 /* cast the value\n        to an integer\n      */ as int) as coo,\n    convert(text, convert(int, 100))\nfrom foo;"
1287        );
1288    }
1289
1290    #[test]
1291    fn autofix_comment_prior_shorthand_convert_fixed() {
1292        let sql = "select\n    100::int::text,\n    convert(int, 126) as bar,\n    cast(\n    1 /* cast the value\n        to an integer\n      */ as int) as coo\nfrom foo;";
1293        let issues = run(sql);
1294        assert!(!issues.is_empty());
1295        let fixed = apply_all_fixes(sql, &issues);
1296        assert_eq!(
1297            fixed,
1298            "select\n    100::int::text,\n    126::int as bar,\n    cast(\n    1 /* cast the value\n        to an integer\n      */ as int) as coo\nfrom foo;"
1299        );
1300    }
1301
1302    #[test]
1303    fn shorthand_to_cast_rewrites_nested_snowflake_path_casts() {
1304        let fixed = shorthand_to_cast("(trim(value:Longitude::varchar))::double").expect("rewrite");
1305        assert_eq!(
1306            fixed,
1307            "cast((trim(cast(value:Longitude as varchar))) as double)"
1308        );
1309        assert_eq!(
1310            shorthand_to_cast("col:a.b:c::varchar").expect("rewrite"),
1311            "cast(col:a.b:c as varchar)"
1312        );
1313    }
1314}