Skip to main content

flowscope_core/linter/rules/
st_005.rs

1//! LINT_ST_005: Structure subquery.
2//!
3//! SQLFluff ST05 parity: avoid subqueries in FROM/JOIN clauses; prefer CTEs.
4
5use crate::linter::config::LintConfig;
6use crate::linter::rule::{LintContext, LintRule};
7use crate::parser::parse_sql_with_dialect;
8use crate::types::{issue_codes, Dialect, Issue, IssueAutofixApplicability, IssuePatchEdit};
9use sqlparser::ast::{Query, Select, SetExpr, Statement, TableFactor};
10use std::collections::HashSet;
11
12use super::semantic_helpers::{
13    collect_qualifier_prefixes_in_expr, visit_select_expressions, visit_selects_in_statement,
14};
15
16#[derive(Clone, Copy, Debug, Eq, PartialEq)]
17enum ForbidSubqueryIn {
18    Both,
19    Join,
20    From,
21}
22
23impl ForbidSubqueryIn {
24    fn from_config(config: &LintConfig) -> Self {
25        match config
26            .rule_option_str(issue_codes::LINT_ST_005, "forbid_subquery_in")
27            .unwrap_or("join")
28            .to_ascii_lowercase()
29            .as_str()
30        {
31            "join" => Self::Join,
32            "from" => Self::From,
33            _ => Self::Both,
34        }
35    }
36
37    fn forbid_from(self) -> bool {
38        matches!(self, Self::Both | Self::From)
39    }
40
41    fn forbid_join(self) -> bool {
42        matches!(self, Self::Both | Self::Join)
43    }
44}
45
46pub struct StructureSubquery {
47    forbid_subquery_in: ForbidSubqueryIn,
48}
49
50impl StructureSubquery {
51    pub fn from_config(config: &LintConfig) -> Self {
52        Self {
53            forbid_subquery_in: ForbidSubqueryIn::from_config(config),
54        }
55    }
56}
57
58impl Default for StructureSubquery {
59    fn default() -> Self {
60        Self {
61            forbid_subquery_in: ForbidSubqueryIn::Join,
62        }
63    }
64}
65
66impl LintRule for StructureSubquery {
67    fn code(&self) -> &'static str {
68        issue_codes::LINT_ST_005
69    }
70
71    fn name(&self) -> &'static str {
72        "Structure subquery"
73    }
74
75    fn description(&self) -> &'static str {
76        "Join/From clauses should not contain subqueries. Use CTEs instead."
77    }
78
79    fn check(&self, statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
80        let mut violations = 0usize;
81
82        visit_selects_in_statement(statement, &mut |select| {
83            let outer_source_names = source_names_in_select(select);
84            for table in &select.from {
85                if self.forbid_subquery_in.forbid_from()
86                    && table_factor_contains_derived(&table.relation, &outer_source_names)
87                {
88                    violations += 1;
89                }
90                if self.forbid_subquery_in.forbid_join() {
91                    for join in &table.joins {
92                        if table_factor_contains_derived(&join.relation, &outer_source_names) {
93                            violations += 1;
94                        }
95                    }
96                }
97            }
98        });
99
100        if violations == 0 {
101            return Vec::new();
102        }
103
104        let autofix_edits = st005_subquery_to_cte_rewrite(
105            ctx.statement_sql(),
106            statement,
107            self.forbid_subquery_in,
108            ctx.dialect(),
109        )
110        .filter(|rewritten| rewritten != ctx.statement_sql())
111        .map(|rewritten| {
112            vec![IssuePatchEdit::new(
113                ctx.span_from_statement_offset(0, ctx.statement_sql().len()),
114                rewritten,
115            )]
116        })
117        .unwrap_or_default();
118
119        (0..violations)
120            .map(|index| {
121                let mut issue = Issue::info(
122                    issue_codes::LINT_ST_005,
123                    "Join/From clauses should not contain subqueries. Use CTEs instead.",
124                )
125                .with_statement(ctx.statement_index);
126                if index == 0 && !autofix_edits.is_empty() {
127                    issue = issue.with_autofix_edits(
128                        IssueAutofixApplicability::Unsafe,
129                        autofix_edits.clone(),
130                    );
131                }
132                issue
133            })
134            .collect()
135    }
136}
137
138// ---------------------------------------------------------------------------
139// Comprehensive text-preserving subquery-to-CTE rewriter
140// ---------------------------------------------------------------------------
141
142/// A subquery found in a FROM/JOIN clause that should be extracted to a CTE.
143#[derive(Debug, Clone)]
144struct SubqueryExtraction {
145    /// Byte offset of the open parenthesis.
146    open_paren: usize,
147    /// Byte offset of the close parenthesis.
148    close_paren: usize,
149    /// Alias name (explicit or auto-generated).
150    alias: String,
151    /// Byte offset past the end of the alias region.
152    alias_region_end: usize,
153}
154
155/// Rewrite the SQL statement by extracting all subqueries in FROM/JOIN clauses
156/// to CTEs. Returns the rewritten SQL, or None if no rewrite is possible.
157fn st005_subquery_to_cte_rewrite(
158    sql: &str,
159    stmt: &Statement,
160    forbid_subquery_in: ForbidSubqueryIn,
161    dialect: Dialect,
162) -> Option<String> {
163    const MAX_REWRITE_PASSES: usize = 8;
164
165    let mut current_sql = sql.to_string();
166    let mut current_stmt = stmt.clone();
167    let mut changed = false;
168
169    for _ in 0..MAX_REWRITE_PASSES {
170        // Collect all non-correlated subqueries from the current AST.
171        let mut subquery_aliases: Vec<(String, bool)> = Vec::new();
172        collect_extractable_subqueries(&current_stmt, forbid_subquery_in, &mut subquery_aliases);
173        if subquery_aliases.is_empty() {
174            break;
175        }
176
177        // Find subquery positions in the current SQL text.
178        let extractions =
179            find_subquery_positions(&current_sql, forbid_subquery_in, &subquery_aliases);
180        if extractions.is_empty() {
181            break;
182        }
183
184        let Some(rewritten) = apply_cte_extractions(&current_sql, &extractions, dialect) else {
185            break;
186        };
187        if rewritten == current_sql {
188            break;
189        }
190
191        changed = true;
192        current_sql = rewritten;
193
194        // Re-parse the rewritten SQL so later passes can extract newly-exposed
195        // nested subqueries (e.g. inside extracted CTE bodies).
196        let Ok(mut reparsed) = parse_sql_with_dialect(&current_sql, dialect) else {
197            break;
198        };
199        let Some(next_stmt) = (reparsed.len() == 1).then(|| reparsed.remove(0)) else {
200            break;
201        };
202        current_stmt = next_stmt;
203    }
204
205    changed.then_some(current_sql)
206}
207
208/// Walk the AST to collect info about each extractable (non-correlated) subquery.
209/// Collects (alias_name, is_correlated) in document order.
210fn collect_extractable_subqueries(
211    stmt: &Statement,
212    forbid_in: ForbidSubqueryIn,
213    out: &mut Vec<(String, bool)>,
214) {
215    visit_selects_in_statement(stmt, &mut |select| {
216        let outer_source_names = source_names_in_select(select);
217        for table in &select.from {
218            if forbid_in.forbid_from() {
219                collect_from_table_factor(&table.relation, &outer_source_names, out);
220            }
221            if forbid_in.forbid_join() {
222                for join in &table.joins {
223                    collect_from_table_factor(&join.relation, &outer_source_names, out);
224                }
225            }
226        }
227    });
228}
229
230/// Recursively collect extractable subqueries from a table factor.
231fn collect_from_table_factor(
232    tf: &TableFactor,
233    outer_names: &HashSet<String>,
234    out: &mut Vec<(String, bool)>,
235) {
236    match tf {
237        TableFactor::Derived {
238            subquery, alias, ..
239        } => {
240            let is_correlated = query_references_outer_sources(subquery, outer_names);
241            if !is_correlated {
242                let alias_name = alias
243                    .as_ref()
244                    .map(|a| a.name.value.clone())
245                    .unwrap_or_default();
246                out.push((alias_name, is_correlated));
247            }
248        }
249        TableFactor::NestedJoin {
250            table_with_joins, ..
251        } => {
252            collect_from_table_factor(&table_with_joins.relation, outer_names, out);
253            for join in &table_with_joins.joins {
254                collect_from_table_factor(&join.relation, outer_names, out);
255            }
256        }
257        TableFactor::Pivot { table, .. }
258        | TableFactor::Unpivot { table, .. }
259        | TableFactor::MatchRecognize { table, .. } => {
260            collect_from_table_factor(table, outer_names, out);
261        }
262        _ => {}
263    }
264}
265
266/// Scan the SQL text to locate subquery parenthesized expressions in FROM/JOIN
267/// clauses. Returns extractions sorted by position (for correct processing order).
268fn find_subquery_positions(
269    sql: &str,
270    forbid_in: ForbidSubqueryIn,
271    ast_aliases: &[(String, bool)],
272) -> Vec<SubqueryExtraction> {
273    let bytes = sql.as_bytes();
274    let mut extractions = Vec::new();
275    let mut ast_idx = 0usize;
276    let mut auto_name_counter = 0usize;
277    // Collect all names to avoid clashes.
278    let mut existing_cte_names: HashSet<String> = HashSet::new();
279    collect_existing_cte_names(sql, &mut existing_cte_names);
280
281    // Names reserved for generated prep_N CTEs.
282    let mut used_names: HashSet<String> = existing_cte_names.clone();
283    for (alias, _) in ast_aliases {
284        if !alias.is_empty() {
285            used_names.insert(alias.to_ascii_uppercase());
286        }
287    }
288
289    // Names already claimed by explicit/auto extractions in this pass.
290    let mut claimed_names: HashSet<String> = existing_cte_names;
291
292    let mut pos = 0usize;
293    while pos < bytes.len() {
294        // Skip quoted regions.
295        if let Some(end) = skip_quoted_region(bytes, pos) {
296            pos = end;
297            continue;
298        }
299        // Skip line comments.
300        if bytes[pos] == b'-' && bytes.get(pos + 1) == Some(&b'-') {
301            while pos < bytes.len() && bytes[pos] != b'\n' {
302                pos += 1;
303            }
304            continue;
305        }
306        // Skip block comments.
307        if bytes[pos] == b'/' && bytes.get(pos + 1) == Some(&b'*') {
308            pos += 2;
309            while pos + 1 < bytes.len() {
310                if bytes[pos] == b'*' && bytes[pos + 1] == b'/' {
311                    pos += 2;
312                    break;
313                }
314                pos += 1;
315            }
316            continue;
317        }
318
319        // Look for FROM or JOIN keywords followed by a parenthesized subquery.
320        let is_from =
321            forbid_in.forbid_from() && match_ascii_keyword_at(bytes, pos, b"FROM").is_some();
322        let is_join = forbid_in.forbid_join()
323            && (match_ascii_keyword_at(bytes, pos, b"JOIN").is_some()
324                || match_join_keyword_sequence(bytes, pos).is_some());
325
326        if is_from || is_join {
327            let keyword_end = if is_from {
328                match_ascii_keyword_at(bytes, pos, b"FROM").unwrap()
329            } else if let Some(end) = match_join_keyword_sequence(bytes, pos) {
330                end
331            } else {
332                match_ascii_keyword_at(bytes, pos, b"JOIN").unwrap()
333            };
334
335            let after_keyword = skip_ascii_whitespace(bytes, keyword_end);
336
337            // Check for open parenthesis (could be `FROM(` or `FROM (` or `JOIN\n(`).
338            if after_keyword < bytes.len() && bytes[after_keyword] == b'(' {
339                if let Some(close) = find_matching_parenthesis_outside_quotes(sql, after_keyword) {
340                    let inner = sql[after_keyword + 1..close].trim();
341                    let inner_lower = inner.to_ascii_lowercase();
342                    // Only extract if inner content starts with SELECT or WITH,
343                    // and we still have AST aliases to consume.
344                    if (inner_lower.starts_with("select") || inner_lower.starts_with("with"))
345                        && ast_idx < ast_aliases.len()
346                    {
347                        let (ref ast_alias, _) = ast_aliases[ast_idx];
348                        ast_idx += 1;
349
350                        let alias = if ast_alias.is_empty() {
351                            let name = generate_prep_name(&mut auto_name_counter, &used_names);
352                            let name_key = name.to_ascii_uppercase();
353                            used_names.insert(name_key.clone());
354                            claimed_names.insert(name_key);
355                            name
356                        } else {
357                            let alias_key = ast_alias.to_ascii_uppercase();
358                            // If the alias would clash with an existing/previous CTE
359                            // name, leave this subquery in place (SQLFluff parity).
360                            if claimed_names.contains(&alias_key) {
361                                pos = close + 1;
362                                continue;
363                            }
364                            claimed_names.insert(alias_key.clone());
365                            used_names.insert(alias_key);
366                            ast_alias.clone()
367                        };
368
369                        // Parse alias region after close paren.
370                        let (_alias_start, alias_end) =
371                            parse_alias_region_after_close_paren(bytes, close);
372
373                        extractions.push(SubqueryExtraction {
374                            open_paren: after_keyword,
375                            close_paren: close,
376                            alias: alias.clone(),
377                            alias_region_end: alias_end,
378                        });
379
380                        // Skip past the subquery.
381                        pos = alias_end;
382                        continue;
383                    }
384                }
385            }
386        }
387
388        pos += 1;
389    }
390
391    extractions
392}
393
394/// Generate a unique prep_N name that doesn't clash with used_names.
395fn generate_prep_name(counter: &mut usize, used_names: &HashSet<String>) -> String {
396    loop {
397        *counter += 1;
398        let name = format!("prep_{counter}");
399        if !used_names.contains(&name.to_ascii_uppercase()) {
400            return name;
401        }
402    }
403}
404
405/// Collect CTE names from existing WITH clause in the SQL text.
406fn collect_existing_cte_names(sql: &str, names: &mut HashSet<String>) {
407    let bytes = sql.as_bytes();
408    let mut pos = skip_ascii_whitespace(bytes, 0);
409
410    // Check for INSERT ... WITH or CREATE TABLE ... AS WITH patterns.
411    // Skip past INSERT INTO ... or CREATE TABLE ... AS to find WITH.
412    if let Some(end) = match_ascii_keyword_at(bytes, pos, b"INSERT") {
413        pos = skip_to_with_or_select(bytes, end);
414    } else if let Some(end) = match_ascii_keyword_at(bytes, pos, b"CREATE") {
415        pos = skip_to_with_or_select(bytes, end);
416    }
417
418    if match_ascii_keyword_at(bytes, pos, b"WITH").is_none() {
419        return;
420    }
421
422    let with_end = match_ascii_keyword_at(bytes, pos, b"WITH").unwrap();
423    pos = skip_ascii_whitespace(bytes, with_end);
424
425    // Skip RECURSIVE keyword if present.
426    if let Some(end) = match_ascii_keyword_at(bytes, pos, b"RECURSIVE") {
427        pos = skip_ascii_whitespace(bytes, end);
428    }
429
430    // Parse CTE names: name AS (...), name AS (...), ...
431    loop {
432        // Parse CTE name.
433        let name_start = pos;
434        if let Some(quoted_end) = consume_quoted_identifier(bytes, pos) {
435            let raw = &sql[name_start..quoted_end];
436            let unquoted = raw.trim_matches(|c| c == '"' || c == '`' || c == '[' || c == ']');
437            names.insert(unquoted.to_ascii_uppercase());
438            pos = skip_ascii_whitespace(bytes, quoted_end);
439        } else if let Some(name_end) = consume_ascii_identifier(bytes, pos) {
440            names.insert(sql[name_start..name_end].to_ascii_uppercase());
441            pos = skip_ascii_whitespace(bytes, name_end);
442        } else {
443            break;
444        }
445
446        // Expect AS keyword.
447        if let Some(as_end) = match_ascii_keyword_at(bytes, pos, b"AS") {
448            pos = skip_ascii_whitespace(bytes, as_end);
449        } else {
450            break;
451        }
452
453        // Skip the CTE body parenthesized expression.
454        if pos < bytes.len() && bytes[pos] == b'(' {
455            if let Some(close) = find_matching_parenthesis_outside_quotes(sql, pos) {
456                pos = skip_ascii_whitespace(bytes, close + 1);
457            } else {
458                break;
459            }
460        } else {
461            break;
462        }
463
464        // Check for comma (more CTEs follow).
465        if pos < bytes.len() && bytes[pos] == b',' {
466            pos += 1;
467            pos = skip_ascii_whitespace(bytes, pos);
468        } else {
469            break;
470        }
471    }
472}
473
474/// Skip forward in bytes to find the position of WITH or SELECT keyword.
475fn skip_to_with_or_select(bytes: &[u8], mut pos: usize) -> usize {
476    while pos < bytes.len() {
477        let ws = skip_ascii_whitespace(bytes, pos);
478        if ws > pos {
479            pos = ws;
480        }
481        if match_ascii_keyword_at(bytes, pos, b"WITH").is_some() {
482            return pos;
483        }
484        if match_ascii_keyword_at(bytes, pos, b"SELECT").is_some() {
485            return pos;
486        }
487        pos += 1;
488    }
489    pos
490}
491
492/// Parse the alias region (optional `AS` + identifier) after a close parenthesis.
493/// Returns (region_start, region_end) where region_start is close_paren + 1.
494fn parse_alias_region_after_close_paren(bytes: &[u8], close_paren: usize) -> (usize, usize) {
495    let start = close_paren + 1;
496    let mut pos = start;
497    let ws_pos = skip_ascii_whitespace(bytes, pos);
498
499    // Check for AS keyword.
500    if let Some(as_end) = match_ascii_keyword_at(bytes, ws_pos, b"AS") {
501        let after_as = skip_ascii_whitespace(bytes, as_end);
502        if let Some(quoted_end) = consume_quoted_identifier(bytes, after_as) {
503            return (start, quoted_end);
504        }
505        if let Some(ident_end) = consume_ascii_identifier(bytes, after_as) {
506            return (start, ident_end);
507        }
508    }
509
510    // No AS keyword; check for bare identifier alias.
511    // An identifier here is an alias only if it's not a SQL keyword that would
512    // indicate the start of the next clause (ON, USING, WHERE, JOIN, etc.).
513    if let Some(quoted_end) = consume_quoted_identifier(bytes, ws_pos) {
514        return (start, quoted_end);
515    }
516    if let Some(ident_end) = consume_ascii_identifier(bytes, ws_pos) {
517        let word = &bytes[ws_pos..ident_end];
518        if !is_clause_keyword(word) {
519            pos = ident_end;
520            return (start, pos);
521        }
522    }
523
524    (start, start)
525}
526
527/// Check if a word is a SQL clause keyword that should not be treated as an alias.
528fn is_clause_keyword(word: &[u8]) -> bool {
529    let upper: Vec<u8> = word.iter().map(|b| b.to_ascii_uppercase()).collect();
530    matches!(
531        upper.as_slice(),
532        b"ON"
533            | b"USING"
534            | b"WHERE"
535            | b"JOIN"
536            | b"INNER"
537            | b"LEFT"
538            | b"RIGHT"
539            | b"FULL"
540            | b"OUTER"
541            | b"CROSS"
542            | b"NATURAL"
543            | b"GROUP"
544            | b"ORDER"
545            | b"HAVING"
546            | b"LIMIT"
547            | b"UNION"
548            | b"INTERSECT"
549            | b"EXCEPT"
550            | b"MINUS"
551            | b"FROM"
552            | b"SELECT"
553            | b"INSERT"
554            | b"UPDATE"
555            | b"DELETE"
556            | b"SET"
557            | b"INTO"
558            | b"VALUES"
559            | b"WITH"
560    )
561}
562
563/// Apply the subquery extractions: build CTE definitions, replace subqueries
564/// with alias references, and insert the WITH clause.
565fn apply_cte_extractions(
566    sql: &str,
567    extractions: &[SubqueryExtraction],
568    dialect: Dialect,
569) -> Option<String> {
570    if extractions.is_empty() {
571        return None;
572    }
573
574    let case_pref = detect_case_preference(sql);
575
576    // Find if there's an existing WITH clause and where each existing CTE lives.
577    let existing_ctes = parse_existing_cte_ranges(sql);
578
579    // For each extraction, determine if it's inside an existing CTE body.
580    // Build (cte_def, insert_before_cte_index) pairs.
581    struct CteInsertion {
582        definition: String,
583        /// None = append at end / prepend for new WITH. Some(i) = insert before existing CTE i.
584        insert_before: Option<usize>,
585    }
586
587    let mut insertions: Vec<CteInsertion> = Vec::new();
588    let mut replacements: Vec<(usize, usize, String)> = Vec::new();
589
590    for ext in extractions {
591        let subquery_text = &sql[ext.open_paren + 1..ext.close_paren];
592        let as_kw = if case_pref == CasePref::Upper {
593            "AS"
594        } else {
595            "as"
596        };
597        let cte_def = format!("{} {} ({})", ext.alias, as_kw, subquery_text);
598
599        // Check if this extraction is inside an existing CTE body.
600        let containing_cte = existing_ctes
601            .iter()
602            .position(|cte| ext.open_paren >= cte.body_start && ext.close_paren <= cte.body_end);
603
604        insertions.push(CteInsertion {
605            definition: cte_def,
606            insert_before: containing_cte,
607        });
608
609        let mut replacement = ext.alias.clone();
610        if ext.open_paren > 0 {
611            let prev = sql.as_bytes()[ext.open_paren - 1];
612            if !prev.is_ascii_whitespace() {
613                replacement.insert(0, ' ');
614            }
615        }
616
617        replacements.push((ext.open_paren, ext.alias_region_end, replacement));
618    }
619
620    // Apply text replacements in reverse order to preserve positions.
621    let mut result = sql.to_string();
622    for (start, end, replacement) in replacements.into_iter().rev() {
623        result.replace_range(start..end, &replacement);
624    }
625
626    // Now insert CTEs. Separate into two groups:
627    // 1. CTEs that need to be inserted before an existing CTE (dependency ordering)
628    // 2. CTEs that are new top-level (no existing WITH, or appended)
629    let mut before_insertions: Vec<(usize, String)> = Vec::new(); // (cte_index, definition)
630    let mut top_level_defs: Vec<String> = Vec::new();
631
632    for insertion in insertions {
633        match insertion.insert_before {
634            Some(cte_idx) => before_insertions.push((cte_idx, insertion.definition)),
635            None => top_level_defs.push(insertion.definition),
636        }
637    }
638
639    if !before_insertions.is_empty() && !existing_ctes.is_empty() {
640        // We need to rebuild the WITH clause with reordered CTEs.
641        result = rebuild_with_clause_with_insertions(
642            &result,
643            sql,
644            &existing_ctes,
645            &before_insertions,
646            &top_level_defs,
647            case_pref,
648        );
649        return Some(result);
650    }
651
652    // Simple case: just insert/append new CTEs.
653    insert_cte_clause(&result, &top_level_defs, case_pref, dialect)
654}
655
656/// Range info for an existing CTE in the WITH clause.
657#[derive(Debug, Clone)]
658struct ExistingCteRange {
659    /// Byte offset of the CTE body open paren.
660    body_start: usize,
661    /// Byte offset of the CTE body close paren.
662    body_end: usize,
663}
664
665/// Parse the existing CTE definitions in a WITH clause.
666fn parse_existing_cte_ranges(sql: &str) -> Vec<ExistingCteRange> {
667    let bytes = sql.as_bytes();
668    let mut pos = skip_ascii_whitespace(bytes, 0);
669    let mut ranges = Vec::new();
670
671    // Skip INSERT/CREATE prefix.
672    if match_ascii_keyword_at(bytes, pos, b"INSERT").is_some()
673        || match_ascii_keyword_at(bytes, pos, b"CREATE").is_some()
674    {
675        pos = skip_to_with_or_select(bytes, pos + 6);
676    }
677
678    let with_end = match match_ascii_keyword_at(bytes, pos, b"WITH") {
679        Some(end) => end,
680        None => return ranges,
681    };
682    pos = skip_ascii_whitespace(bytes, with_end);
683
684    // Skip RECURSIVE.
685    if let Some(end) = match_ascii_keyword_at(bytes, pos, b"RECURSIVE") {
686        pos = skip_ascii_whitespace(bytes, end);
687    }
688
689    loop {
690        // CTE name.
691        if let Some(quoted_end) = consume_quoted_identifier(bytes, pos) {
692            pos = skip_ascii_whitespace(bytes, quoted_end);
693        } else if let Some(name_end) = consume_ascii_identifier(bytes, pos) {
694            pos = skip_ascii_whitespace(bytes, name_end);
695        } else {
696            break;
697        }
698
699        // AS keyword.
700        if let Some(as_end) = match_ascii_keyword_at(bytes, pos, b"AS") {
701            pos = skip_ascii_whitespace(bytes, as_end);
702        } else {
703            break;
704        }
705
706        // CTE body paren.
707        if pos < bytes.len() && bytes[pos] == b'(' {
708            if let Some(close) = find_matching_parenthesis_outside_quotes(sql, pos) {
709                ranges.push(ExistingCteRange {
710                    body_start: pos,
711                    body_end: close,
712                });
713                pos = skip_ascii_whitespace(bytes, close + 1);
714            } else {
715                break;
716            }
717        } else {
718            break;
719        }
720
721        // Comma.
722        if pos < bytes.len() && bytes[pos] == b',' {
723            pos += 1;
724            pos = skip_ascii_whitespace(bytes, pos);
725        } else {
726            break;
727        }
728    }
729
730    ranges
731}
732
733/// Rebuild the WITH clause with new CTEs inserted before their containing CTEs.
734fn rebuild_with_clause_with_insertions(
735    modified_sql: &str,
736    _original_sql: &str,
737    _existing_ctes: &[ExistingCteRange],
738    before_insertions: &[(usize, String)],
739    top_level_defs: &[String],
740    case_pref: CasePref,
741) -> String {
742    // The modified_sql has already had subquery text replaced with alias names.
743    // We need to reconstruct the WITH clause with CTEs in dependency order.
744    //
745    // Strategy: find the WITH clause region in modified_sql, extract each CTE text,
746    // then rebuild with new CTEs inserted at the right positions.
747
748    let bytes = modified_sql.as_bytes();
749    let mut pos = skip_ascii_whitespace(bytes, 0);
750
751    // Skip INSERT/CREATE prefix.
752    if match_ascii_keyword_at(bytes, pos, b"INSERT").is_some()
753        || match_ascii_keyword_at(bytes, pos, b"CREATE").is_some()
754    {
755        pos = skip_to_with_or_select(bytes, pos + 6);
756    }
757
758    let with_kw_start = pos;
759    let with_end = match match_ascii_keyword_at(bytes, pos, b"WITH") {
760        Some(end) => end,
761        None => return modified_sql.to_string(),
762    };
763    pos = skip_ascii_whitespace(bytes, with_end);
764
765    // Skip RECURSIVE.
766    if let Some(end) = match_ascii_keyword_at(bytes, pos, b"RECURSIVE") {
767        pos = skip_ascii_whitespace(bytes, end);
768    }
769
770    // Parse CTE texts from modified SQL.
771    let mut cte_texts: Vec<String> = Vec::new();
772    let mut last_cte_end = pos;
773
774    loop {
775        let cte_start = pos;
776
777        if let Some(quoted_end) = consume_quoted_identifier(bytes, pos) {
778            pos = skip_ascii_whitespace(bytes, quoted_end);
779        } else if let Some(name_end) = consume_ascii_identifier(bytes, pos) {
780            pos = skip_ascii_whitespace(bytes, name_end);
781        } else {
782            break;
783        }
784
785        if let Some(as_end) = match_ascii_keyword_at(bytes, pos, b"AS") {
786            pos = skip_ascii_whitespace(bytes, as_end);
787        } else {
788            break;
789        }
790
791        if pos < bytes.len() && bytes[pos] == b'(' {
792            if let Some(close) = find_matching_parenthesis_outside_quotes(modified_sql, pos) {
793                let cte_text = modified_sql[cte_start..close + 1].to_string();
794                cte_texts.push(cte_text);
795                last_cte_end = close + 1;
796                pos = skip_ascii_whitespace(bytes, close + 1);
797            } else {
798                break;
799            }
800        } else {
801            break;
802        }
803
804        if pos < bytes.len() && bytes[pos] == b',' {
805            pos += 1;
806            pos = skip_ascii_whitespace(bytes, pos);
807        } else {
808            break;
809        }
810    }
811
812    // Build new CTE list with insertions at the right positions.
813    let mut new_cte_list: Vec<String> = Vec::new();
814    for (i, cte_text) in cte_texts.iter().enumerate() {
815        // Insert any new CTEs that should go before this existing CTE.
816        for (before_idx, def) in before_insertions {
817            if *before_idx == i {
818                new_cte_list.push(def.clone());
819            }
820        }
821        new_cte_list.push(cte_text.clone());
822    }
823
824    // Append top-level defs at end.
825    for def in top_level_defs {
826        new_cte_list.push(def.clone());
827    }
828
829    // Rebuild the SQL.
830    let with_kw = if case_pref == CasePref::Upper {
831        "WITH"
832    } else {
833        "with"
834    };
835    let remainder = &modified_sql[last_cte_end..];
836
837    let mut result = String::with_capacity(modified_sql.len() + 200);
838    result.push_str(&modified_sql[..with_kw_start]);
839    result.push_str(with_kw);
840    result.push(' ');
841    for (i, cte) in new_cte_list.iter().enumerate() {
842        if i > 0 {
843            result.push_str(",\n");
844        }
845        result.push_str(cte);
846    }
847    result.push_str(remainder);
848
849    result
850}
851
852#[derive(Clone, Copy, Debug, Eq, PartialEq)]
853enum CasePref {
854    Upper,
855    Lower,
856}
857
858/// Detect whether the SQL uses uppercase or lowercase keywords.
859fn detect_case_preference(sql: &str) -> CasePref {
860    let bytes = sql.as_bytes();
861    let pos = skip_ascii_whitespace(bytes, 0);
862    // Check the first keyword.
863    for kw in &[b"WITH" as &[u8], b"SELECT", b"INSERT", b"CREATE"] {
864        if pos + kw.len() <= bytes.len() {
865            let word = &bytes[pos..pos + kw.len()];
866            if word
867                .iter()
868                .zip(kw.iter())
869                .all(|(a, b)| a.to_ascii_uppercase() == *b)
870                && is_word_boundary_for_keyword(bytes, pos + kw.len())
871            {
872                return if word[0].is_ascii_uppercase() {
873                    CasePref::Upper
874                } else {
875                    CasePref::Lower
876                };
877            }
878        }
879    }
880    CasePref::Upper
881}
882
883/// Insert CTE definitions into the SQL, handling existing WITH clauses,
884/// INSERT...SELECT, and CTAS patterns.
885fn insert_cte_clause(
886    sql: &str,
887    cte_defs: &[String],
888    case_pref: CasePref,
889    dialect: Dialect,
890) -> Option<String> {
891    let bytes = sql.as_bytes();
892    let with_kw = if case_pref == CasePref::Upper {
893        "WITH"
894    } else {
895        "with"
896    };
897
898    // Check for INSERT...SELECT or CREATE TABLE...AS patterns.
899    let scan_pos = skip_ascii_whitespace(bytes, 0);
900
901    let is_insert = match_ascii_keyword_at(bytes, scan_pos, b"INSERT").is_some();
902    let is_create = match_ascii_keyword_at(bytes, scan_pos, b"CREATE").is_some();
903    let is_tsql_insert = is_insert && dialect == Dialect::Mssql;
904
905    if is_tsql_insert {
906        // T-SQL: WITH goes before INSERT.
907        let insert_pos = skip_ascii_whitespace(bytes, 0);
908        return Some(insert_with_before_position(
909            sql, insert_pos, cte_defs, with_kw,
910        ));
911    }
912
913    if is_create {
914        if let Some(body_pos) = find_create_as_body_position(sql) {
915            return insert_with_at_select(sql, body_pos, cte_defs, with_kw);
916        }
917        // Fallback for unusual CREATE syntaxes.
918        if let Some(pos) = find_main_select_position(sql) {
919            return insert_with_at_select(sql, pos, cte_defs, with_kw);
920        }
921        return None;
922    }
923
924    if is_insert {
925        // For non-TSQL INSERT: find where SELECT/WITH starts and insert there.
926        let select_pos = find_main_select_position(sql);
927        if let Some(pos) = select_pos {
928            return insert_with_at_select(sql, pos, cte_defs, with_kw);
929        }
930        return None;
931    }
932
933    // Look for existing WITH clause.
934    if let Some(with_info) = find_existing_with_clause(sql) {
935        // Append new CTEs to existing WITH clause.
936        return Some(append_to_existing_with(sql, &with_info, cte_defs));
937    }
938
939    // No existing WITH: prepend.
940    let insert_pos = skip_ascii_whitespace(bytes, 0);
941    Some(insert_with_before_position(
942        sql, insert_pos, cte_defs, with_kw,
943    ))
944}
945
946/// Find the start of a CREATE ... AS body (typically SELECT or WITH).
947fn find_create_as_body_position(sql: &str) -> Option<usize> {
948    let bytes = sql.as_bytes();
949    let mut pos = skip_ascii_whitespace(bytes, 0);
950    let create_end = match_ascii_keyword_at(bytes, pos, b"CREATE")?;
951    pos = create_end;
952
953    let mut depth = 0usize;
954    while pos < bytes.len() {
955        if let Some(end) = skip_quoted_region(bytes, pos) {
956            pos = end;
957            continue;
958        }
959        if bytes[pos] == b'-' && bytes.get(pos + 1) == Some(&b'-') {
960            while pos < bytes.len() && bytes[pos] != b'\n' {
961                pos += 1;
962            }
963            continue;
964        }
965        if bytes[pos] == b'/' && bytes.get(pos + 1) == Some(&b'*') {
966            pos += 2;
967            while pos + 1 < bytes.len() {
968                if bytes[pos] == b'*' && bytes[pos + 1] == b'/' {
969                    pos += 2;
970                    break;
971                }
972                pos += 1;
973            }
974            continue;
975        }
976
977        if bytes[pos] == b'(' {
978            depth += 1;
979            pos += 1;
980            continue;
981        }
982        if bytes[pos] == b')' {
983            depth = depth.saturating_sub(1);
984            pos += 1;
985            continue;
986        }
987
988        if depth == 0 {
989            if let Some(as_end) = match_ascii_keyword_at(bytes, pos, b"AS") {
990                return Some(skip_ascii_whitespace(bytes, as_end));
991            }
992        }
993
994        pos += 1;
995    }
996
997    None
998}
999
1000struct ExistingWithInfo {
1001    /// Byte position just after the last CTE definition's closing paren.
1002    last_cte_end: usize,
1003}
1004
1005/// Find the existing WITH clause and return info about where to append.
1006fn find_existing_with_clause(sql: &str) -> Option<ExistingWithInfo> {
1007    let bytes = sql.as_bytes();
1008    let mut pos = skip_ascii_whitespace(bytes, 0);
1009
1010    // Skip INSERT/CREATE prefix.
1011    if match_ascii_keyword_at(bytes, pos, b"INSERT").is_some()
1012        || match_ascii_keyword_at(bytes, pos, b"CREATE").is_some()
1013    {
1014        pos = skip_to_with_or_select(bytes, pos + 6);
1015    }
1016
1017    let _with_end = match_ascii_keyword_at(bytes, pos, b"WITH")?;
1018    let mut cursor = skip_ascii_whitespace(bytes, _with_end);
1019
1020    // Skip RECURSIVE.
1021    if let Some(end) = match_ascii_keyword_at(bytes, cursor, b"RECURSIVE") {
1022        cursor = skip_ascii_whitespace(bytes, end);
1023    }
1024
1025    // Walk through CTE definitions to find the last one.
1026    let mut last_cte_end = cursor;
1027    loop {
1028        // Skip CTE name.
1029        if let Some(quoted_end) = consume_quoted_identifier(bytes, cursor) {
1030            cursor = skip_ascii_whitespace(bytes, quoted_end);
1031        } else if let Some(name_end) = consume_ascii_identifier(bytes, cursor) {
1032            cursor = skip_ascii_whitespace(bytes, name_end);
1033        } else {
1034            break;
1035        }
1036
1037        // AS keyword.
1038        if let Some(as_end) = match_ascii_keyword_at(bytes, cursor, b"AS") {
1039            cursor = skip_ascii_whitespace(bytes, as_end);
1040        } else {
1041            break;
1042        }
1043
1044        // CTE body.
1045        if cursor < bytes.len() && bytes[cursor] == b'(' {
1046            if let Some(close) = find_matching_parenthesis_outside_quotes(sql, cursor) {
1047                last_cte_end = close + 1;
1048                cursor = skip_ascii_whitespace(bytes, close + 1);
1049            } else {
1050                break;
1051            }
1052        } else {
1053            break;
1054        }
1055
1056        // Comma means more CTEs.
1057        if cursor < bytes.len() && bytes[cursor] == b',' {
1058            cursor += 1;
1059            cursor = skip_ascii_whitespace(bytes, cursor);
1060        } else {
1061            break;
1062        }
1063    }
1064
1065    Some(ExistingWithInfo { last_cte_end })
1066}
1067
1068/// Append new CTE definitions after the last existing CTE.
1069fn append_to_existing_with(sql: &str, with_info: &ExistingWithInfo, cte_defs: &[String]) -> String {
1070    let insert_pos = with_info.last_cte_end;
1071    let mut result =
1072        String::with_capacity(sql.len() + cte_defs.iter().map(|d| d.len() + 4).sum::<usize>());
1073    result.push_str(&sql[..insert_pos]);
1074    for def in cte_defs {
1075        result.push_str(",\n");
1076        result.push_str(def);
1077    }
1078    result.push_str(&sql[insert_pos..]);
1079    result
1080}
1081
1082/// Insert WITH clause before a given position.
1083fn insert_with_before_position(
1084    sql: &str,
1085    pos: usize,
1086    cte_defs: &[String],
1087    with_kw: &str,
1088) -> String {
1089    let mut result = String::with_capacity(sql.len() + 100);
1090    result.push_str(&sql[..pos]);
1091    result.push_str(with_kw);
1092    result.push(' ');
1093    for (i, def) in cte_defs.iter().enumerate() {
1094        if i > 0 {
1095            result.push_str(",\n");
1096        }
1097        result.push_str(def);
1098    }
1099    result.push('\n');
1100    result.push_str(&sql[pos..]);
1101    result
1102}
1103
1104/// Insert WITH clause before a SELECT that is preceded by INSERT/CREATE.
1105fn insert_with_at_select(
1106    sql: &str,
1107    select_pos: usize,
1108    cte_defs: &[String],
1109    with_kw: &str,
1110) -> Option<String> {
1111    // Check if there's already a WITH clause at this position.
1112    let bytes = sql.as_bytes();
1113    if match_ascii_keyword_at(bytes, select_pos, b"WITH").is_some() {
1114        // Existing WITH at select position — append to it.
1115        if let Some(with_info) = find_existing_with_clause_at(sql, select_pos) {
1116            return Some(append_to_existing_with(sql, &with_info, cte_defs));
1117        }
1118    }
1119
1120    Some(insert_with_before_position(
1121        sql, select_pos, cte_defs, with_kw,
1122    ))
1123}
1124
1125/// Find existing WITH clause starting at a specific position.
1126fn find_existing_with_clause_at(sql: &str, start: usize) -> Option<ExistingWithInfo> {
1127    let bytes = sql.as_bytes();
1128    let _with_end = match_ascii_keyword_at(bytes, start, b"WITH")?;
1129    let mut cursor = skip_ascii_whitespace(bytes, _with_end);
1130
1131    // Skip RECURSIVE.
1132    if let Some(end) = match_ascii_keyword_at(bytes, cursor, b"RECURSIVE") {
1133        cursor = skip_ascii_whitespace(bytes, end);
1134    }
1135
1136    let mut last_cte_end = cursor;
1137    loop {
1138        if let Some(quoted_end) = consume_quoted_identifier(bytes, cursor) {
1139            cursor = skip_ascii_whitespace(bytes, quoted_end);
1140        } else if let Some(name_end) = consume_ascii_identifier(bytes, cursor) {
1141            cursor = skip_ascii_whitespace(bytes, name_end);
1142        } else {
1143            break;
1144        }
1145
1146        if let Some(as_end) = match_ascii_keyword_at(bytes, cursor, b"AS") {
1147            cursor = skip_ascii_whitespace(bytes, as_end);
1148        } else {
1149            break;
1150        }
1151
1152        if cursor < bytes.len() && bytes[cursor] == b'(' {
1153            if let Some(close) = find_matching_parenthesis_outside_quotes(sql, cursor) {
1154                last_cte_end = close + 1;
1155                cursor = skip_ascii_whitespace(bytes, close + 1);
1156            } else {
1157                break;
1158            }
1159        } else {
1160            break;
1161        }
1162
1163        if cursor < bytes.len() && bytes[cursor] == b',' {
1164            cursor += 1;
1165            cursor = skip_ascii_whitespace(bytes, cursor);
1166        } else {
1167            break;
1168        }
1169    }
1170
1171    Some(ExistingWithInfo { last_cte_end })
1172}
1173
1174/// Find the position of the main SELECT keyword in an INSERT or CREATE statement.
1175fn find_main_select_position(sql: &str) -> Option<usize> {
1176    let bytes = sql.as_bytes();
1177    let mut pos = 0usize;
1178    let mut depth = 0usize;
1179
1180    while pos < bytes.len() {
1181        if let Some(end) = skip_quoted_region(bytes, pos) {
1182            pos = end;
1183            continue;
1184        }
1185        if bytes[pos] == b'-' && bytes.get(pos + 1) == Some(&b'-') {
1186            while pos < bytes.len() && bytes[pos] != b'\n' {
1187                pos += 1;
1188            }
1189            continue;
1190        }
1191        if bytes[pos] == b'/' && bytes.get(pos + 1) == Some(&b'*') {
1192            pos += 2;
1193            while pos + 1 < bytes.len() {
1194                if bytes[pos] == b'*' && bytes[pos + 1] == b'/' {
1195                    pos += 2;
1196                    break;
1197                }
1198                pos += 1;
1199            }
1200            continue;
1201        }
1202
1203        if bytes[pos] == b'(' {
1204            depth += 1;
1205            pos += 1;
1206            continue;
1207        }
1208        if bytes[pos] == b')' {
1209            depth = depth.saturating_sub(1);
1210            pos += 1;
1211            continue;
1212        }
1213
1214        // Only at depth 0, look for SELECT or WITH keyword.
1215        if depth == 0 {
1216            if match_ascii_keyword_at(bytes, pos, b"WITH").is_some() {
1217                return Some(pos);
1218            }
1219            if match_ascii_keyword_at(bytes, pos, b"SELECT").is_some() {
1220                return Some(pos);
1221            }
1222        }
1223
1224        pos += 1;
1225    }
1226    None
1227}
1228
1229/// Skip a quoted region (single quote, double quote, backtick, bracket).
1230/// Returns the position after the closing quote, or None if not in a quoted region.
1231fn skip_quoted_region(bytes: &[u8], pos: usize) -> Option<usize> {
1232    let b = bytes[pos];
1233    if b == b'\'' {
1234        return Some(skip_to_close_quote(bytes, pos + 1, b'\''));
1235    }
1236    if b == b'"' {
1237        return Some(skip_to_close_quote(bytes, pos + 1, b'"'));
1238    }
1239    if b == b'`' {
1240        return Some(skip_to_close_quote(bytes, pos + 1, b'`'));
1241    }
1242    if b == b'[' {
1243        return Some(skip_to_close_quote(bytes, pos + 1, b']'));
1244    }
1245    None
1246}
1247
1248fn skip_to_close_quote(bytes: &[u8], mut pos: usize, close: u8) -> usize {
1249    while pos < bytes.len() {
1250        if bytes[pos] == close {
1251            if bytes.get(pos + 1) == Some(&close) {
1252                pos += 2; // Escaped quote.
1253            } else {
1254                return pos + 1;
1255            }
1256        } else {
1257            pos += 1;
1258        }
1259    }
1260    pos
1261}
1262
1263/// Consume a quoted identifier (double-quoted, backtick-quoted, or bracket-quoted).
1264fn consume_quoted_identifier(bytes: &[u8], pos: usize) -> Option<usize> {
1265    if pos >= bytes.len() {
1266        return None;
1267    }
1268    match bytes[pos] {
1269        b'"' => Some(skip_to_close_quote(bytes, pos + 1, b'"')),
1270        b'`' => Some(skip_to_close_quote(bytes, pos + 1, b'`')),
1271        b'[' => Some(skip_to_close_quote(bytes, pos + 1, b']')),
1272        _ => None,
1273    }
1274}
1275
1276/// Match a multi-word JOIN keyword sequence like INNER JOIN, LEFT JOIN, etc.
1277/// Returns the byte position after the final JOIN keyword.
1278fn match_join_keyword_sequence(bytes: &[u8], pos: usize) -> Option<usize> {
1279    // Check for: INNER JOIN, LEFT [OUTER] JOIN, RIGHT [OUTER] JOIN,
1280    // FULL [OUTER] JOIN, CROSS JOIN, LEFT OUTER JOIN, etc.
1281    let prefixes: &[&[u8]] = &[b"INNER", b"LEFT", b"RIGHT", b"FULL", b"CROSS", b"NATURAL"];
1282
1283    for prefix in prefixes {
1284        if let Some(prefix_end) = match_ascii_keyword_at(bytes, pos, prefix) {
1285            let mut cursor = skip_ascii_whitespace(bytes, prefix_end);
1286
1287            // Optional OUTER keyword.
1288            if let Some(outer_end) = match_ascii_keyword_at(bytes, cursor, b"OUTER") {
1289                cursor = skip_ascii_whitespace(bytes, outer_end);
1290            }
1291
1292            if let Some(join_end) = match_ascii_keyword_at(bytes, cursor, b"JOIN") {
1293                return Some(join_end);
1294            }
1295        }
1296    }
1297    None
1298}
1299
1300fn find_matching_parenthesis_outside_quotes(sql: &str, open_paren_index: usize) -> Option<usize> {
1301    #[derive(Clone, Copy, PartialEq, Eq)]
1302    enum Mode {
1303        Outside,
1304        SingleQuote,
1305        DoubleQuote,
1306        BacktickQuote,
1307        BracketQuote,
1308    }
1309
1310    let bytes = sql.as_bytes();
1311    if open_paren_index >= bytes.len() || bytes[open_paren_index] != b'(' {
1312        return None;
1313    }
1314
1315    let mut depth = 0usize;
1316    let mut mode = Mode::Outside;
1317    let mut index = open_paren_index;
1318
1319    while index < bytes.len() {
1320        let byte = bytes[index];
1321        let next = bytes.get(index + 1).copied();
1322
1323        match mode {
1324            Mode::Outside => {
1325                if byte == b'\'' {
1326                    mode = Mode::SingleQuote;
1327                    index += 1;
1328                    continue;
1329                }
1330                if byte == b'"' {
1331                    mode = Mode::DoubleQuote;
1332                    index += 1;
1333                    continue;
1334                }
1335                if byte == b'`' {
1336                    mode = Mode::BacktickQuote;
1337                    index += 1;
1338                    continue;
1339                }
1340                if byte == b'[' {
1341                    mode = Mode::BracketQuote;
1342                    index += 1;
1343                    continue;
1344                }
1345                if byte == b'(' {
1346                    depth += 1;
1347                    index += 1;
1348                    continue;
1349                }
1350                if byte == b')' {
1351                    depth = depth.checked_sub(1)?;
1352                    if depth == 0 {
1353                        return Some(index);
1354                    }
1355                }
1356                index += 1;
1357            }
1358            Mode::SingleQuote => {
1359                if byte == b'\'' {
1360                    if next == Some(b'\'') {
1361                        index += 2;
1362                    } else {
1363                        mode = Mode::Outside;
1364                        index += 1;
1365                    }
1366                } else {
1367                    index += 1;
1368                }
1369            }
1370            Mode::DoubleQuote => {
1371                if byte == b'"' {
1372                    if next == Some(b'"') {
1373                        index += 2;
1374                    } else {
1375                        mode = Mode::Outside;
1376                        index += 1;
1377                    }
1378                } else {
1379                    index += 1;
1380                }
1381            }
1382            Mode::BacktickQuote => {
1383                if byte == b'`' {
1384                    if next == Some(b'`') {
1385                        index += 2;
1386                    } else {
1387                        mode = Mode::Outside;
1388                        index += 1;
1389                    }
1390                } else {
1391                    index += 1;
1392                }
1393            }
1394            Mode::BracketQuote => {
1395                if byte == b']' {
1396                    if next == Some(b']') {
1397                        index += 2;
1398                    } else {
1399                        mode = Mode::Outside;
1400                        index += 1;
1401                    }
1402                } else {
1403                    index += 1;
1404                }
1405            }
1406        }
1407    }
1408
1409    None
1410}
1411
1412fn is_ascii_whitespace_byte(byte: u8) -> bool {
1413    matches!(byte, b' ' | b'\n' | b'\r' | b'\t' | 0x0b | 0x0c)
1414}
1415
1416fn is_ascii_ident_start(byte: u8) -> bool {
1417    byte.is_ascii_alphabetic() || byte == b'_'
1418}
1419
1420fn is_ascii_ident_continue(byte: u8) -> bool {
1421    byte.is_ascii_alphanumeric() || byte == b'_'
1422}
1423
1424fn skip_ascii_whitespace(bytes: &[u8], mut index: usize) -> usize {
1425    while index < bytes.len() && is_ascii_whitespace_byte(bytes[index]) {
1426        index += 1;
1427    }
1428    index
1429}
1430
1431fn consume_ascii_identifier(bytes: &[u8], start: usize) -> Option<usize> {
1432    if start >= bytes.len() || !is_ascii_ident_start(bytes[start]) {
1433        return None;
1434    }
1435    let mut index = start + 1;
1436    while index < bytes.len() && is_ascii_ident_continue(bytes[index]) {
1437        index += 1;
1438    }
1439    Some(index)
1440}
1441
1442fn is_word_boundary_for_keyword(bytes: &[u8], index: usize) -> bool {
1443    index == 0 || index >= bytes.len() || !is_ascii_ident_continue(bytes[index])
1444}
1445
1446fn match_ascii_keyword_at(bytes: &[u8], start: usize, keyword_upper: &[u8]) -> Option<usize> {
1447    let end = start.checked_add(keyword_upper.len())?;
1448    if end > bytes.len() {
1449        return None;
1450    }
1451    if !is_word_boundary_for_keyword(bytes, start.saturating_sub(1))
1452        || !is_word_boundary_for_keyword(bytes, end)
1453    {
1454        return None;
1455    }
1456    let matches = bytes[start..end]
1457        .iter()
1458        .zip(keyword_upper.iter())
1459        .all(|(actual, expected)| actual.to_ascii_uppercase() == *expected);
1460    if matches {
1461        Some(end)
1462    } else {
1463        None
1464    }
1465}
1466
1467fn table_factor_contains_derived(
1468    table_factor: &TableFactor,
1469    outer_source_names: &HashSet<String>,
1470) -> bool {
1471    match table_factor {
1472        TableFactor::Derived { subquery, .. } => {
1473            !query_references_outer_sources(subquery, outer_source_names)
1474        }
1475        TableFactor::NestedJoin {
1476            table_with_joins, ..
1477        } => {
1478            table_factor_contains_derived(&table_with_joins.relation, outer_source_names)
1479                || table_with_joins
1480                    .joins
1481                    .iter()
1482                    .any(|join| table_factor_contains_derived(&join.relation, outer_source_names))
1483        }
1484        TableFactor::Pivot { table, .. }
1485        | TableFactor::Unpivot { table, .. }
1486        | TableFactor::MatchRecognize { table, .. } => {
1487            table_factor_contains_derived(table, outer_source_names)
1488        }
1489        _ => false,
1490    }
1491}
1492
1493fn query_references_outer_sources(query: &Query, outer_source_names: &HashSet<String>) -> bool {
1494    if let Some(with) = &query.with {
1495        for cte in &with.cte_tables {
1496            if query_references_outer_sources(&cte.query, outer_source_names) {
1497                return true;
1498            }
1499        }
1500    }
1501
1502    set_expr_references_outer_sources(&query.body, outer_source_names)
1503}
1504
1505fn set_expr_references_outer_sources(
1506    set_expr: &SetExpr,
1507    outer_source_names: &HashSet<String>,
1508) -> bool {
1509    match set_expr {
1510        SetExpr::Select(select) => select_references_outer_sources(select, outer_source_names),
1511        SetExpr::Query(query) => query_references_outer_sources(query, outer_source_names),
1512        SetExpr::SetOperation { left, right, .. } => {
1513            set_expr_references_outer_sources(left, outer_source_names)
1514                || set_expr_references_outer_sources(right, outer_source_names)
1515        }
1516        _ => false,
1517    }
1518}
1519
1520fn select_references_outer_sources(select: &Select, outer_source_names: &HashSet<String>) -> bool {
1521    let mut qualifier_prefixes = HashSet::new();
1522    visit_select_expressions(select, &mut |expr| {
1523        collect_qualifier_prefixes_in_expr(expr, &mut qualifier_prefixes);
1524    });
1525
1526    let local_source_names = source_names_in_select(select);
1527    if qualifier_prefixes
1528        .iter()
1529        .any(|name| outer_source_names.contains(name) && !local_source_names.contains(name))
1530    {
1531        return true;
1532    }
1533
1534    for table in &select.from {
1535        if table_factor_references_outer_sources(&table.relation, outer_source_names) {
1536            return true;
1537        }
1538        for join in &table.joins {
1539            if table_factor_references_outer_sources(&join.relation, outer_source_names) {
1540                return true;
1541            }
1542        }
1543    }
1544    false
1545}
1546
1547fn table_factor_references_outer_sources(
1548    table_factor: &TableFactor,
1549    outer_source_names: &HashSet<String>,
1550) -> bool {
1551    match table_factor {
1552        TableFactor::Derived { subquery, .. } => {
1553            query_references_outer_sources(subquery, outer_source_names)
1554        }
1555        TableFactor::NestedJoin {
1556            table_with_joins, ..
1557        } => {
1558            table_factor_references_outer_sources(&table_with_joins.relation, outer_source_names)
1559                || table_with_joins.joins.iter().any(|join| {
1560                    table_factor_references_outer_sources(&join.relation, outer_source_names)
1561                })
1562        }
1563        TableFactor::Pivot { table, .. }
1564        | TableFactor::Unpivot { table, .. }
1565        | TableFactor::MatchRecognize { table, .. } => {
1566            table_factor_references_outer_sources(table, outer_source_names)
1567        }
1568        _ => false,
1569    }
1570}
1571
1572fn source_names_in_select(select: &Select) -> HashSet<String> {
1573    let mut names = HashSet::new();
1574    for table in &select.from {
1575        collect_source_names_from_table_factor(&table.relation, &mut names);
1576        for join in &table.joins {
1577            collect_source_names_from_table_factor(&join.relation, &mut names);
1578        }
1579    }
1580    names
1581}
1582
1583fn collect_source_names_from_table_factor(table_factor: &TableFactor, names: &mut HashSet<String>) {
1584    match table_factor {
1585        TableFactor::Table { name, alias, .. } => {
1586            if let Some(last) = name.0.last().and_then(|part| part.as_ident()) {
1587                names.insert(last.value.to_ascii_uppercase());
1588            }
1589            if let Some(alias) = alias {
1590                names.insert(alias.name.value.to_ascii_uppercase());
1591            }
1592        }
1593        TableFactor::Derived {
1594            alias, subquery, ..
1595        } => {
1596            if let Some(alias) = alias {
1597                names.insert(alias.name.value.to_ascii_uppercase());
1598            }
1599            if let Some(with) = &subquery.with {
1600                for cte in &with.cte_tables {
1601                    names.insert(cte.alias.name.value.to_ascii_uppercase());
1602                }
1603            }
1604        }
1605        TableFactor::TableFunction { alias, .. }
1606        | TableFactor::Function { alias, .. }
1607        | TableFactor::UNNEST { alias, .. }
1608        | TableFactor::JsonTable { alias, .. }
1609        | TableFactor::OpenJsonTable { alias, .. } => {
1610            if let Some(alias) = alias {
1611                names.insert(alias.name.value.to_ascii_uppercase());
1612            }
1613        }
1614        TableFactor::NestedJoin {
1615            table_with_joins, ..
1616        } => {
1617            collect_source_names_from_table_factor(&table_with_joins.relation, names);
1618            for join in &table_with_joins.joins {
1619                collect_source_names_from_table_factor(&join.relation, names);
1620            }
1621        }
1622        TableFactor::Pivot { table, .. }
1623        | TableFactor::Unpivot { table, .. }
1624        | TableFactor::MatchRecognize { table, .. } => {
1625            collect_source_names_from_table_factor(table, names);
1626        }
1627        _ => {}
1628    }
1629}
1630
1631#[cfg(test)]
1632mod tests {
1633    use super::*;
1634    use crate::linter::{config::LintConfig, rule::LintContext, Linter};
1635    use crate::parse_sql;
1636    use crate::types::IssueAutofixApplicability;
1637
1638    fn run(sql: &str) -> Vec<Issue> {
1639        let statements = parse_sql(sql).expect("parse sql");
1640        let linter = Linter::new(LintConfig::default());
1641        let stmt = &statements[0];
1642        let ctx = LintContext {
1643            sql,
1644            statement_range: 0..sql.len(),
1645            statement_index: 0,
1646        };
1647        linter.check_statement(stmt, &ctx)
1648    }
1649
1650    fn apply_issue_autofix(sql: &str, issue: &Issue) -> Option<String> {
1651        let autofix = issue.autofix.as_ref()?;
1652        let mut out = sql.to_string();
1653        let mut edits = autofix.edits.clone();
1654        edits.sort_by_key(|edit| (edit.span.start, edit.span.end));
1655        for edit in edits.into_iter().rev() {
1656            out.replace_range(edit.span.start..edit.span.end, &edit.replacement);
1657        }
1658        Some(out)
1659    }
1660
1661    #[test]
1662    fn default_does_not_flag_subquery_in_from() {
1663        let issues = run("SELECT * FROM (SELECT * FROM t) sub");
1664        assert!(!issues
1665            .iter()
1666            .any(|issue| issue.code == issue_codes::LINT_ST_005));
1667    }
1668
1669    #[test]
1670    fn default_flags_subquery_in_join() {
1671        let issues = run("SELECT * FROM t JOIN (SELECT * FROM u) sub ON t.id = sub.id");
1672        assert!(issues
1673            .iter()
1674            .any(|issue| issue.code == issue_codes::LINT_ST_005));
1675    }
1676
1677    #[test]
1678    fn default_allows_correlated_subquery_join_without_alias() {
1679        let issues = run("SELECT pd.* \
1680             FROM person_dates \
1681             JOIN (SELECT * FROM events WHERE events.name = person_dates.name)");
1682        assert!(!issues
1683            .iter()
1684            .any(|issue| issue.code == issue_codes::LINT_ST_005));
1685    }
1686
1687    #[test]
1688    fn default_allows_correlated_subquery_join_with_alias_reference() {
1689        let issues = run("SELECT pd.* \
1690             FROM person_dates AS pd \
1691             JOIN (SELECT * FROM events AS ce WHERE ce.name = pd.name)");
1692        assert!(!issues
1693            .iter()
1694            .any(|issue| issue.code == issue_codes::LINT_ST_005));
1695    }
1696
1697    #[test]
1698    fn default_allows_correlated_subquery_join_with_outer_table_name_reference() {
1699        let issues = run("SELECT pd.* \
1700             FROM person_dates AS pd \
1701             JOIN (SELECT * FROM events AS ce WHERE ce.name = person_dates.name)");
1702        assert!(!issues
1703            .iter()
1704            .any(|issue| issue.code == issue_codes::LINT_ST_005));
1705    }
1706
1707    #[test]
1708    fn does_not_flag_cte_usage() {
1709        let issues = run("WITH sub AS (SELECT * FROM t) SELECT * FROM sub");
1710        assert!(!issues
1711            .iter()
1712            .any(|issue| issue.code == issue_codes::LINT_ST_005));
1713    }
1714
1715    #[test]
1716    fn does_not_flag_scalar_subquery_in_where() {
1717        let issues = run("SELECT * FROM t WHERE id IN (SELECT id FROM u)");
1718        assert!(!issues
1719            .iter()
1720            .any(|issue| issue.code == issue_codes::LINT_ST_005));
1721    }
1722
1723    #[test]
1724    fn forbid_subquery_in_join_does_not_flag_from_subquery() {
1725        let sql = "SELECT * FROM (SELECT * FROM t) sub";
1726        let statements = parse_sql(sql).expect("parse sql");
1727        let rule = StructureSubquery::from_config(&LintConfig {
1728            enabled: true,
1729            disabled_rules: vec![],
1730            rule_configs: std::collections::BTreeMap::from([(
1731                "structure.subquery".to_string(),
1732                serde_json::json!({"forbid_subquery_in": "join"}),
1733            )]),
1734        });
1735        let issues = rule.check(
1736            &statements[0],
1737            &LintContext {
1738                sql,
1739                statement_range: 0..sql.len(),
1740                statement_index: 0,
1741            },
1742        );
1743        assert!(issues.is_empty());
1744    }
1745
1746    #[test]
1747    fn forbid_subquery_in_from_emits_unsafe_cte_autofix_for_simple_case() {
1748        let sql = "SELECT * FROM (SELECT 1) sub";
1749        let statements = parse_sql(sql).expect("parse sql");
1750        let rule = StructureSubquery::from_config(&LintConfig {
1751            enabled: true,
1752            disabled_rules: vec![],
1753            rule_configs: std::collections::BTreeMap::from([(
1754                "LINT_ST_005".to_string(),
1755                serde_json::json!({"forbid_subquery_in": "from"}),
1756            )]),
1757        });
1758        let issues = rule.check(
1759            &statements[0],
1760            &LintContext {
1761                sql,
1762                statement_range: 0..sql.len(),
1763                statement_index: 0,
1764            },
1765        );
1766        assert_eq!(issues.len(), 1);
1767        let autofix = issues[0].autofix.as_ref().expect("autofix metadata");
1768        assert_eq!(autofix.applicability, IssueAutofixApplicability::Unsafe);
1769        let fixed = apply_issue_autofix(sql, &issues[0]).expect("apply autofix");
1770        assert_eq!(fixed, "WITH sub AS (SELECT 1)\nSELECT * FROM sub");
1771    }
1772
1773    #[test]
1774    fn forbid_subquery_in_from_does_not_flag_join_subquery() {
1775        let sql = "SELECT * FROM t JOIN (SELECT * FROM u) sub ON t.id = sub.id";
1776        let statements = parse_sql(sql).expect("parse sql");
1777        let rule = StructureSubquery::from_config(&LintConfig {
1778            enabled: true,
1779            disabled_rules: vec![],
1780            rule_configs: std::collections::BTreeMap::from([(
1781                "LINT_ST_005".to_string(),
1782                serde_json::json!({"forbid_subquery_in": "from"}),
1783            )]),
1784        });
1785        let issues = rule.check(
1786            &statements[0],
1787            &LintContext {
1788                sql,
1789                statement_range: 0..sql.len(),
1790                statement_index: 0,
1791            },
1792        );
1793        assert!(issues.is_empty());
1794    }
1795
1796    #[test]
1797    fn forbid_both_flags_subquery_inside_cte_body() {
1798        let sql = "WITH b AS (SELECT x, z FROM (SELECT x, z FROM p_cte)) SELECT b.z FROM b";
1799        let statements = parse_sql(sql).expect("parse sql");
1800        let rule = StructureSubquery::from_config(&LintConfig {
1801            enabled: true,
1802            disabled_rules: vec![],
1803            rule_configs: std::collections::BTreeMap::from([(
1804                "structure.subquery".to_string(),
1805                serde_json::json!({"forbid_subquery_in": "both"}),
1806            )]),
1807        });
1808        let issues = rule.check(
1809            &statements[0],
1810            &LintContext {
1811                sql,
1812                statement_range: 0..sql.len(),
1813                statement_index: 0,
1814            },
1815        );
1816        assert_eq!(issues.len(), 1);
1817    }
1818
1819    #[test]
1820    fn forbid_both_flags_subqueries_in_set_operation_second_branch() {
1821        let sql = "SELECT 1 AS value_name UNION SELECT value FROM (SELECT 2 AS value_name) CROSS JOIN (SELECT 1 AS v2)";
1822        let statements = parse_sql(sql).expect("parse sql");
1823        let rule = StructureSubquery::from_config(&LintConfig {
1824            enabled: true,
1825            disabled_rules: vec![],
1826            rule_configs: std::collections::BTreeMap::from([(
1827                "structure.subquery".to_string(),
1828                serde_json::json!({"forbid_subquery_in": "both"}),
1829            )]),
1830        });
1831        let issues = rule.check(
1832            &statements[0],
1833            &LintContext {
1834                sql,
1835                statement_range: 0..sql.len(),
1836                statement_index: 0,
1837            },
1838        );
1839        assert_eq!(issues.len(), 2);
1840    }
1841
1842    // --- Fixture-based rewriter tests ---
1843
1844    fn run_fix(sql: &str, forbid_in: &str) -> Option<String> {
1845        let statements = parse_sql(sql).expect("parse sql");
1846        let rule = StructureSubquery::from_config(&LintConfig {
1847            enabled: true,
1848            disabled_rules: vec![],
1849            rule_configs: std::collections::BTreeMap::from([(
1850                "structure.subquery".to_string(),
1851                serde_json::json!({"forbid_subquery_in": forbid_in}),
1852            )]),
1853        });
1854        let ctx = LintContext {
1855            sql,
1856            statement_range: 0..sql.len(),
1857            statement_index: 0,
1858        };
1859        let issues = rule.check(&statements[0], &ctx);
1860        if issues.is_empty() {
1861            return None;
1862        }
1863        let st05_issue = issues
1864            .iter()
1865            .find(|i| i.code == issue_codes::LINT_ST_005 && i.autofix.is_some())?;
1866        apply_issue_autofix(sql, st05_issue)
1867    }
1868
1869    fn assert_fix_whitespace_eq(actual: &str, expected: &str) {
1870        let norm = |s: &str| s.split_whitespace().collect::<Vec<_>>().join(" ");
1871        assert_eq!(
1872            norm(actual),
1873            norm(expected),
1874            "\n--- actual ---\n{actual}\n--- expected ---\n{expected}\n"
1875        );
1876    }
1877
1878    #[test]
1879    fn fixture_select_fail() {
1880        let sql = "select\n    a.x, a.y, b.z\nfrom a\njoin (\n    select x, z from b\n) as b on (a.x = b.x)\n";
1881        let expected = "with b as (\n    select x, z from b\n)\nselect\n    a.x, a.y, b.z\nfrom a\njoin b on (a.x = b.x)\n";
1882        let fixed = run_fix(sql, "join").expect("should produce fix");
1883        assert_fix_whitespace_eq(&fixed, expected);
1884    }
1885
1886    #[test]
1887    fn fixture_cte_select_fail() {
1888        let sql = "with prep as (\n  select 1 as x, 2 as z\n)\nselect\n    a.x, a.y, b.z\nfrom a\njoin (\n    select x, z from b\n) as b on (a.x = b.x)\n";
1889        let expected = "with prep as (\n  select 1 as x, 2 as z\n),\nb as (\n    select x, z from b\n)\nselect\n    a.x, a.y, b.z\nfrom a\njoin b on (a.x = b.x)\n";
1890        let fixed = run_fix(sql, "join").expect("should produce fix");
1891        assert_fix_whitespace_eq(&fixed, expected);
1892    }
1893
1894    #[test]
1895    fn fixture_from_clause_fail() {
1896        let sql = "select\n    a.x, a.y\nfrom (\n    select * from b\n) as a\n";
1897        let expected = "with a as (\n    select * from b\n)\nselect\n    a.x, a.y\nfrom a\n";
1898        let fixed = run_fix(sql, "from").expect("should produce fix");
1899        assert_fix_whitespace_eq(&fixed, expected);
1900    }
1901
1902    #[test]
1903    fn fixture_both_clause_fail() {
1904        let sql = "select\n    a.x, a.y\nfrom (\n    select * from b\n) as a\n";
1905        let expected = "with a as (\n    select * from b\n)\nselect\n    a.x, a.y\nfrom a\n";
1906        let fixed = run_fix(sql, "both").expect("should produce fix");
1907        assert_fix_whitespace_eq(&fixed, expected);
1908    }
1909
1910    #[test]
1911    fn fixture_cte_with_clashing_name_generates_prep() {
1912        let sql = "with prep_1 as (\n  select 1 as x, 2 as z\n)\nselect\n    a.x, a.y, z\nfrom a\njoin (\n    select x, z from b\n) on a.x = z\n";
1913        let fixed = run_fix(sql, "join").expect("should produce fix");
1914        // Should generate prep_2 since prep_1 exists.
1915        assert!(
1916            fixed.contains("prep_2"),
1917            "expected prep_2 in output: {fixed}"
1918        );
1919    }
1920
1921    #[test]
1922    fn fixture_set_subquery_in_second_query() {
1923        let sql = "SELECT 1 AS value_name\nUNION\nSELECT value\nFROM (SELECT 2 AS value_name);\n";
1924        let expected = "WITH prep_1 AS (SELECT 2 AS value_name)\nSELECT 1 AS value_name\nUNION\nSELECT value\nFROM prep_1;\n";
1925        let fixed = run_fix(sql, "both").expect("should produce fix");
1926        assert_fix_whitespace_eq(&fixed, expected);
1927    }
1928
1929    #[test]
1930    fn fixture_set_subquery_in_second_query_join() {
1931        let sql = "SELECT 1 AS value_name\nUNION\nSELECT value\nFROM (SELECT 2 AS value_name)\nCROSS JOIN (SELECT 1 as v2);\n";
1932        let expected = "WITH prep_1 AS (SELECT 2 AS value_name),\nprep_2 AS (SELECT 1 as v2)\nSELECT 1 AS value_name\nUNION\nSELECT value\nFROM prep_1\nCROSS JOIN prep_2;\n";
1933        let fixed = run_fix(sql, "both").expect("should produce fix");
1934        assert_fix_whitespace_eq(&fixed, expected);
1935    }
1936
1937    #[test]
1938    fn fixture_with_fail_generates_prep_for_unnamed_subquery() {
1939        let sql = "select\n    a.x, a.y, b.z\nfrom a\njoin (\n    with d as (\n        select x, z from b\n    )\n    select * from d\n) using (x)\n";
1940        let fixed = run_fix(sql, "join").expect("should produce fix");
1941        assert!(
1942            fixed.contains("prep_1"),
1943            "expected prep_1 in output: {fixed}"
1944        );
1945    }
1946
1947    #[test]
1948    fn fixture_set_fail() {
1949        let sql = "SELECT\n    a.x, a.y, b.z\nFROM a\nJOIN (\n    select x, z from b\n    union\n    select x, z from d\n) USING (x)\n";
1950        let fixed = run_fix(sql, "join").expect("should produce fix");
1951        assert!(
1952            fixed.contains("prep_1"),
1953            "expected prep_1 in output: {fixed}"
1954        );
1955    }
1956
1957    #[test]
1958    fn fixture_subquery_in_cte_both() {
1959        let sql = "with b as (\n  select x, z from (\n    select x, z from p_cte\n  )\n)\nselect b.z\nfrom b\n";
1960        let expected = "with prep_1 as (\n    select x, z from p_cte\n  ),\nb as (\n  select x, z from prep_1\n)\nselect b.z\nfrom b\n";
1961        let fixed = run_fix(sql, "both").expect("should produce fix");
1962        assert_fix_whitespace_eq(&fixed, expected);
1963    }
1964
1965    #[test]
1966    fn fixture_issue_3598_avoid_looping_1() {
1967        let sql = "WITH cte1 AS (\n    SELECT a\n    FROM (SELECT a)\n)\nSELECT a FROM cte1\n";
1968        let expected = "WITH prep_1 AS (SELECT a),\ncte1 AS (\n    SELECT a\n    FROM prep_1\n)\nSELECT a FROM cte1\n";
1969        let fixed = run_fix(sql, "both").expect("should produce fix");
1970        assert_fix_whitespace_eq(&fixed, expected);
1971    }
1972
1973    #[test]
1974    fn fixture_issue_3598_avoid_looping_2() {
1975        let sql = "WITH cte1 AS (\n    SELECT *\n    FROM (SELECT * FROM mongo.temp)\n)\nSELECT * FROM cte1\n";
1976        let expected = "WITH prep_1 AS (SELECT * FROM mongo.temp),\ncte1 AS (\n    SELECT *\n    FROM prep_1\n)\nSELECT * FROM cte1\n";
1977        let fixed = run_fix(sql, "both").expect("should produce fix");
1978        assert_fix_whitespace_eq(&fixed, expected);
1979    }
1980
1981    #[test]
1982    fn fixture_multijoin_both() {
1983        let sql = "select\n    a.x, d.x as foo, a.y, b.z\nfrom (select a, x from foo) a\njoin d using(x)\njoin (\n    select x, z from b\n) as b using (x)\n";
1984        let fixed = run_fix(sql, "both").expect("should produce fix");
1985        // Should extract both subqueries.
1986        assert!(
1987            fixed.to_ascii_lowercase().contains("with"),
1988            "expected WITH in output: {fixed}"
1989        );
1990    }
1991}