1use crate::linter::config::LintConfig;
6use crate::linter::rule::{LintContext, LintRule};
7use crate::parser::parse_sql_with_dialect;
8use crate::types::{issue_codes, Dialect, Issue, IssueAutofixApplicability, IssuePatchEdit};
9use sqlparser::ast::{Query, Select, SetExpr, Statement, TableFactor};
10use std::collections::HashSet;
11
12use super::semantic_helpers::{
13 collect_qualifier_prefixes_in_expr, visit_select_expressions, visit_selects_in_statement,
14};
15
16#[derive(Clone, Copy, Debug, Eq, PartialEq)]
17enum ForbidSubqueryIn {
18 Both,
19 Join,
20 From,
21}
22
23impl ForbidSubqueryIn {
24 fn from_config(config: &LintConfig) -> Self {
25 match config
26 .rule_option_str(issue_codes::LINT_ST_005, "forbid_subquery_in")
27 .unwrap_or("join")
28 .to_ascii_lowercase()
29 .as_str()
30 {
31 "join" => Self::Join,
32 "from" => Self::From,
33 _ => Self::Both,
34 }
35 }
36
37 fn forbid_from(self) -> bool {
38 matches!(self, Self::Both | Self::From)
39 }
40
41 fn forbid_join(self) -> bool {
42 matches!(self, Self::Both | Self::Join)
43 }
44}
45
46pub struct StructureSubquery {
47 forbid_subquery_in: ForbidSubqueryIn,
48}
49
50impl StructureSubquery {
51 pub fn from_config(config: &LintConfig) -> Self {
52 Self {
53 forbid_subquery_in: ForbidSubqueryIn::from_config(config),
54 }
55 }
56}
57
58impl Default for StructureSubquery {
59 fn default() -> Self {
60 Self {
61 forbid_subquery_in: ForbidSubqueryIn::Join,
62 }
63 }
64}
65
66impl LintRule for StructureSubquery {
67 fn code(&self) -> &'static str {
68 issue_codes::LINT_ST_005
69 }
70
71 fn name(&self) -> &'static str {
72 "Structure subquery"
73 }
74
75 fn description(&self) -> &'static str {
76 "Join/From clauses should not contain subqueries. Use CTEs instead."
77 }
78
79 fn check(&self, statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
80 let mut violations = 0usize;
81
82 visit_selects_in_statement(statement, &mut |select| {
83 let outer_source_names = source_names_in_select(select);
84 for table in &select.from {
85 if self.forbid_subquery_in.forbid_from()
86 && table_factor_contains_derived(&table.relation, &outer_source_names)
87 {
88 violations += 1;
89 }
90 if self.forbid_subquery_in.forbid_join() {
91 for join in &table.joins {
92 if table_factor_contains_derived(&join.relation, &outer_source_names) {
93 violations += 1;
94 }
95 }
96 }
97 }
98 });
99
100 if violations == 0 {
101 return Vec::new();
102 }
103
104 let autofix_edits = st005_subquery_to_cte_rewrite(
105 ctx.statement_sql(),
106 statement,
107 self.forbid_subquery_in,
108 ctx.dialect(),
109 )
110 .filter(|rewritten| rewritten != ctx.statement_sql())
111 .map(|rewritten| {
112 vec![IssuePatchEdit::new(
113 ctx.span_from_statement_offset(0, ctx.statement_sql().len()),
114 rewritten,
115 )]
116 })
117 .unwrap_or_default();
118
119 (0..violations)
120 .map(|index| {
121 let mut issue = Issue::info(
122 issue_codes::LINT_ST_005,
123 "Join/From clauses should not contain subqueries. Use CTEs instead.",
124 )
125 .with_statement(ctx.statement_index);
126 if index == 0 && !autofix_edits.is_empty() {
127 issue = issue.with_autofix_edits(
128 IssueAutofixApplicability::Unsafe,
129 autofix_edits.clone(),
130 );
131 }
132 issue
133 })
134 .collect()
135 }
136}
137
138#[derive(Debug, Clone)]
144struct SubqueryExtraction {
145 open_paren: usize,
147 close_paren: usize,
149 alias: String,
151 alias_region_end: usize,
153}
154
155fn st005_subquery_to_cte_rewrite(
158 sql: &str,
159 stmt: &Statement,
160 forbid_subquery_in: ForbidSubqueryIn,
161 dialect: Dialect,
162) -> Option<String> {
163 const MAX_REWRITE_PASSES: usize = 8;
164
165 let mut current_sql = sql.to_string();
166 let mut current_stmt = stmt.clone();
167 let mut changed = false;
168
169 for _ in 0..MAX_REWRITE_PASSES {
170 let mut subquery_aliases: Vec<(String, bool)> = Vec::new();
172 collect_extractable_subqueries(¤t_stmt, forbid_subquery_in, &mut subquery_aliases);
173 if subquery_aliases.is_empty() {
174 break;
175 }
176
177 let extractions =
179 find_subquery_positions(¤t_sql, forbid_subquery_in, &subquery_aliases);
180 if extractions.is_empty() {
181 break;
182 }
183
184 let Some(rewritten) = apply_cte_extractions(¤t_sql, &extractions, dialect) else {
185 break;
186 };
187 if rewritten == current_sql {
188 break;
189 }
190
191 changed = true;
192 current_sql = rewritten;
193
194 let Ok(mut reparsed) = parse_sql_with_dialect(¤t_sql, dialect) else {
197 break;
198 };
199 let Some(next_stmt) = (reparsed.len() == 1).then(|| reparsed.remove(0)) else {
200 break;
201 };
202 current_stmt = next_stmt;
203 }
204
205 changed.then_some(current_sql)
206}
207
208fn collect_extractable_subqueries(
211 stmt: &Statement,
212 forbid_in: ForbidSubqueryIn,
213 out: &mut Vec<(String, bool)>,
214) {
215 visit_selects_in_statement(stmt, &mut |select| {
216 let outer_source_names = source_names_in_select(select);
217 for table in &select.from {
218 if forbid_in.forbid_from() {
219 collect_from_table_factor(&table.relation, &outer_source_names, out);
220 }
221 if forbid_in.forbid_join() {
222 for join in &table.joins {
223 collect_from_table_factor(&join.relation, &outer_source_names, out);
224 }
225 }
226 }
227 });
228}
229
230fn collect_from_table_factor(
232 tf: &TableFactor,
233 outer_names: &HashSet<String>,
234 out: &mut Vec<(String, bool)>,
235) {
236 match tf {
237 TableFactor::Derived {
238 subquery, alias, ..
239 } => {
240 let is_correlated = query_references_outer_sources(subquery, outer_names);
241 if !is_correlated {
242 let alias_name = alias
243 .as_ref()
244 .map(|a| a.name.value.clone())
245 .unwrap_or_default();
246 out.push((alias_name, is_correlated));
247 }
248 }
249 TableFactor::NestedJoin {
250 table_with_joins, ..
251 } => {
252 collect_from_table_factor(&table_with_joins.relation, outer_names, out);
253 for join in &table_with_joins.joins {
254 collect_from_table_factor(&join.relation, outer_names, out);
255 }
256 }
257 TableFactor::Pivot { table, .. }
258 | TableFactor::Unpivot { table, .. }
259 | TableFactor::MatchRecognize { table, .. } => {
260 collect_from_table_factor(table, outer_names, out);
261 }
262 _ => {}
263 }
264}
265
266fn find_subquery_positions(
269 sql: &str,
270 forbid_in: ForbidSubqueryIn,
271 ast_aliases: &[(String, bool)],
272) -> Vec<SubqueryExtraction> {
273 let bytes = sql.as_bytes();
274 let mut extractions = Vec::new();
275 let mut ast_idx = 0usize;
276 let mut auto_name_counter = 0usize;
277 let mut existing_cte_names: HashSet<String> = HashSet::new();
279 collect_existing_cte_names(sql, &mut existing_cte_names);
280
281 let mut used_names: HashSet<String> = existing_cte_names.clone();
283 for (alias, _) in ast_aliases {
284 if !alias.is_empty() {
285 used_names.insert(alias.to_ascii_uppercase());
286 }
287 }
288
289 let mut claimed_names: HashSet<String> = existing_cte_names;
291
292 let mut pos = 0usize;
293 while pos < bytes.len() {
294 if let Some(end) = skip_quoted_region(bytes, pos) {
296 pos = end;
297 continue;
298 }
299 if bytes[pos] == b'-' && bytes.get(pos + 1) == Some(&b'-') {
301 while pos < bytes.len() && bytes[pos] != b'\n' {
302 pos += 1;
303 }
304 continue;
305 }
306 if bytes[pos] == b'/' && bytes.get(pos + 1) == Some(&b'*') {
308 pos += 2;
309 while pos + 1 < bytes.len() {
310 if bytes[pos] == b'*' && bytes[pos + 1] == b'/' {
311 pos += 2;
312 break;
313 }
314 pos += 1;
315 }
316 continue;
317 }
318
319 let is_from =
321 forbid_in.forbid_from() && match_ascii_keyword_at(bytes, pos, b"FROM").is_some();
322 let is_join = forbid_in.forbid_join()
323 && (match_ascii_keyword_at(bytes, pos, b"JOIN").is_some()
324 || match_join_keyword_sequence(bytes, pos).is_some());
325
326 if is_from || is_join {
327 let keyword_end = if is_from {
328 match_ascii_keyword_at(bytes, pos, b"FROM").unwrap()
329 } else if let Some(end) = match_join_keyword_sequence(bytes, pos) {
330 end
331 } else {
332 match_ascii_keyword_at(bytes, pos, b"JOIN").unwrap()
333 };
334
335 let after_keyword = skip_ascii_whitespace(bytes, keyword_end);
336
337 if after_keyword < bytes.len() && bytes[after_keyword] == b'(' {
339 if let Some(close) = find_matching_parenthesis_outside_quotes(sql, after_keyword) {
340 let inner = sql[after_keyword + 1..close].trim();
341 let inner_lower = inner.to_ascii_lowercase();
342 if (inner_lower.starts_with("select") || inner_lower.starts_with("with"))
345 && ast_idx < ast_aliases.len()
346 {
347 let (ref ast_alias, _) = ast_aliases[ast_idx];
348 ast_idx += 1;
349
350 let alias = if ast_alias.is_empty() {
351 let name = generate_prep_name(&mut auto_name_counter, &used_names);
352 let name_key = name.to_ascii_uppercase();
353 used_names.insert(name_key.clone());
354 claimed_names.insert(name_key);
355 name
356 } else {
357 let alias_key = ast_alias.to_ascii_uppercase();
358 if claimed_names.contains(&alias_key) {
361 pos = close + 1;
362 continue;
363 }
364 claimed_names.insert(alias_key.clone());
365 used_names.insert(alias_key);
366 ast_alias.clone()
367 };
368
369 let (_alias_start, alias_end) =
371 parse_alias_region_after_close_paren(bytes, close);
372
373 extractions.push(SubqueryExtraction {
374 open_paren: after_keyword,
375 close_paren: close,
376 alias: alias.clone(),
377 alias_region_end: alias_end,
378 });
379
380 pos = alias_end;
382 continue;
383 }
384 }
385 }
386 }
387
388 pos += 1;
389 }
390
391 extractions
392}
393
394fn generate_prep_name(counter: &mut usize, used_names: &HashSet<String>) -> String {
396 loop {
397 *counter += 1;
398 let name = format!("prep_{counter}");
399 if !used_names.contains(&name.to_ascii_uppercase()) {
400 return name;
401 }
402 }
403}
404
405fn collect_existing_cte_names(sql: &str, names: &mut HashSet<String>) {
407 let bytes = sql.as_bytes();
408 let mut pos = skip_ascii_whitespace(bytes, 0);
409
410 if let Some(end) = match_ascii_keyword_at(bytes, pos, b"INSERT") {
413 pos = skip_to_with_or_select(bytes, end);
414 } else if let Some(end) = match_ascii_keyword_at(bytes, pos, b"CREATE") {
415 pos = skip_to_with_or_select(bytes, end);
416 }
417
418 if match_ascii_keyword_at(bytes, pos, b"WITH").is_none() {
419 return;
420 }
421
422 let with_end = match_ascii_keyword_at(bytes, pos, b"WITH").unwrap();
423 pos = skip_ascii_whitespace(bytes, with_end);
424
425 if let Some(end) = match_ascii_keyword_at(bytes, pos, b"RECURSIVE") {
427 pos = skip_ascii_whitespace(bytes, end);
428 }
429
430 loop {
432 let name_start = pos;
434 if let Some(quoted_end) = consume_quoted_identifier(bytes, pos) {
435 let raw = &sql[name_start..quoted_end];
436 let unquoted = raw.trim_matches(|c| c == '"' || c == '`' || c == '[' || c == ']');
437 names.insert(unquoted.to_ascii_uppercase());
438 pos = skip_ascii_whitespace(bytes, quoted_end);
439 } else if let Some(name_end) = consume_ascii_identifier(bytes, pos) {
440 names.insert(sql[name_start..name_end].to_ascii_uppercase());
441 pos = skip_ascii_whitespace(bytes, name_end);
442 } else {
443 break;
444 }
445
446 if let Some(as_end) = match_ascii_keyword_at(bytes, pos, b"AS") {
448 pos = skip_ascii_whitespace(bytes, as_end);
449 } else {
450 break;
451 }
452
453 if pos < bytes.len() && bytes[pos] == b'(' {
455 if let Some(close) = find_matching_parenthesis_outside_quotes(sql, pos) {
456 pos = skip_ascii_whitespace(bytes, close + 1);
457 } else {
458 break;
459 }
460 } else {
461 break;
462 }
463
464 if pos < bytes.len() && bytes[pos] == b',' {
466 pos += 1;
467 pos = skip_ascii_whitespace(bytes, pos);
468 } else {
469 break;
470 }
471 }
472}
473
474fn skip_to_with_or_select(bytes: &[u8], mut pos: usize) -> usize {
476 while pos < bytes.len() {
477 let ws = skip_ascii_whitespace(bytes, pos);
478 if ws > pos {
479 pos = ws;
480 }
481 if match_ascii_keyword_at(bytes, pos, b"WITH").is_some() {
482 return pos;
483 }
484 if match_ascii_keyword_at(bytes, pos, b"SELECT").is_some() {
485 return pos;
486 }
487 pos += 1;
488 }
489 pos
490}
491
492fn parse_alias_region_after_close_paren(bytes: &[u8], close_paren: usize) -> (usize, usize) {
495 let start = close_paren + 1;
496 let mut pos = start;
497 let ws_pos = skip_ascii_whitespace(bytes, pos);
498
499 if let Some(as_end) = match_ascii_keyword_at(bytes, ws_pos, b"AS") {
501 let after_as = skip_ascii_whitespace(bytes, as_end);
502 if let Some(quoted_end) = consume_quoted_identifier(bytes, after_as) {
503 return (start, quoted_end);
504 }
505 if let Some(ident_end) = consume_ascii_identifier(bytes, after_as) {
506 return (start, ident_end);
507 }
508 }
509
510 if let Some(quoted_end) = consume_quoted_identifier(bytes, ws_pos) {
514 return (start, quoted_end);
515 }
516 if let Some(ident_end) = consume_ascii_identifier(bytes, ws_pos) {
517 let word = &bytes[ws_pos..ident_end];
518 if !is_clause_keyword(word) {
519 pos = ident_end;
520 return (start, pos);
521 }
522 }
523
524 (start, start)
525}
526
527fn is_clause_keyword(word: &[u8]) -> bool {
529 let upper: Vec<u8> = word.iter().map(|b| b.to_ascii_uppercase()).collect();
530 matches!(
531 upper.as_slice(),
532 b"ON"
533 | b"USING"
534 | b"WHERE"
535 | b"JOIN"
536 | b"INNER"
537 | b"LEFT"
538 | b"RIGHT"
539 | b"FULL"
540 | b"OUTER"
541 | b"CROSS"
542 | b"NATURAL"
543 | b"GROUP"
544 | b"ORDER"
545 | b"HAVING"
546 | b"LIMIT"
547 | b"UNION"
548 | b"INTERSECT"
549 | b"EXCEPT"
550 | b"MINUS"
551 | b"FROM"
552 | b"SELECT"
553 | b"INSERT"
554 | b"UPDATE"
555 | b"DELETE"
556 | b"SET"
557 | b"INTO"
558 | b"VALUES"
559 | b"WITH"
560 )
561}
562
563fn apply_cte_extractions(
566 sql: &str,
567 extractions: &[SubqueryExtraction],
568 dialect: Dialect,
569) -> Option<String> {
570 if extractions.is_empty() {
571 return None;
572 }
573
574 let case_pref = detect_case_preference(sql);
575
576 let existing_ctes = parse_existing_cte_ranges(sql);
578
579 struct CteInsertion {
582 definition: String,
583 insert_before: Option<usize>,
585 }
586
587 let mut insertions: Vec<CteInsertion> = Vec::new();
588 let mut replacements: Vec<(usize, usize, String)> = Vec::new();
589
590 for ext in extractions {
591 let subquery_text = &sql[ext.open_paren + 1..ext.close_paren];
592 let as_kw = if case_pref == CasePref::Upper {
593 "AS"
594 } else {
595 "as"
596 };
597 let cte_def = format!("{} {} ({})", ext.alias, as_kw, subquery_text);
598
599 let containing_cte = existing_ctes
601 .iter()
602 .position(|cte| ext.open_paren >= cte.body_start && ext.close_paren <= cte.body_end);
603
604 insertions.push(CteInsertion {
605 definition: cte_def,
606 insert_before: containing_cte,
607 });
608
609 let mut replacement = ext.alias.clone();
610 if ext.open_paren > 0 {
611 let prev = sql.as_bytes()[ext.open_paren - 1];
612 if !prev.is_ascii_whitespace() {
613 replacement.insert(0, ' ');
614 }
615 }
616
617 replacements.push((ext.open_paren, ext.alias_region_end, replacement));
618 }
619
620 let mut result = sql.to_string();
622 for (start, end, replacement) in replacements.into_iter().rev() {
623 result.replace_range(start..end, &replacement);
624 }
625
626 let mut before_insertions: Vec<(usize, String)> = Vec::new(); let mut top_level_defs: Vec<String> = Vec::new();
631
632 for insertion in insertions {
633 match insertion.insert_before {
634 Some(cte_idx) => before_insertions.push((cte_idx, insertion.definition)),
635 None => top_level_defs.push(insertion.definition),
636 }
637 }
638
639 if !before_insertions.is_empty() && !existing_ctes.is_empty() {
640 result = rebuild_with_clause_with_insertions(
642 &result,
643 sql,
644 &existing_ctes,
645 &before_insertions,
646 &top_level_defs,
647 case_pref,
648 );
649 return Some(result);
650 }
651
652 insert_cte_clause(&result, &top_level_defs, case_pref, dialect)
654}
655
656#[derive(Debug, Clone)]
658struct ExistingCteRange {
659 body_start: usize,
661 body_end: usize,
663}
664
665fn parse_existing_cte_ranges(sql: &str) -> Vec<ExistingCteRange> {
667 let bytes = sql.as_bytes();
668 let mut pos = skip_ascii_whitespace(bytes, 0);
669 let mut ranges = Vec::new();
670
671 if match_ascii_keyword_at(bytes, pos, b"INSERT").is_some()
673 || match_ascii_keyword_at(bytes, pos, b"CREATE").is_some()
674 {
675 pos = skip_to_with_or_select(bytes, pos + 6);
676 }
677
678 let with_end = match match_ascii_keyword_at(bytes, pos, b"WITH") {
679 Some(end) => end,
680 None => return ranges,
681 };
682 pos = skip_ascii_whitespace(bytes, with_end);
683
684 if let Some(end) = match_ascii_keyword_at(bytes, pos, b"RECURSIVE") {
686 pos = skip_ascii_whitespace(bytes, end);
687 }
688
689 loop {
690 if let Some(quoted_end) = consume_quoted_identifier(bytes, pos) {
692 pos = skip_ascii_whitespace(bytes, quoted_end);
693 } else if let Some(name_end) = consume_ascii_identifier(bytes, pos) {
694 pos = skip_ascii_whitespace(bytes, name_end);
695 } else {
696 break;
697 }
698
699 if let Some(as_end) = match_ascii_keyword_at(bytes, pos, b"AS") {
701 pos = skip_ascii_whitespace(bytes, as_end);
702 } else {
703 break;
704 }
705
706 if pos < bytes.len() && bytes[pos] == b'(' {
708 if let Some(close) = find_matching_parenthesis_outside_quotes(sql, pos) {
709 ranges.push(ExistingCteRange {
710 body_start: pos,
711 body_end: close,
712 });
713 pos = skip_ascii_whitespace(bytes, close + 1);
714 } else {
715 break;
716 }
717 } else {
718 break;
719 }
720
721 if pos < bytes.len() && bytes[pos] == b',' {
723 pos += 1;
724 pos = skip_ascii_whitespace(bytes, pos);
725 } else {
726 break;
727 }
728 }
729
730 ranges
731}
732
733fn rebuild_with_clause_with_insertions(
735 modified_sql: &str,
736 _original_sql: &str,
737 _existing_ctes: &[ExistingCteRange],
738 before_insertions: &[(usize, String)],
739 top_level_defs: &[String],
740 case_pref: CasePref,
741) -> String {
742 let bytes = modified_sql.as_bytes();
749 let mut pos = skip_ascii_whitespace(bytes, 0);
750
751 if match_ascii_keyword_at(bytes, pos, b"INSERT").is_some()
753 || match_ascii_keyword_at(bytes, pos, b"CREATE").is_some()
754 {
755 pos = skip_to_with_or_select(bytes, pos + 6);
756 }
757
758 let with_kw_start = pos;
759 let with_end = match match_ascii_keyword_at(bytes, pos, b"WITH") {
760 Some(end) => end,
761 None => return modified_sql.to_string(),
762 };
763 pos = skip_ascii_whitespace(bytes, with_end);
764
765 if let Some(end) = match_ascii_keyword_at(bytes, pos, b"RECURSIVE") {
767 pos = skip_ascii_whitespace(bytes, end);
768 }
769
770 let mut cte_texts: Vec<String> = Vec::new();
772 let mut last_cte_end = pos;
773
774 loop {
775 let cte_start = pos;
776
777 if let Some(quoted_end) = consume_quoted_identifier(bytes, pos) {
778 pos = skip_ascii_whitespace(bytes, quoted_end);
779 } else if let Some(name_end) = consume_ascii_identifier(bytes, pos) {
780 pos = skip_ascii_whitespace(bytes, name_end);
781 } else {
782 break;
783 }
784
785 if let Some(as_end) = match_ascii_keyword_at(bytes, pos, b"AS") {
786 pos = skip_ascii_whitespace(bytes, as_end);
787 } else {
788 break;
789 }
790
791 if pos < bytes.len() && bytes[pos] == b'(' {
792 if let Some(close) = find_matching_parenthesis_outside_quotes(modified_sql, pos) {
793 let cte_text = modified_sql[cte_start..close + 1].to_string();
794 cte_texts.push(cte_text);
795 last_cte_end = close + 1;
796 pos = skip_ascii_whitespace(bytes, close + 1);
797 } else {
798 break;
799 }
800 } else {
801 break;
802 }
803
804 if pos < bytes.len() && bytes[pos] == b',' {
805 pos += 1;
806 pos = skip_ascii_whitespace(bytes, pos);
807 } else {
808 break;
809 }
810 }
811
812 let mut new_cte_list: Vec<String> = Vec::new();
814 for (i, cte_text) in cte_texts.iter().enumerate() {
815 for (before_idx, def) in before_insertions {
817 if *before_idx == i {
818 new_cte_list.push(def.clone());
819 }
820 }
821 new_cte_list.push(cte_text.clone());
822 }
823
824 for def in top_level_defs {
826 new_cte_list.push(def.clone());
827 }
828
829 let with_kw = if case_pref == CasePref::Upper {
831 "WITH"
832 } else {
833 "with"
834 };
835 let remainder = &modified_sql[last_cte_end..];
836
837 let mut result = String::with_capacity(modified_sql.len() + 200);
838 result.push_str(&modified_sql[..with_kw_start]);
839 result.push_str(with_kw);
840 result.push(' ');
841 for (i, cte) in new_cte_list.iter().enumerate() {
842 if i > 0 {
843 result.push_str(",\n");
844 }
845 result.push_str(cte);
846 }
847 result.push_str(remainder);
848
849 result
850}
851
852#[derive(Clone, Copy, Debug, Eq, PartialEq)]
853enum CasePref {
854 Upper,
855 Lower,
856}
857
858fn detect_case_preference(sql: &str) -> CasePref {
860 let bytes = sql.as_bytes();
861 let pos = skip_ascii_whitespace(bytes, 0);
862 for kw in &[b"WITH" as &[u8], b"SELECT", b"INSERT", b"CREATE"] {
864 if pos + kw.len() <= bytes.len() {
865 let word = &bytes[pos..pos + kw.len()];
866 if word
867 .iter()
868 .zip(kw.iter())
869 .all(|(a, b)| a.to_ascii_uppercase() == *b)
870 && is_word_boundary_for_keyword(bytes, pos + kw.len())
871 {
872 return if word[0].is_ascii_uppercase() {
873 CasePref::Upper
874 } else {
875 CasePref::Lower
876 };
877 }
878 }
879 }
880 CasePref::Upper
881}
882
883fn insert_cte_clause(
886 sql: &str,
887 cte_defs: &[String],
888 case_pref: CasePref,
889 dialect: Dialect,
890) -> Option<String> {
891 let bytes = sql.as_bytes();
892 let with_kw = if case_pref == CasePref::Upper {
893 "WITH"
894 } else {
895 "with"
896 };
897
898 let scan_pos = skip_ascii_whitespace(bytes, 0);
900
901 let is_insert = match_ascii_keyword_at(bytes, scan_pos, b"INSERT").is_some();
902 let is_create = match_ascii_keyword_at(bytes, scan_pos, b"CREATE").is_some();
903 let is_tsql_insert = is_insert && dialect == Dialect::Mssql;
904
905 if is_tsql_insert {
906 let insert_pos = skip_ascii_whitespace(bytes, 0);
908 return Some(insert_with_before_position(
909 sql, insert_pos, cte_defs, with_kw,
910 ));
911 }
912
913 if is_create {
914 if let Some(body_pos) = find_create_as_body_position(sql) {
915 return insert_with_at_select(sql, body_pos, cte_defs, with_kw);
916 }
917 if let Some(pos) = find_main_select_position(sql) {
919 return insert_with_at_select(sql, pos, cte_defs, with_kw);
920 }
921 return None;
922 }
923
924 if is_insert {
925 let select_pos = find_main_select_position(sql);
927 if let Some(pos) = select_pos {
928 return insert_with_at_select(sql, pos, cte_defs, with_kw);
929 }
930 return None;
931 }
932
933 if let Some(with_info) = find_existing_with_clause(sql) {
935 return Some(append_to_existing_with(sql, &with_info, cte_defs));
937 }
938
939 let insert_pos = skip_ascii_whitespace(bytes, 0);
941 Some(insert_with_before_position(
942 sql, insert_pos, cte_defs, with_kw,
943 ))
944}
945
946fn find_create_as_body_position(sql: &str) -> Option<usize> {
948 let bytes = sql.as_bytes();
949 let mut pos = skip_ascii_whitespace(bytes, 0);
950 let create_end = match_ascii_keyword_at(bytes, pos, b"CREATE")?;
951 pos = create_end;
952
953 let mut depth = 0usize;
954 while pos < bytes.len() {
955 if let Some(end) = skip_quoted_region(bytes, pos) {
956 pos = end;
957 continue;
958 }
959 if bytes[pos] == b'-' && bytes.get(pos + 1) == Some(&b'-') {
960 while pos < bytes.len() && bytes[pos] != b'\n' {
961 pos += 1;
962 }
963 continue;
964 }
965 if bytes[pos] == b'/' && bytes.get(pos + 1) == Some(&b'*') {
966 pos += 2;
967 while pos + 1 < bytes.len() {
968 if bytes[pos] == b'*' && bytes[pos + 1] == b'/' {
969 pos += 2;
970 break;
971 }
972 pos += 1;
973 }
974 continue;
975 }
976
977 if bytes[pos] == b'(' {
978 depth += 1;
979 pos += 1;
980 continue;
981 }
982 if bytes[pos] == b')' {
983 depth = depth.saturating_sub(1);
984 pos += 1;
985 continue;
986 }
987
988 if depth == 0 {
989 if let Some(as_end) = match_ascii_keyword_at(bytes, pos, b"AS") {
990 return Some(skip_ascii_whitespace(bytes, as_end));
991 }
992 }
993
994 pos += 1;
995 }
996
997 None
998}
999
1000struct ExistingWithInfo {
1001 last_cte_end: usize,
1003}
1004
1005fn find_existing_with_clause(sql: &str) -> Option<ExistingWithInfo> {
1007 let bytes = sql.as_bytes();
1008 let mut pos = skip_ascii_whitespace(bytes, 0);
1009
1010 if match_ascii_keyword_at(bytes, pos, b"INSERT").is_some()
1012 || match_ascii_keyword_at(bytes, pos, b"CREATE").is_some()
1013 {
1014 pos = skip_to_with_or_select(bytes, pos + 6);
1015 }
1016
1017 let _with_end = match_ascii_keyword_at(bytes, pos, b"WITH")?;
1018 let mut cursor = skip_ascii_whitespace(bytes, _with_end);
1019
1020 if let Some(end) = match_ascii_keyword_at(bytes, cursor, b"RECURSIVE") {
1022 cursor = skip_ascii_whitespace(bytes, end);
1023 }
1024
1025 let mut last_cte_end = cursor;
1027 loop {
1028 if let Some(quoted_end) = consume_quoted_identifier(bytes, cursor) {
1030 cursor = skip_ascii_whitespace(bytes, quoted_end);
1031 } else if let Some(name_end) = consume_ascii_identifier(bytes, cursor) {
1032 cursor = skip_ascii_whitespace(bytes, name_end);
1033 } else {
1034 break;
1035 }
1036
1037 if let Some(as_end) = match_ascii_keyword_at(bytes, cursor, b"AS") {
1039 cursor = skip_ascii_whitespace(bytes, as_end);
1040 } else {
1041 break;
1042 }
1043
1044 if cursor < bytes.len() && bytes[cursor] == b'(' {
1046 if let Some(close) = find_matching_parenthesis_outside_quotes(sql, cursor) {
1047 last_cte_end = close + 1;
1048 cursor = skip_ascii_whitespace(bytes, close + 1);
1049 } else {
1050 break;
1051 }
1052 } else {
1053 break;
1054 }
1055
1056 if cursor < bytes.len() && bytes[cursor] == b',' {
1058 cursor += 1;
1059 cursor = skip_ascii_whitespace(bytes, cursor);
1060 } else {
1061 break;
1062 }
1063 }
1064
1065 Some(ExistingWithInfo { last_cte_end })
1066}
1067
1068fn append_to_existing_with(sql: &str, with_info: &ExistingWithInfo, cte_defs: &[String]) -> String {
1070 let insert_pos = with_info.last_cte_end;
1071 let mut result =
1072 String::with_capacity(sql.len() + cte_defs.iter().map(|d| d.len() + 4).sum::<usize>());
1073 result.push_str(&sql[..insert_pos]);
1074 for def in cte_defs {
1075 result.push_str(",\n");
1076 result.push_str(def);
1077 }
1078 result.push_str(&sql[insert_pos..]);
1079 result
1080}
1081
1082fn insert_with_before_position(
1084 sql: &str,
1085 pos: usize,
1086 cte_defs: &[String],
1087 with_kw: &str,
1088) -> String {
1089 let mut result = String::with_capacity(sql.len() + 100);
1090 result.push_str(&sql[..pos]);
1091 result.push_str(with_kw);
1092 result.push(' ');
1093 for (i, def) in cte_defs.iter().enumerate() {
1094 if i > 0 {
1095 result.push_str(",\n");
1096 }
1097 result.push_str(def);
1098 }
1099 result.push('\n');
1100 result.push_str(&sql[pos..]);
1101 result
1102}
1103
1104fn insert_with_at_select(
1106 sql: &str,
1107 select_pos: usize,
1108 cte_defs: &[String],
1109 with_kw: &str,
1110) -> Option<String> {
1111 let bytes = sql.as_bytes();
1113 if match_ascii_keyword_at(bytes, select_pos, b"WITH").is_some() {
1114 if let Some(with_info) = find_existing_with_clause_at(sql, select_pos) {
1116 return Some(append_to_existing_with(sql, &with_info, cte_defs));
1117 }
1118 }
1119
1120 Some(insert_with_before_position(
1121 sql, select_pos, cte_defs, with_kw,
1122 ))
1123}
1124
1125fn find_existing_with_clause_at(sql: &str, start: usize) -> Option<ExistingWithInfo> {
1127 let bytes = sql.as_bytes();
1128 let _with_end = match_ascii_keyword_at(bytes, start, b"WITH")?;
1129 let mut cursor = skip_ascii_whitespace(bytes, _with_end);
1130
1131 if let Some(end) = match_ascii_keyword_at(bytes, cursor, b"RECURSIVE") {
1133 cursor = skip_ascii_whitespace(bytes, end);
1134 }
1135
1136 let mut last_cte_end = cursor;
1137 loop {
1138 if let Some(quoted_end) = consume_quoted_identifier(bytes, cursor) {
1139 cursor = skip_ascii_whitespace(bytes, quoted_end);
1140 } else if let Some(name_end) = consume_ascii_identifier(bytes, cursor) {
1141 cursor = skip_ascii_whitespace(bytes, name_end);
1142 } else {
1143 break;
1144 }
1145
1146 if let Some(as_end) = match_ascii_keyword_at(bytes, cursor, b"AS") {
1147 cursor = skip_ascii_whitespace(bytes, as_end);
1148 } else {
1149 break;
1150 }
1151
1152 if cursor < bytes.len() && bytes[cursor] == b'(' {
1153 if let Some(close) = find_matching_parenthesis_outside_quotes(sql, cursor) {
1154 last_cte_end = close + 1;
1155 cursor = skip_ascii_whitespace(bytes, close + 1);
1156 } else {
1157 break;
1158 }
1159 } else {
1160 break;
1161 }
1162
1163 if cursor < bytes.len() && bytes[cursor] == b',' {
1164 cursor += 1;
1165 cursor = skip_ascii_whitespace(bytes, cursor);
1166 } else {
1167 break;
1168 }
1169 }
1170
1171 Some(ExistingWithInfo { last_cte_end })
1172}
1173
1174fn find_main_select_position(sql: &str) -> Option<usize> {
1176 let bytes = sql.as_bytes();
1177 let mut pos = 0usize;
1178 let mut depth = 0usize;
1179
1180 while pos < bytes.len() {
1181 if let Some(end) = skip_quoted_region(bytes, pos) {
1182 pos = end;
1183 continue;
1184 }
1185 if bytes[pos] == b'-' && bytes.get(pos + 1) == Some(&b'-') {
1186 while pos < bytes.len() && bytes[pos] != b'\n' {
1187 pos += 1;
1188 }
1189 continue;
1190 }
1191 if bytes[pos] == b'/' && bytes.get(pos + 1) == Some(&b'*') {
1192 pos += 2;
1193 while pos + 1 < bytes.len() {
1194 if bytes[pos] == b'*' && bytes[pos + 1] == b'/' {
1195 pos += 2;
1196 break;
1197 }
1198 pos += 1;
1199 }
1200 continue;
1201 }
1202
1203 if bytes[pos] == b'(' {
1204 depth += 1;
1205 pos += 1;
1206 continue;
1207 }
1208 if bytes[pos] == b')' {
1209 depth = depth.saturating_sub(1);
1210 pos += 1;
1211 continue;
1212 }
1213
1214 if depth == 0 {
1216 if match_ascii_keyword_at(bytes, pos, b"WITH").is_some() {
1217 return Some(pos);
1218 }
1219 if match_ascii_keyword_at(bytes, pos, b"SELECT").is_some() {
1220 return Some(pos);
1221 }
1222 }
1223
1224 pos += 1;
1225 }
1226 None
1227}
1228
1229fn skip_quoted_region(bytes: &[u8], pos: usize) -> Option<usize> {
1232 let b = bytes[pos];
1233 if b == b'\'' {
1234 return Some(skip_to_close_quote(bytes, pos + 1, b'\''));
1235 }
1236 if b == b'"' {
1237 return Some(skip_to_close_quote(bytes, pos + 1, b'"'));
1238 }
1239 if b == b'`' {
1240 return Some(skip_to_close_quote(bytes, pos + 1, b'`'));
1241 }
1242 if b == b'[' {
1243 return Some(skip_to_close_quote(bytes, pos + 1, b']'));
1244 }
1245 None
1246}
1247
1248fn skip_to_close_quote(bytes: &[u8], mut pos: usize, close: u8) -> usize {
1249 while pos < bytes.len() {
1250 if bytes[pos] == close {
1251 if bytes.get(pos + 1) == Some(&close) {
1252 pos += 2; } else {
1254 return pos + 1;
1255 }
1256 } else {
1257 pos += 1;
1258 }
1259 }
1260 pos
1261}
1262
1263fn consume_quoted_identifier(bytes: &[u8], pos: usize) -> Option<usize> {
1265 if pos >= bytes.len() {
1266 return None;
1267 }
1268 match bytes[pos] {
1269 b'"' => Some(skip_to_close_quote(bytes, pos + 1, b'"')),
1270 b'`' => Some(skip_to_close_quote(bytes, pos + 1, b'`')),
1271 b'[' => Some(skip_to_close_quote(bytes, pos + 1, b']')),
1272 _ => None,
1273 }
1274}
1275
1276fn match_join_keyword_sequence(bytes: &[u8], pos: usize) -> Option<usize> {
1279 let prefixes: &[&[u8]] = &[b"INNER", b"LEFT", b"RIGHT", b"FULL", b"CROSS", b"NATURAL"];
1282
1283 for prefix in prefixes {
1284 if let Some(prefix_end) = match_ascii_keyword_at(bytes, pos, prefix) {
1285 let mut cursor = skip_ascii_whitespace(bytes, prefix_end);
1286
1287 if let Some(outer_end) = match_ascii_keyword_at(bytes, cursor, b"OUTER") {
1289 cursor = skip_ascii_whitespace(bytes, outer_end);
1290 }
1291
1292 if let Some(join_end) = match_ascii_keyword_at(bytes, cursor, b"JOIN") {
1293 return Some(join_end);
1294 }
1295 }
1296 }
1297 None
1298}
1299
1300fn find_matching_parenthesis_outside_quotes(sql: &str, open_paren_index: usize) -> Option<usize> {
1301 #[derive(Clone, Copy, PartialEq, Eq)]
1302 enum Mode {
1303 Outside,
1304 SingleQuote,
1305 DoubleQuote,
1306 BacktickQuote,
1307 BracketQuote,
1308 }
1309
1310 let bytes = sql.as_bytes();
1311 if open_paren_index >= bytes.len() || bytes[open_paren_index] != b'(' {
1312 return None;
1313 }
1314
1315 let mut depth = 0usize;
1316 let mut mode = Mode::Outside;
1317 let mut index = open_paren_index;
1318
1319 while index < bytes.len() {
1320 let byte = bytes[index];
1321 let next = bytes.get(index + 1).copied();
1322
1323 match mode {
1324 Mode::Outside => {
1325 if byte == b'\'' {
1326 mode = Mode::SingleQuote;
1327 index += 1;
1328 continue;
1329 }
1330 if byte == b'"' {
1331 mode = Mode::DoubleQuote;
1332 index += 1;
1333 continue;
1334 }
1335 if byte == b'`' {
1336 mode = Mode::BacktickQuote;
1337 index += 1;
1338 continue;
1339 }
1340 if byte == b'[' {
1341 mode = Mode::BracketQuote;
1342 index += 1;
1343 continue;
1344 }
1345 if byte == b'(' {
1346 depth += 1;
1347 index += 1;
1348 continue;
1349 }
1350 if byte == b')' {
1351 depth = depth.checked_sub(1)?;
1352 if depth == 0 {
1353 return Some(index);
1354 }
1355 }
1356 index += 1;
1357 }
1358 Mode::SingleQuote => {
1359 if byte == b'\'' {
1360 if next == Some(b'\'') {
1361 index += 2;
1362 } else {
1363 mode = Mode::Outside;
1364 index += 1;
1365 }
1366 } else {
1367 index += 1;
1368 }
1369 }
1370 Mode::DoubleQuote => {
1371 if byte == b'"' {
1372 if next == Some(b'"') {
1373 index += 2;
1374 } else {
1375 mode = Mode::Outside;
1376 index += 1;
1377 }
1378 } else {
1379 index += 1;
1380 }
1381 }
1382 Mode::BacktickQuote => {
1383 if byte == b'`' {
1384 if next == Some(b'`') {
1385 index += 2;
1386 } else {
1387 mode = Mode::Outside;
1388 index += 1;
1389 }
1390 } else {
1391 index += 1;
1392 }
1393 }
1394 Mode::BracketQuote => {
1395 if byte == b']' {
1396 if next == Some(b']') {
1397 index += 2;
1398 } else {
1399 mode = Mode::Outside;
1400 index += 1;
1401 }
1402 } else {
1403 index += 1;
1404 }
1405 }
1406 }
1407 }
1408
1409 None
1410}
1411
1412fn is_ascii_whitespace_byte(byte: u8) -> bool {
1413 matches!(byte, b' ' | b'\n' | b'\r' | b'\t' | 0x0b | 0x0c)
1414}
1415
1416fn is_ascii_ident_start(byte: u8) -> bool {
1417 byte.is_ascii_alphabetic() || byte == b'_'
1418}
1419
1420fn is_ascii_ident_continue(byte: u8) -> bool {
1421 byte.is_ascii_alphanumeric() || byte == b'_'
1422}
1423
1424fn skip_ascii_whitespace(bytes: &[u8], mut index: usize) -> usize {
1425 while index < bytes.len() && is_ascii_whitespace_byte(bytes[index]) {
1426 index += 1;
1427 }
1428 index
1429}
1430
1431fn consume_ascii_identifier(bytes: &[u8], start: usize) -> Option<usize> {
1432 if start >= bytes.len() || !is_ascii_ident_start(bytes[start]) {
1433 return None;
1434 }
1435 let mut index = start + 1;
1436 while index < bytes.len() && is_ascii_ident_continue(bytes[index]) {
1437 index += 1;
1438 }
1439 Some(index)
1440}
1441
1442fn is_word_boundary_for_keyword(bytes: &[u8], index: usize) -> bool {
1443 index == 0 || index >= bytes.len() || !is_ascii_ident_continue(bytes[index])
1444}
1445
1446fn match_ascii_keyword_at(bytes: &[u8], start: usize, keyword_upper: &[u8]) -> Option<usize> {
1447 let end = start.checked_add(keyword_upper.len())?;
1448 if end > bytes.len() {
1449 return None;
1450 }
1451 if !is_word_boundary_for_keyword(bytes, start.saturating_sub(1))
1452 || !is_word_boundary_for_keyword(bytes, end)
1453 {
1454 return None;
1455 }
1456 let matches = bytes[start..end]
1457 .iter()
1458 .zip(keyword_upper.iter())
1459 .all(|(actual, expected)| actual.to_ascii_uppercase() == *expected);
1460 if matches {
1461 Some(end)
1462 } else {
1463 None
1464 }
1465}
1466
1467fn table_factor_contains_derived(
1468 table_factor: &TableFactor,
1469 outer_source_names: &HashSet<String>,
1470) -> bool {
1471 match table_factor {
1472 TableFactor::Derived { subquery, .. } => {
1473 !query_references_outer_sources(subquery, outer_source_names)
1474 }
1475 TableFactor::NestedJoin {
1476 table_with_joins, ..
1477 } => {
1478 table_factor_contains_derived(&table_with_joins.relation, outer_source_names)
1479 || table_with_joins
1480 .joins
1481 .iter()
1482 .any(|join| table_factor_contains_derived(&join.relation, outer_source_names))
1483 }
1484 TableFactor::Pivot { table, .. }
1485 | TableFactor::Unpivot { table, .. }
1486 | TableFactor::MatchRecognize { table, .. } => {
1487 table_factor_contains_derived(table, outer_source_names)
1488 }
1489 _ => false,
1490 }
1491}
1492
1493fn query_references_outer_sources(query: &Query, outer_source_names: &HashSet<String>) -> bool {
1494 if let Some(with) = &query.with {
1495 for cte in &with.cte_tables {
1496 if query_references_outer_sources(&cte.query, outer_source_names) {
1497 return true;
1498 }
1499 }
1500 }
1501
1502 set_expr_references_outer_sources(&query.body, outer_source_names)
1503}
1504
1505fn set_expr_references_outer_sources(
1506 set_expr: &SetExpr,
1507 outer_source_names: &HashSet<String>,
1508) -> bool {
1509 match set_expr {
1510 SetExpr::Select(select) => select_references_outer_sources(select, outer_source_names),
1511 SetExpr::Query(query) => query_references_outer_sources(query, outer_source_names),
1512 SetExpr::SetOperation { left, right, .. } => {
1513 set_expr_references_outer_sources(left, outer_source_names)
1514 || set_expr_references_outer_sources(right, outer_source_names)
1515 }
1516 _ => false,
1517 }
1518}
1519
1520fn select_references_outer_sources(select: &Select, outer_source_names: &HashSet<String>) -> bool {
1521 let mut qualifier_prefixes = HashSet::new();
1522 visit_select_expressions(select, &mut |expr| {
1523 collect_qualifier_prefixes_in_expr(expr, &mut qualifier_prefixes);
1524 });
1525
1526 let local_source_names = source_names_in_select(select);
1527 if qualifier_prefixes
1528 .iter()
1529 .any(|name| outer_source_names.contains(name) && !local_source_names.contains(name))
1530 {
1531 return true;
1532 }
1533
1534 for table in &select.from {
1535 if table_factor_references_outer_sources(&table.relation, outer_source_names) {
1536 return true;
1537 }
1538 for join in &table.joins {
1539 if table_factor_references_outer_sources(&join.relation, outer_source_names) {
1540 return true;
1541 }
1542 }
1543 }
1544 false
1545}
1546
1547fn table_factor_references_outer_sources(
1548 table_factor: &TableFactor,
1549 outer_source_names: &HashSet<String>,
1550) -> bool {
1551 match table_factor {
1552 TableFactor::Derived { subquery, .. } => {
1553 query_references_outer_sources(subquery, outer_source_names)
1554 }
1555 TableFactor::NestedJoin {
1556 table_with_joins, ..
1557 } => {
1558 table_factor_references_outer_sources(&table_with_joins.relation, outer_source_names)
1559 || table_with_joins.joins.iter().any(|join| {
1560 table_factor_references_outer_sources(&join.relation, outer_source_names)
1561 })
1562 }
1563 TableFactor::Pivot { table, .. }
1564 | TableFactor::Unpivot { table, .. }
1565 | TableFactor::MatchRecognize { table, .. } => {
1566 table_factor_references_outer_sources(table, outer_source_names)
1567 }
1568 _ => false,
1569 }
1570}
1571
1572fn source_names_in_select(select: &Select) -> HashSet<String> {
1573 let mut names = HashSet::new();
1574 for table in &select.from {
1575 collect_source_names_from_table_factor(&table.relation, &mut names);
1576 for join in &table.joins {
1577 collect_source_names_from_table_factor(&join.relation, &mut names);
1578 }
1579 }
1580 names
1581}
1582
1583fn collect_source_names_from_table_factor(table_factor: &TableFactor, names: &mut HashSet<String>) {
1584 match table_factor {
1585 TableFactor::Table { name, alias, .. } => {
1586 if let Some(last) = name.0.last().and_then(|part| part.as_ident()) {
1587 names.insert(last.value.to_ascii_uppercase());
1588 }
1589 if let Some(alias) = alias {
1590 names.insert(alias.name.value.to_ascii_uppercase());
1591 }
1592 }
1593 TableFactor::Derived {
1594 alias, subquery, ..
1595 } => {
1596 if let Some(alias) = alias {
1597 names.insert(alias.name.value.to_ascii_uppercase());
1598 }
1599 if let Some(with) = &subquery.with {
1600 for cte in &with.cte_tables {
1601 names.insert(cte.alias.name.value.to_ascii_uppercase());
1602 }
1603 }
1604 }
1605 TableFactor::TableFunction { alias, .. }
1606 | TableFactor::Function { alias, .. }
1607 | TableFactor::UNNEST { alias, .. }
1608 | TableFactor::JsonTable { alias, .. }
1609 | TableFactor::OpenJsonTable { alias, .. } => {
1610 if let Some(alias) = alias {
1611 names.insert(alias.name.value.to_ascii_uppercase());
1612 }
1613 }
1614 TableFactor::NestedJoin {
1615 table_with_joins, ..
1616 } => {
1617 collect_source_names_from_table_factor(&table_with_joins.relation, names);
1618 for join in &table_with_joins.joins {
1619 collect_source_names_from_table_factor(&join.relation, names);
1620 }
1621 }
1622 TableFactor::Pivot { table, .. }
1623 | TableFactor::Unpivot { table, .. }
1624 | TableFactor::MatchRecognize { table, .. } => {
1625 collect_source_names_from_table_factor(table, names);
1626 }
1627 _ => {}
1628 }
1629}
1630
1631#[cfg(test)]
1632mod tests {
1633 use super::*;
1634 use crate::linter::{config::LintConfig, rule::LintContext, Linter};
1635 use crate::parse_sql;
1636 use crate::types::IssueAutofixApplicability;
1637
1638 fn run(sql: &str) -> Vec<Issue> {
1639 let statements = parse_sql(sql).expect("parse sql");
1640 let linter = Linter::new(LintConfig::default());
1641 let stmt = &statements[0];
1642 let ctx = LintContext {
1643 sql,
1644 statement_range: 0..sql.len(),
1645 statement_index: 0,
1646 };
1647 linter.check_statement(stmt, &ctx)
1648 }
1649
1650 fn apply_issue_autofix(sql: &str, issue: &Issue) -> Option<String> {
1651 let autofix = issue.autofix.as_ref()?;
1652 let mut out = sql.to_string();
1653 let mut edits = autofix.edits.clone();
1654 edits.sort_by_key(|edit| (edit.span.start, edit.span.end));
1655 for edit in edits.into_iter().rev() {
1656 out.replace_range(edit.span.start..edit.span.end, &edit.replacement);
1657 }
1658 Some(out)
1659 }
1660
1661 #[test]
1662 fn default_does_not_flag_subquery_in_from() {
1663 let issues = run("SELECT * FROM (SELECT * FROM t) sub");
1664 assert!(!issues
1665 .iter()
1666 .any(|issue| issue.code == issue_codes::LINT_ST_005));
1667 }
1668
1669 #[test]
1670 fn default_flags_subquery_in_join() {
1671 let issues = run("SELECT * FROM t JOIN (SELECT * FROM u) sub ON t.id = sub.id");
1672 assert!(issues
1673 .iter()
1674 .any(|issue| issue.code == issue_codes::LINT_ST_005));
1675 }
1676
1677 #[test]
1678 fn default_allows_correlated_subquery_join_without_alias() {
1679 let issues = run("SELECT pd.* \
1680 FROM person_dates \
1681 JOIN (SELECT * FROM events WHERE events.name = person_dates.name)");
1682 assert!(!issues
1683 .iter()
1684 .any(|issue| issue.code == issue_codes::LINT_ST_005));
1685 }
1686
1687 #[test]
1688 fn default_allows_correlated_subquery_join_with_alias_reference() {
1689 let issues = run("SELECT pd.* \
1690 FROM person_dates AS pd \
1691 JOIN (SELECT * FROM events AS ce WHERE ce.name = pd.name)");
1692 assert!(!issues
1693 .iter()
1694 .any(|issue| issue.code == issue_codes::LINT_ST_005));
1695 }
1696
1697 #[test]
1698 fn default_allows_correlated_subquery_join_with_outer_table_name_reference() {
1699 let issues = run("SELECT pd.* \
1700 FROM person_dates AS pd \
1701 JOIN (SELECT * FROM events AS ce WHERE ce.name = person_dates.name)");
1702 assert!(!issues
1703 .iter()
1704 .any(|issue| issue.code == issue_codes::LINT_ST_005));
1705 }
1706
1707 #[test]
1708 fn does_not_flag_cte_usage() {
1709 let issues = run("WITH sub AS (SELECT * FROM t) SELECT * FROM sub");
1710 assert!(!issues
1711 .iter()
1712 .any(|issue| issue.code == issue_codes::LINT_ST_005));
1713 }
1714
1715 #[test]
1716 fn does_not_flag_scalar_subquery_in_where() {
1717 let issues = run("SELECT * FROM t WHERE id IN (SELECT id FROM u)");
1718 assert!(!issues
1719 .iter()
1720 .any(|issue| issue.code == issue_codes::LINT_ST_005));
1721 }
1722
1723 #[test]
1724 fn forbid_subquery_in_join_does_not_flag_from_subquery() {
1725 let sql = "SELECT * FROM (SELECT * FROM t) sub";
1726 let statements = parse_sql(sql).expect("parse sql");
1727 let rule = StructureSubquery::from_config(&LintConfig {
1728 enabled: true,
1729 disabled_rules: vec![],
1730 rule_configs: std::collections::BTreeMap::from([(
1731 "structure.subquery".to_string(),
1732 serde_json::json!({"forbid_subquery_in": "join"}),
1733 )]),
1734 });
1735 let issues = rule.check(
1736 &statements[0],
1737 &LintContext {
1738 sql,
1739 statement_range: 0..sql.len(),
1740 statement_index: 0,
1741 },
1742 );
1743 assert!(issues.is_empty());
1744 }
1745
1746 #[test]
1747 fn forbid_subquery_in_from_emits_unsafe_cte_autofix_for_simple_case() {
1748 let sql = "SELECT * FROM (SELECT 1) sub";
1749 let statements = parse_sql(sql).expect("parse sql");
1750 let rule = StructureSubquery::from_config(&LintConfig {
1751 enabled: true,
1752 disabled_rules: vec![],
1753 rule_configs: std::collections::BTreeMap::from([(
1754 "LINT_ST_005".to_string(),
1755 serde_json::json!({"forbid_subquery_in": "from"}),
1756 )]),
1757 });
1758 let issues = rule.check(
1759 &statements[0],
1760 &LintContext {
1761 sql,
1762 statement_range: 0..sql.len(),
1763 statement_index: 0,
1764 },
1765 );
1766 assert_eq!(issues.len(), 1);
1767 let autofix = issues[0].autofix.as_ref().expect("autofix metadata");
1768 assert_eq!(autofix.applicability, IssueAutofixApplicability::Unsafe);
1769 let fixed = apply_issue_autofix(sql, &issues[0]).expect("apply autofix");
1770 assert_eq!(fixed, "WITH sub AS (SELECT 1)\nSELECT * FROM sub");
1771 }
1772
1773 #[test]
1774 fn forbid_subquery_in_from_does_not_flag_join_subquery() {
1775 let sql = "SELECT * FROM t JOIN (SELECT * FROM u) sub ON t.id = sub.id";
1776 let statements = parse_sql(sql).expect("parse sql");
1777 let rule = StructureSubquery::from_config(&LintConfig {
1778 enabled: true,
1779 disabled_rules: vec![],
1780 rule_configs: std::collections::BTreeMap::from([(
1781 "LINT_ST_005".to_string(),
1782 serde_json::json!({"forbid_subquery_in": "from"}),
1783 )]),
1784 });
1785 let issues = rule.check(
1786 &statements[0],
1787 &LintContext {
1788 sql,
1789 statement_range: 0..sql.len(),
1790 statement_index: 0,
1791 },
1792 );
1793 assert!(issues.is_empty());
1794 }
1795
1796 #[test]
1797 fn forbid_both_flags_subquery_inside_cte_body() {
1798 let sql = "WITH b AS (SELECT x, z FROM (SELECT x, z FROM p_cte)) SELECT b.z FROM b";
1799 let statements = parse_sql(sql).expect("parse sql");
1800 let rule = StructureSubquery::from_config(&LintConfig {
1801 enabled: true,
1802 disabled_rules: vec![],
1803 rule_configs: std::collections::BTreeMap::from([(
1804 "structure.subquery".to_string(),
1805 serde_json::json!({"forbid_subquery_in": "both"}),
1806 )]),
1807 });
1808 let issues = rule.check(
1809 &statements[0],
1810 &LintContext {
1811 sql,
1812 statement_range: 0..sql.len(),
1813 statement_index: 0,
1814 },
1815 );
1816 assert_eq!(issues.len(), 1);
1817 }
1818
1819 #[test]
1820 fn forbid_both_flags_subqueries_in_set_operation_second_branch() {
1821 let sql = "SELECT 1 AS value_name UNION SELECT value FROM (SELECT 2 AS value_name) CROSS JOIN (SELECT 1 AS v2)";
1822 let statements = parse_sql(sql).expect("parse sql");
1823 let rule = StructureSubquery::from_config(&LintConfig {
1824 enabled: true,
1825 disabled_rules: vec![],
1826 rule_configs: std::collections::BTreeMap::from([(
1827 "structure.subquery".to_string(),
1828 serde_json::json!({"forbid_subquery_in": "both"}),
1829 )]),
1830 });
1831 let issues = rule.check(
1832 &statements[0],
1833 &LintContext {
1834 sql,
1835 statement_range: 0..sql.len(),
1836 statement_index: 0,
1837 },
1838 );
1839 assert_eq!(issues.len(), 2);
1840 }
1841
1842 fn run_fix(sql: &str, forbid_in: &str) -> Option<String> {
1845 let statements = parse_sql(sql).expect("parse sql");
1846 let rule = StructureSubquery::from_config(&LintConfig {
1847 enabled: true,
1848 disabled_rules: vec![],
1849 rule_configs: std::collections::BTreeMap::from([(
1850 "structure.subquery".to_string(),
1851 serde_json::json!({"forbid_subquery_in": forbid_in}),
1852 )]),
1853 });
1854 let ctx = LintContext {
1855 sql,
1856 statement_range: 0..sql.len(),
1857 statement_index: 0,
1858 };
1859 let issues = rule.check(&statements[0], &ctx);
1860 if issues.is_empty() {
1861 return None;
1862 }
1863 let st05_issue = issues
1864 .iter()
1865 .find(|i| i.code == issue_codes::LINT_ST_005 && i.autofix.is_some())?;
1866 apply_issue_autofix(sql, st05_issue)
1867 }
1868
1869 fn assert_fix_whitespace_eq(actual: &str, expected: &str) {
1870 let norm = |s: &str| s.split_whitespace().collect::<Vec<_>>().join(" ");
1871 assert_eq!(
1872 norm(actual),
1873 norm(expected),
1874 "\n--- actual ---\n{actual}\n--- expected ---\n{expected}\n"
1875 );
1876 }
1877
1878 #[test]
1879 fn fixture_select_fail() {
1880 let sql = "select\n a.x, a.y, b.z\nfrom a\njoin (\n select x, z from b\n) as b on (a.x = b.x)\n";
1881 let expected = "with b as (\n select x, z from b\n)\nselect\n a.x, a.y, b.z\nfrom a\njoin b on (a.x = b.x)\n";
1882 let fixed = run_fix(sql, "join").expect("should produce fix");
1883 assert_fix_whitespace_eq(&fixed, expected);
1884 }
1885
1886 #[test]
1887 fn fixture_cte_select_fail() {
1888 let sql = "with prep as (\n select 1 as x, 2 as z\n)\nselect\n a.x, a.y, b.z\nfrom a\njoin (\n select x, z from b\n) as b on (a.x = b.x)\n";
1889 let expected = "with prep as (\n select 1 as x, 2 as z\n),\nb as (\n select x, z from b\n)\nselect\n a.x, a.y, b.z\nfrom a\njoin b on (a.x = b.x)\n";
1890 let fixed = run_fix(sql, "join").expect("should produce fix");
1891 assert_fix_whitespace_eq(&fixed, expected);
1892 }
1893
1894 #[test]
1895 fn fixture_from_clause_fail() {
1896 let sql = "select\n a.x, a.y\nfrom (\n select * from b\n) as a\n";
1897 let expected = "with a as (\n select * from b\n)\nselect\n a.x, a.y\nfrom a\n";
1898 let fixed = run_fix(sql, "from").expect("should produce fix");
1899 assert_fix_whitespace_eq(&fixed, expected);
1900 }
1901
1902 #[test]
1903 fn fixture_both_clause_fail() {
1904 let sql = "select\n a.x, a.y\nfrom (\n select * from b\n) as a\n";
1905 let expected = "with a as (\n select * from b\n)\nselect\n a.x, a.y\nfrom a\n";
1906 let fixed = run_fix(sql, "both").expect("should produce fix");
1907 assert_fix_whitespace_eq(&fixed, expected);
1908 }
1909
1910 #[test]
1911 fn fixture_cte_with_clashing_name_generates_prep() {
1912 let sql = "with prep_1 as (\n select 1 as x, 2 as z\n)\nselect\n a.x, a.y, z\nfrom a\njoin (\n select x, z from b\n) on a.x = z\n";
1913 let fixed = run_fix(sql, "join").expect("should produce fix");
1914 assert!(
1916 fixed.contains("prep_2"),
1917 "expected prep_2 in output: {fixed}"
1918 );
1919 }
1920
1921 #[test]
1922 fn fixture_set_subquery_in_second_query() {
1923 let sql = "SELECT 1 AS value_name\nUNION\nSELECT value\nFROM (SELECT 2 AS value_name);\n";
1924 let expected = "WITH prep_1 AS (SELECT 2 AS value_name)\nSELECT 1 AS value_name\nUNION\nSELECT value\nFROM prep_1;\n";
1925 let fixed = run_fix(sql, "both").expect("should produce fix");
1926 assert_fix_whitespace_eq(&fixed, expected);
1927 }
1928
1929 #[test]
1930 fn fixture_set_subquery_in_second_query_join() {
1931 let sql = "SELECT 1 AS value_name\nUNION\nSELECT value\nFROM (SELECT 2 AS value_name)\nCROSS JOIN (SELECT 1 as v2);\n";
1932 let expected = "WITH prep_1 AS (SELECT 2 AS value_name),\nprep_2 AS (SELECT 1 as v2)\nSELECT 1 AS value_name\nUNION\nSELECT value\nFROM prep_1\nCROSS JOIN prep_2;\n";
1933 let fixed = run_fix(sql, "both").expect("should produce fix");
1934 assert_fix_whitespace_eq(&fixed, expected);
1935 }
1936
1937 #[test]
1938 fn fixture_with_fail_generates_prep_for_unnamed_subquery() {
1939 let sql = "select\n a.x, a.y, b.z\nfrom a\njoin (\n with d as (\n select x, z from b\n )\n select * from d\n) using (x)\n";
1940 let fixed = run_fix(sql, "join").expect("should produce fix");
1941 assert!(
1942 fixed.contains("prep_1"),
1943 "expected prep_1 in output: {fixed}"
1944 );
1945 }
1946
1947 #[test]
1948 fn fixture_set_fail() {
1949 let sql = "SELECT\n a.x, a.y, b.z\nFROM a\nJOIN (\n select x, z from b\n union\n select x, z from d\n) USING (x)\n";
1950 let fixed = run_fix(sql, "join").expect("should produce fix");
1951 assert!(
1952 fixed.contains("prep_1"),
1953 "expected prep_1 in output: {fixed}"
1954 );
1955 }
1956
1957 #[test]
1958 fn fixture_subquery_in_cte_both() {
1959 let sql = "with b as (\n select x, z from (\n select x, z from p_cte\n )\n)\nselect b.z\nfrom b\n";
1960 let expected = "with prep_1 as (\n select x, z from p_cte\n ),\nb as (\n select x, z from prep_1\n)\nselect b.z\nfrom b\n";
1961 let fixed = run_fix(sql, "both").expect("should produce fix");
1962 assert_fix_whitespace_eq(&fixed, expected);
1963 }
1964
1965 #[test]
1966 fn fixture_issue_3598_avoid_looping_1() {
1967 let sql = "WITH cte1 AS (\n SELECT a\n FROM (SELECT a)\n)\nSELECT a FROM cte1\n";
1968 let expected = "WITH prep_1 AS (SELECT a),\ncte1 AS (\n SELECT a\n FROM prep_1\n)\nSELECT a FROM cte1\n";
1969 let fixed = run_fix(sql, "both").expect("should produce fix");
1970 assert_fix_whitespace_eq(&fixed, expected);
1971 }
1972
1973 #[test]
1974 fn fixture_issue_3598_avoid_looping_2() {
1975 let sql = "WITH cte1 AS (\n SELECT *\n FROM (SELECT * FROM mongo.temp)\n)\nSELECT * FROM cte1\n";
1976 let expected = "WITH prep_1 AS (SELECT * FROM mongo.temp),\ncte1 AS (\n SELECT *\n FROM prep_1\n)\nSELECT * FROM cte1\n";
1977 let fixed = run_fix(sql, "both").expect("should produce fix");
1978 assert_fix_whitespace_eq(&fixed, expected);
1979 }
1980
1981 #[test]
1982 fn fixture_multijoin_both() {
1983 let sql = "select\n a.x, d.x as foo, a.y, b.z\nfrom (select a, x from foo) a\njoin d using(x)\njoin (\n select x, z from b\n) as b using (x)\n";
1984 let fixed = run_fix(sql, "both").expect("should produce fix");
1985 assert!(
1987 fixed.to_ascii_lowercase().contains("with"),
1988 "expected WITH in output: {fixed}"
1989 );
1990 }
1991}