1use crate::linter::rule::{LintContext, LintRule};
7use crate::types::{issue_codes, Dialect, Issue, IssueAutofixApplicability, IssuePatchEdit};
8use sqlparser::ast::{Query, Statement};
9use sqlparser::keywords::Keyword;
10use sqlparser::tokenizer::{Token, TokenWithSpan, Tokenizer, Whitespace};
11use std::ops::Range;
12
13pub struct LayoutCteBracket;
14
15impl LintRule for LayoutCteBracket {
16 fn code(&self) -> &'static str {
17 issue_codes::LINT_LT_007
18 }
19
20 fn name(&self) -> &'static str {
21 "Layout CTE bracket"
22 }
23
24 fn description(&self) -> &'static str {
25 "'WITH' clause closing bracket should be on a new line."
26 }
27
28 fn check(&self, statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
29 let tokens = tokenize_with_offsets_for_context(ctx);
30 let violation = misplaced_cte_closing_bracket_for_statement(
31 statement,
32 ctx,
33 tokens.as_deref(),
34 )
35 .or_else(|| {
36 misplaced_cte_closing_bracket(ctx.statement_sql(), ctx.dialect(), tokens.as_deref())
37 });
38
39 if let Some(((start, end), fix_span)) = violation {
40 let mut issue = Issue::warning(
41 issue_codes::LINT_LT_007,
42 "CTE AS clause appears to be missing surrounding brackets.",
43 )
44 .with_statement(ctx.statement_index)
45 .with_span(ctx.span_from_statement_offset(start, end));
46 if let Some((fix_start, fix_end)) = fix_span {
47 issue = issue.with_autofix_edits(
48 IssueAutofixApplicability::Safe,
49 vec![IssuePatchEdit::new(
50 ctx.span_from_statement_offset(fix_start, fix_end),
51 "\n",
52 )],
53 );
54 }
55 vec![issue]
56 } else {
57 Vec::new()
58 }
59 }
60}
61
62#[derive(Clone)]
63struct LocatedToken {
64 token: Token,
65 start: usize,
66 end: usize,
67 start_line: usize,
68 end_line: usize,
69}
70
71type Lt07Span = (usize, usize);
72type Lt07AutofixSpan = (usize, usize);
73type Lt07Violation = (Lt07Span, Option<Lt07AutofixSpan>);
74
75fn misplaced_cte_closing_bracket_for_statement(
76 statement: &Statement,
77 ctx: &LintContext,
78 tokens: Option<&[LocatedToken]>,
79) -> Option<Lt07Violation> {
80 let query = match statement {
81 Statement::Query(query) => query.as_ref(),
82 Statement::CreateView { query, .. } => query.as_ref(),
83 _ => return None,
84 };
85
86 misplaced_cte_closing_bracket_in_query(query, ctx, tokens)
87}
88
89fn misplaced_cte_closing_bracket_in_query(
90 query: &Query,
91 ctx: &LintContext,
92 tokens: Option<&[LocatedToken]>,
93) -> Option<Lt07Violation> {
94 let with = query.with.as_ref()?;
95 let sql = ctx.statement_sql();
96 let owned_tokens;
97 let tokens = if let Some(tokens) = tokens {
98 tokens
99 } else {
100 owned_tokens = tokenize_with_offsets(sql, ctx.dialect())?;
101 &owned_tokens
102 };
103
104 for cte in &with.cte_tables {
105 let Some(close_abs) = token_start_offset(ctx.sql, &cte.closing_paren_token.0) else {
106 continue;
107 };
108 if close_abs < ctx.statement_range.start || close_abs >= ctx.statement_range.end {
109 continue;
110 }
111 let close_rel = close_abs - ctx.statement_range.start;
112 let Some(close_idx) = tokens
113 .iter()
114 .position(|token| matches!(token.token, Token::RParen) && token.start == close_rel)
115 else {
116 continue;
117 };
118 let Some(open_idx) = matching_open_paren_index(tokens, close_idx) else {
119 continue;
120 };
121
122 let body_end = tokens[close_idx].start;
123 if body_end > sql.len() {
124 continue;
125 }
126 if !cte_body_has_line_break(tokens, sql, open_idx, close_idx) {
127 continue;
128 }
129 let Some(prev_idx) = last_non_spacing_token_before_on_same_line(tokens, close_idx) else {
130 continue;
131 };
132
133 let report_span = (tokens[close_idx].start, tokens[close_idx].end);
134 let fix_span = safe_newline_fix_span(sql, tokens, prev_idx, close_idx);
135 return Some((report_span, fix_span));
136 }
137
138 None
139}
140
141fn matching_open_paren_index(tokens: &[LocatedToken], close_idx: usize) -> Option<usize> {
142 if !matches!(tokens.get(close_idx)?.token, Token::RParen) {
143 return None;
144 }
145
146 let mut depth = 0usize;
147 for index in (0..=close_idx).rev() {
148 match tokens[index].token {
149 Token::RParen => depth += 1,
150 Token::LParen => {
151 depth = depth.saturating_sub(1);
152 if depth == 0 {
153 return Some(index);
154 }
155 }
156 _ => {}
157 }
158 }
159 None
160}
161
162fn tokenize_with_offsets(sql: &str, dialect: Dialect) -> Option<Vec<LocatedToken>> {
163 let dialect = dialect.to_sqlparser_dialect();
164 let mut tokenizer = Tokenizer::new(dialect.as_ref(), sql);
165 let tokens = tokenizer.tokenize_with_location().ok()?;
166
167 let mut out = Vec::with_capacity(tokens.len());
168 for token in tokens {
169 let Some(start) = line_col_to_offset(
170 sql,
171 token.span.start.line as usize,
172 token.span.start.column as usize,
173 ) else {
174 continue;
175 };
176 let Some(end) = line_col_to_offset(
177 sql,
178 token.span.end.line as usize,
179 token.span.end.column as usize,
180 ) else {
181 continue;
182 };
183 out.push(LocatedToken {
184 token: token.token,
185 start,
186 end,
187 start_line: token.span.start.line as usize,
188 end_line: token.span.end.line as usize,
189 });
190 }
191
192 Some(out)
193}
194
195fn tokenize_with_offsets_for_context(ctx: &LintContext) -> Option<Vec<LocatedToken>> {
196 let statement_start = ctx.statement_range.start;
197 let from_document_tokens = ctx.with_document_tokens(|tokens| {
198 if tokens.is_empty() {
199 return None;
200 }
201
202 let mut out = Vec::new();
203 let mut covered_ranges = Vec::new();
204 for token in tokens {
205 let (start, end) = token_with_span_offsets(ctx.sql, token)?;
206 if start < ctx.statement_range.start || end > ctx.statement_range.end {
207 continue;
208 }
209 if !token_span_matches_source(ctx.sql, start, end, &token.token) {
213 return None;
214 }
215 covered_ranges.push(start..end);
216
217 out.push(LocatedToken {
218 token: token.token.clone(),
219 start: start - statement_start,
220 end: end - statement_start,
221 start_line: token.span.start.line as usize,
222 end_line: token.span.end.line as usize,
223 });
224 }
225
226 if !gaps_are_whitespace_only(ctx.sql, ctx.statement_range.clone(), &covered_ranges) {
227 return None;
228 }
229 Some(out)
230 });
231
232 if let Some(tokens) = from_document_tokens {
233 return Some(tokens);
234 }
235
236 tokenize_with_offsets(ctx.statement_sql(), ctx.dialect())
237}
238
239fn token_span_matches_source(sql: &str, start: usize, end: usize, token: &Token) -> bool {
240 if start > end || end > sql.len() {
241 return false;
242 }
243
244 match token {
245 Token::Word(word) => source_word_matches(sql, start, end, word.value.as_str()),
246 Token::LParen => sql.get(start..end) == Some("("),
247 Token::RParen => sql.get(start..end) == Some(")"),
248 _ => true,
249 }
250}
251
252fn source_word_matches(sql: &str, start: usize, end: usize, value: &str) -> bool {
253 let Some(raw) = sql.get(start..end) else {
254 return false;
255 };
256 let normalized = raw.trim_matches(|ch| matches!(ch, '"' | '`' | '[' | ']'));
257 normalized.eq_ignore_ascii_case(value)
258}
259
260fn gaps_are_whitespace_only(sql: &str, range: Range<usize>, covered: &[Range<usize>]) -> bool {
261 if range.start > range.end || range.end > sql.len() {
262 return false;
263 }
264
265 let mut spans = covered.to_vec();
266 spans.sort_by_key(|span| (span.start, span.end));
267
268 let mut cursor = range.start;
269 for span in spans {
270 if span.end <= range.start || span.start >= range.end {
271 continue;
272 }
273 let start = span.start.max(range.start);
274 let end = span.end.min(range.end);
275 if start > cursor {
276 let Some(gap) = sql.get(cursor..start) else {
277 return false;
278 };
279 if gap.chars().any(|ch| !ch.is_whitespace()) {
280 return false;
281 }
282 }
283 cursor = cursor.max(end);
284 }
285
286 if cursor < range.end {
287 let Some(gap) = sql.get(cursor..range.end) else {
288 return false;
289 };
290 if gap.chars().any(|ch| !ch.is_whitespace()) {
291 return false;
292 }
293 }
294
295 true
296}
297
298fn token_start_offset(sql: &str, token: &sqlparser::tokenizer::TokenWithSpan) -> Option<usize> {
299 line_col_to_offset(
300 sql,
301 token.span.start.line as usize,
302 token.span.start.column as usize,
303 )
304}
305
306fn token_with_span_offsets(sql: &str, token: &TokenWithSpan) -> Option<(usize, usize)> {
307 let start = line_col_to_offset(
308 sql,
309 token.span.start.line as usize,
310 token.span.start.column as usize,
311 )?;
312 let end = line_col_to_offset(
313 sql,
314 token.span.end.line as usize,
315 token.span.end.column as usize,
316 )?;
317 Some((start, end))
318}
319
320fn line_col_to_offset(sql: &str, line: usize, column: usize) -> Option<usize> {
321 if line == 0 || column == 0 {
322 return None;
323 }
324
325 let mut current_line = 1usize;
326 let mut current_col = 1usize;
327
328 for (offset, ch) in sql.char_indices() {
329 if current_line == line && current_col == column {
330 return Some(offset);
331 }
332
333 if ch == '\n' {
334 current_line += 1;
335 current_col = 1;
336 } else {
337 current_col += 1;
338 }
339 }
340
341 if current_line == line && current_col == column {
342 return Some(sql.len());
343 }
344
345 None
346}
347
348fn misplaced_cte_closing_bracket(
349 sql: &str,
350 dialect: Dialect,
351 tokens: Option<&[LocatedToken]>,
352) -> Option<Lt07Violation> {
353 let owned_tokens;
354 let tokens = if let Some(tokens) = tokens {
355 tokens
356 } else {
357 owned_tokens = tokenize_with_offsets(sql, dialect)?;
358 &owned_tokens
359 };
360
361 if tokens.is_empty() {
362 return None;
363 }
364
365 let has_with = tokens
366 .iter()
367 .any(|token| matches!(token.token, Token::Word(ref word) if word.keyword == Keyword::WITH));
368 if !has_with {
369 return None;
370 }
371
372 let mut index = 0usize;
373 while let Some(as_idx) = find_next_as_keyword(tokens, index) {
374 let Some(open_idx) = next_non_trivia_index(tokens, as_idx + 1) else {
375 index = as_idx + 1;
376 continue;
377 };
378 if !matches!(tokens[open_idx].token, Token::LParen) {
379 index = as_idx + 1;
380 continue;
381 }
382
383 let Some(close_idx) = matching_close_paren_index(tokens, open_idx) else {
384 index = open_idx + 1;
385 continue;
386 };
387
388 let body_start = tokens[open_idx].end;
389 let body_end = tokens[close_idx].start;
390 if body_start < body_end
391 && body_end <= sql.len()
392 && cte_body_has_line_break(tokens, sql, open_idx, close_idx)
393 {
394 if let Some(prev_idx) = last_non_spacing_token_before_on_same_line(tokens, close_idx) {
395 let report_span = (tokens[close_idx].start, tokens[close_idx].end);
396 let fix_span = safe_newline_fix_span(sql, tokens, prev_idx, close_idx);
397 return Some((report_span, fix_span));
398 }
399 }
400
401 index = close_idx + 1;
402 }
403
404 None
405}
406
407fn cte_body_has_line_break(
408 tokens: &[LocatedToken],
409 sql: &str,
410 open_idx: usize,
411 close_idx: usize,
412) -> bool {
413 if close_idx <= open_idx + 1 {
414 return false;
415 }
416
417 tokens
418 .iter()
419 .take(close_idx)
420 .skip(open_idx + 1)
421 .any(|token| token.end <= sql.len() && count_line_breaks(&sql[token.start..token.end]) > 0)
422}
423
424fn last_non_spacing_token_before_on_same_line(
425 tokens: &[LocatedToken],
426 close_idx: usize,
427) -> Option<usize> {
428 let close = tokens.get(close_idx)?;
429 let line = close.start_line;
430
431 for (index, token) in tokens[..close_idx].iter().enumerate().rev() {
432 if token.end_line < line {
433 break;
434 }
435 if token.start_line != line {
436 continue;
437 }
438 if is_spacing_whitespace(&token.token) {
439 continue;
440 }
441 return Some(index);
442 }
443
444 None
445}
446
447fn safe_newline_fix_span(
448 sql: &str,
449 tokens: &[LocatedToken],
450 prev_idx: usize,
451 close_idx: usize,
452) -> Option<Lt07AutofixSpan> {
453 let gap_start = tokens.get(prev_idx)?.end;
454 let gap_end = tokens.get(close_idx)?.start;
455 if gap_start > gap_end || gap_end > sql.len() {
456 return None;
457 }
458
459 let gap = &sql[gap_start..gap_end];
460 if gap.chars().all(char::is_whitespace) && !gap.contains('\n') && !gap.contains('\r') {
461 Some((gap_start, gap_end))
462 } else {
463 None
464 }
465}
466
467#[cfg(test)]
468fn has_misplaced_cte_closing_bracket(sql: &str, dialect: Dialect) -> bool {
469 misplaced_cte_closing_bracket(sql, dialect, None).is_some()
470}
471
472fn find_next_as_keyword(tokens: &[LocatedToken], mut index: usize) -> Option<usize> {
473 while index < tokens.len() {
474 if matches!(
475 tokens[index].token,
476 Token::Word(ref word) if word.keyword == Keyword::AS
477 ) {
478 return Some(index);
479 }
480 index += 1;
481 }
482 None
483}
484
485fn next_non_trivia_index(tokens: &[LocatedToken], mut index: usize) -> Option<usize> {
486 while index < tokens.len() {
487 if !is_trivia_token(&tokens[index].token) {
488 return Some(index);
489 }
490 index += 1;
491 }
492 None
493}
494
495fn matching_close_paren_index(tokens: &[LocatedToken], open_idx: usize) -> Option<usize> {
496 if !matches!(tokens.get(open_idx)?.token, Token::LParen) {
497 return None;
498 }
499
500 let mut depth = 0usize;
501 for (idx, token) in tokens.iter().enumerate().skip(open_idx) {
502 match token.token {
503 Token::LParen => depth += 1,
504 Token::RParen => {
505 depth -= 1;
506 if depth == 0 {
507 return Some(idx);
508 }
509 }
510 _ => {}
511 }
512 }
513
514 None
515}
516
517fn is_trivia_token(token: &Token) -> bool {
518 matches!(
519 token,
520 Token::Whitespace(Whitespace::Space | Whitespace::Tab | Whitespace::Newline)
521 | Token::Whitespace(Whitespace::SingleLineComment { .. })
522 | Token::Whitespace(Whitespace::MultiLineComment(_))
523 )
524}
525
526fn is_spacing_whitespace(token: &Token) -> bool {
527 matches!(
528 token,
529 Token::Whitespace(Whitespace::Space | Whitespace::Tab | Whitespace::Newline)
530 )
531}
532
533fn count_line_breaks(text: &str) -> usize {
534 let mut count = 0usize;
535 let mut chars = text.chars().peekable();
536 while let Some(ch) = chars.next() {
537 if ch == '\n' {
538 count += 1;
539 continue;
540 }
541 if ch == '\r' {
542 count += 1;
543 if matches!(chars.peek(), Some('\n')) {
544 let _ = chars.next();
545 }
546 }
547 }
548 count
549}
550
551#[cfg(test)]
552mod tests {
553 use super::*;
554 use crate::parser::parse_sql;
555 use crate::types::IssueAutofixApplicability;
556
557 fn run(sql: &str) -> Vec<Issue> {
558 let statements = parse_sql(sql).expect("parse");
559 let rule = LayoutCteBracket;
560 statements
561 .iter()
562 .enumerate()
563 .flat_map(|(index, statement)| {
564 rule.check(
565 statement,
566 &LintContext {
567 sql,
568 statement_range: 0..sql.len(),
569 statement_index: index,
570 },
571 )
572 })
573 .collect()
574 }
575
576 fn apply_issue_autofix(sql: &str, issue: &Issue) -> Option<String> {
577 let autofix = issue.autofix.as_ref()?;
578 let mut out = sql.to_string();
579 let mut edits = autofix.edits.clone();
580 edits.sort_by_key(|edit| (edit.span.start, edit.span.end));
581 for edit in edits.into_iter().rev() {
582 out.replace_range(edit.span.start..edit.span.end, &edit.replacement);
583 }
584 Some(out)
585 }
586
587 #[test]
588 fn flags_closing_paren_after_sql_code_in_multiline_cte() {
589 let sql = "with cte_1 as (\n select foo\n from tbl_1)\nselect * from cte_1";
590 let issues = run(sql);
591 assert_eq!(issues.len(), 1);
592 assert_eq!(issues[0].code, issue_codes::LINT_LT_007);
593 let autofix = issues[0].autofix.as_ref().expect("autofix metadata");
594 assert_eq!(autofix.applicability, IssueAutofixApplicability::Safe);
595 let fixed = apply_issue_autofix(sql, &issues[0]).expect("apply autofix");
596 assert_eq!(
597 fixed,
598 "with cte_1 as (\n select foo\n from tbl_1\n)\nselect * from cte_1"
599 );
600 }
601
602 #[test]
603 fn does_not_flag_single_line_cte_body() {
604 assert!(run("WITH cte AS (SELECT 1) SELECT * FROM cte").is_empty());
605 }
606
607 #[test]
608 fn does_not_flag_multiline_cte_with_own_line_close() {
609 let sql = "with cte as (\n select 1\n) select * from cte";
610 assert!(run(sql).is_empty());
611 }
612
613 #[test]
614 fn flags_templated_close_paren_on_same_line_as_cte_body_code() {
615 let sql =
616 "with\n{% if true %}\n cte as (\n select 1)\n{% endif %}\nselect * from cte";
617 assert!(has_misplaced_cte_closing_bracket(sql, Dialect::Generic));
618 }
619
620 #[test]
621 fn flags_close_paren_when_comment_precedes_on_same_line() {
622 let sql = "WITH cte AS (\n SELECT 1 /* trailing comment */)\nSELECT * FROM cte";
623 let issues = run(sql);
624 assert_eq!(issues.len(), 1);
625 assert_eq!(issues[0].code, issue_codes::LINT_LT_007);
626 let fixed = apply_issue_autofix(sql, &issues[0]).expect("apply autofix");
627 assert_eq!(
628 fixed,
629 "WITH cte AS (\n SELECT 1 /* trailing comment */\n)\nSELECT * FROM cte"
630 );
631 }
632}