Skip to main content

flowscope_core/linter/rules/
tq_002.rs

1//! LINT_TQ_002: TSQL procedure BEGIN/END block.
2//!
3//! SQLFluff TQ02 parity: procedures with multiple statements should include a
4//! `BEGIN`/`END` block.
5
6use crate::linter::rule::{LintContext, LintRule};
7use crate::types::{issue_codes, Dialect, Issue, IssueAutofixApplicability, IssuePatchEdit, Span};
8use sqlparser::ast::Statement;
9
10pub struct TsqlProcedureBeginEnd;
11
12impl LintRule for TsqlProcedureBeginEnd {
13    fn code(&self) -> &'static str {
14        issue_codes::LINT_TQ_002
15    }
16
17    fn name(&self) -> &'static str {
18        "TSQL procedure BEGIN/END"
19    }
20
21    fn description(&self) -> &'static str {
22        "Procedure bodies with multiple statements should be wrapped in BEGIN/END."
23    }
24
25    fn check(&self, _statement: &Statement, ctx: &LintContext) -> Vec<Issue> {
26        if ctx.dialect() != Dialect::Mssql {
27            return Vec::new();
28        }
29
30        // Treat TQ02 as source/document scoped so best-effort parser recovery
31        // (which may parse only later statement slices) still allows detection.
32        if ctx.statement_index != 0 {
33            return Vec::new();
34        }
35
36        let has_violation = procedure_requires_begin_end_from_sql(ctx.sql);
37
38        if has_violation {
39            let mut issue = Issue::warning(
40                issue_codes::LINT_TQ_002,
41                "Stored procedures with multiple statements should include BEGIN/END block.",
42            )
43            .with_statement(ctx.statement_index);
44
45            let autofix_edits = tq002_autofix_edits(ctx.sql)
46                .into_iter()
47                .map(|edit| IssuePatchEdit::new(Span::new(edit.start, edit.end), edit.replacement))
48                .collect::<Vec<_>>();
49            if !autofix_edits.is_empty() {
50                issue = issue.with_autofix_edits(IssueAutofixApplicability::Safe, autofix_edits);
51            }
52
53            vec![issue]
54        } else {
55            Vec::new()
56        }
57    }
58}
59
60struct Tq002AutofixEdit {
61    start: usize,
62    end: usize,
63    replacement: String,
64}
65
66#[derive(Clone, Copy)]
67struct ProcedureBodyLayout {
68    as_end: usize,
69    body_start: usize,
70    body_end: usize,
71    has_begin: bool,
72    statement_count: usize,
73}
74
75fn procedure_requires_begin_end_from_sql(sql: &str) -> bool {
76    let Some(layout) = procedure_body_layout(sql) else {
77        return false;
78    };
79    !layout.has_begin && layout.statement_count > 1
80}
81
82fn tq002_autofix_edits(sql: &str) -> Vec<Tq002AutofixEdit> {
83    let Some(layout) = procedure_body_layout(sql) else {
84        return Vec::new();
85    };
86    if layout.has_begin || layout.statement_count <= 1 {
87        return Vec::new();
88    }
89
90    let multiline_body = sql.as_bytes()[layout.as_end..layout.body_start]
91        .iter()
92        .any(|byte| matches!(*byte, b'\n' | b'\r'));
93
94    let begin_replacement = if multiline_body { "BEGIN\n" } else { "BEGIN " };
95    let end_replacement = if multiline_body { "\nEND" } else { " END" };
96
97    vec![
98        Tq002AutofixEdit {
99            start: layout.body_start,
100            end: layout.body_start,
101            replacement: begin_replacement.to_string(),
102        },
103        Tq002AutofixEdit {
104            start: layout.body_end,
105            end: layout.body_end,
106            replacement: end_replacement.to_string(),
107        },
108    ]
109}
110
111fn procedure_body_layout(sql: &str) -> Option<ProcedureBodyLayout> {
112    let bytes = sql.as_bytes();
113    let header_end = procedure_header_end(bytes)?;
114    let as_end = find_next_keyword(bytes, header_end, b"AS")?;
115    let body_start = skip_ascii_whitespace(bytes, as_end);
116    if body_start >= bytes.len() {
117        return None;
118    }
119
120    let body_end = trim_ascii_whitespace_end(bytes);
121    if body_end <= body_start {
122        return None;
123    }
124
125    let has_begin = match_ascii_keyword_at(bytes, body_start, b"BEGIN").is_some();
126    let statement_count = count_body_statements(&sql[body_start..body_end]);
127
128    Some(ProcedureBodyLayout {
129        as_end,
130        body_start,
131        body_end,
132        has_begin,
133        statement_count,
134    })
135}
136
137fn procedure_header_end(bytes: &[u8]) -> Option<usize> {
138    let mut index = skip_ascii_whitespace_and_comments(bytes, 0);
139
140    if let Some(create_end) = match_ascii_keyword_at(bytes, index, b"CREATE") {
141        index = skip_ascii_whitespace_and_comments(bytes, create_end);
142        if let Some(or_end) = match_ascii_keyword_at(bytes, index, b"OR") {
143            index = skip_ascii_whitespace_and_comments(bytes, or_end);
144            let alter_end = match_ascii_keyword_at(bytes, index, b"ALTER")?;
145            index = skip_ascii_whitespace_and_comments(bytes, alter_end);
146        }
147        let proc_end = match_procedure_keyword(bytes, index)?;
148        return Some(skip_ascii_whitespace_and_comments(bytes, proc_end));
149    }
150
151    if let Some(alter_end) = match_ascii_keyword_at(bytes, index, b"ALTER") {
152        index = skip_ascii_whitespace_and_comments(bytes, alter_end);
153        let proc_end = match_procedure_keyword(bytes, index)?;
154        return Some(skip_ascii_whitespace_and_comments(bytes, proc_end));
155    }
156
157    None
158}
159
160fn match_procedure_keyword(bytes: &[u8], start: usize) -> Option<usize> {
161    match_ascii_keyword_at(bytes, start, b"PROCEDURE")
162        .or_else(|| match_ascii_keyword_at(bytes, start, b"PROC"))
163}
164
165fn find_next_keyword(bytes: &[u8], mut index: usize, keyword_upper: &[u8]) -> Option<usize> {
166    while index < bytes.len() {
167        index = skip_ascii_whitespace_and_comments(bytes, index);
168        if index >= bytes.len() {
169            return None;
170        }
171
172        if let Some(end) = match_ascii_keyword_at(bytes, index, keyword_upper) {
173            return Some(end);
174        }
175
176        if bytes[index] == b'\'' {
177            index = skip_single_quoted_literal(bytes, index);
178            continue;
179        }
180        if bytes[index] == b'"' {
181            index = skip_double_quoted_literal(bytes, index);
182            continue;
183        }
184        if bytes[index] == b'[' {
185            index = skip_bracket_identifier(bytes, index);
186            continue;
187        }
188
189        index += 1;
190    }
191
192    None
193}
194
195fn count_body_statements(sql: &str) -> usize {
196    let bytes = sql.as_bytes();
197    let mut index = 0usize;
198    let mut statement_count = 0usize;
199    let mut statement_has_code = false;
200    let mut paren_depth = 0usize;
201
202    while index < bytes.len() {
203        if bytes[index] == b'-' && index + 1 < bytes.len() && bytes[index + 1] == b'-' {
204            index = skip_line_comment(bytes, index);
205            continue;
206        }
207        if bytes[index] == b'/' && index + 1 < bytes.len() && bytes[index + 1] == b'*' {
208            index = skip_block_comment(bytes, index);
209            continue;
210        }
211        if bytes[index] == b'\'' {
212            statement_has_code = true;
213            index = skip_single_quoted_literal(bytes, index);
214            continue;
215        }
216        if bytes[index] == b'"' {
217            statement_has_code = true;
218            index = skip_double_quoted_literal(bytes, index);
219            continue;
220        }
221        if bytes[index] == b'[' {
222            statement_has_code = true;
223            index = skip_bracket_identifier(bytes, index);
224            continue;
225        }
226
227        match bytes[index] {
228            b'(' => {
229                statement_has_code = true;
230                paren_depth += 1;
231                index += 1;
232            }
233            b')' => {
234                statement_has_code = true;
235                paren_depth = paren_depth.saturating_sub(1);
236                index += 1;
237            }
238            b';' if paren_depth == 0 => {
239                if statement_has_code {
240                    statement_count += 1;
241                    statement_has_code = false;
242                }
243                index += 1;
244            }
245            byte if is_ascii_whitespace_byte(byte) => {
246                index += 1;
247            }
248            _ => {
249                statement_has_code = true;
250                index += 1;
251            }
252        }
253    }
254
255    if statement_has_code {
256        statement_count += 1;
257    }
258
259    statement_count
260}
261
262fn trim_ascii_whitespace_end(bytes: &[u8]) -> usize {
263    let mut tail = bytes.len();
264    while tail > 0 && is_ascii_whitespace_byte(bytes[tail - 1]) {
265        tail -= 1;
266    }
267    tail
268}
269
270fn skip_ascii_whitespace_and_comments(bytes: &[u8], mut index: usize) -> usize {
271    loop {
272        index = skip_ascii_whitespace(bytes, index);
273        if index >= bytes.len() {
274            return index;
275        }
276        if bytes[index] == b'-' && index + 1 < bytes.len() && bytes[index + 1] == b'-' {
277            index = skip_line_comment(bytes, index);
278            continue;
279        }
280        if bytes[index] == b'/' && index + 1 < bytes.len() && bytes[index + 1] == b'*' {
281            index = skip_block_comment(bytes, index);
282            continue;
283        }
284        return index;
285    }
286}
287
288fn skip_line_comment(bytes: &[u8], mut index: usize) -> usize {
289    index += 2;
290    while index < bytes.len() && !matches!(bytes[index], b'\n' | b'\r') {
291        index += 1;
292    }
293    index
294}
295
296fn skip_block_comment(bytes: &[u8], mut index: usize) -> usize {
297    index += 2;
298    while index + 1 < bytes.len() {
299        if bytes[index] == b'*' && bytes[index + 1] == b'/' {
300            return index + 2;
301        }
302        index += 1;
303    }
304    bytes.len()
305}
306
307fn skip_single_quoted_literal(bytes: &[u8], mut index: usize) -> usize {
308    index += 1;
309    while index < bytes.len() {
310        if bytes[index] == b'\'' {
311            if index + 1 < bytes.len() && bytes[index + 1] == b'\'' {
312                index += 2;
313            } else {
314                return index + 1;
315            }
316        } else {
317            index += 1;
318        }
319    }
320    bytes.len()
321}
322
323fn skip_double_quoted_literal(bytes: &[u8], mut index: usize) -> usize {
324    index += 1;
325    while index < bytes.len() {
326        if bytes[index] == b'"' {
327            if index + 1 < bytes.len() && bytes[index + 1] == b'"' {
328                index += 2;
329            } else {
330                return index + 1;
331            }
332        } else {
333            index += 1;
334        }
335    }
336    bytes.len()
337}
338
339fn skip_bracket_identifier(bytes: &[u8], mut index: usize) -> usize {
340    index += 1;
341    while index < bytes.len() {
342        if bytes[index] == b']' {
343            if index + 1 < bytes.len() && bytes[index + 1] == b']' {
344                index += 2;
345            } else {
346                return index + 1;
347            }
348        } else {
349            index += 1;
350        }
351    }
352    bytes.len()
353}
354
355fn is_ascii_ident_continue(byte: u8) -> bool {
356    byte.is_ascii_alphanumeric() || byte == b'_'
357}
358
359fn is_word_boundary_for_keyword(bytes: &[u8], index: usize) -> bool {
360    index == 0 || index >= bytes.len() || !is_ascii_ident_continue(bytes[index])
361}
362
363fn match_ascii_keyword_at(bytes: &[u8], start: usize, keyword_upper: &[u8]) -> Option<usize> {
364    let end = start.checked_add(keyword_upper.len())?;
365    if end > bytes.len() {
366        return None;
367    }
368    if !is_word_boundary_for_keyword(bytes, start.saturating_sub(1))
369        || !is_word_boundary_for_keyword(bytes, end)
370    {
371        return None;
372    }
373    let matches = bytes[start..end]
374        .iter()
375        .zip(keyword_upper.iter())
376        .all(|(actual, expected)| actual.to_ascii_uppercase() == *expected);
377    if matches {
378        Some(end)
379    } else {
380        None
381    }
382}
383
384fn is_ascii_whitespace_byte(byte: u8) -> bool {
385    matches!(byte, b' ' | b'\n' | b'\r' | b'\t' | 0x0b | 0x0c)
386}
387
388fn skip_ascii_whitespace(bytes: &[u8], mut index: usize) -> usize {
389    while index < bytes.len() && is_ascii_whitespace_byte(bytes[index]) {
390        index += 1;
391    }
392    index
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use crate::linter::rule::with_active_dialect;
399    use crate::parser::{parse_sql, parse_sql_with_dialect};
400    use crate::types::IssueAutofixApplicability;
401    use crate::Dialect;
402
403    fn run(sql: &str) -> Vec<Issue> {
404        let statements = parse_sql_with_dialect(sql, Dialect::Mssql).expect("parse");
405        let rule = TsqlProcedureBeginEnd;
406        with_active_dialect(Dialect::Mssql, || {
407            statements
408                .iter()
409                .enumerate()
410                .flat_map(|(index, statement)| {
411                    rule.check(
412                        statement,
413                        &LintContext {
414                            sql,
415                            statement_range: 0..sql.len(),
416                            statement_index: index,
417                        },
418                    )
419                })
420                .collect()
421        })
422    }
423
424    fn run_statementless(sql: &str) -> Vec<Issue> {
425        let synthetic = parse_sql("SELECT 1").expect("parse synthetic statement");
426        let rule = TsqlProcedureBeginEnd;
427        with_active_dialect(Dialect::Mssql, || {
428            rule.check(
429                &synthetic[0],
430                &LintContext {
431                    sql,
432                    statement_range: 0..sql.len(),
433                    statement_index: 0,
434                },
435            )
436        })
437    }
438
439    fn apply_issue_autofix(sql: &str, issue: &Issue) -> Option<String> {
440        let autofix = issue.autofix.as_ref()?;
441        let mut out = sql.to_string();
442        let mut edits = autofix.edits.clone();
443        edits.sort_by_key(|edit| (edit.span.start, edit.span.end));
444        for edit in edits.into_iter().rev() {
445            out.replace_range(edit.span.start..edit.span.end, &edit.replacement);
446        }
447        Some(out)
448    }
449
450    #[test]
451    fn does_not_flag_single_statement_procedure_without_begin_end() {
452        let sql = "CREATE PROCEDURE p AS SELECT 1;";
453        let issues = run(sql);
454        assert!(issues.is_empty());
455    }
456
457    #[test]
458    fn does_not_flag_procedure_with_begin_end() {
459        let issues = run("CREATE PROCEDURE p AS BEGIN SELECT 1; END;");
460        assert!(issues.is_empty());
461    }
462
463    #[test]
464    fn detects_multi_statement_create_procedure_in_statementless_mode() {
465        let sql = "CREATE PROCEDURE p AS SELECT 1; SELECT 2;";
466        let issues = run_statementless(sql);
467        assert_eq!(issues.len(), 1);
468        assert_eq!(issues[0].code, issue_codes::LINT_TQ_002);
469        let autofix = issues[0].autofix.as_ref().expect("autofix metadata");
470        assert_eq!(autofix.applicability, IssueAutofixApplicability::Safe);
471        let fixed = apply_issue_autofix(sql, &issues[0]).expect("apply autofix");
472        assert_eq!(fixed, "CREATE PROCEDURE p AS BEGIN SELECT 1; SELECT 2; END");
473    }
474
475    #[test]
476    fn detects_alter_procedure_in_statementless_mode() {
477        let sql = "ALTER PROCEDURE dbo.p AS SELECT 1; SELECT 2;";
478        let issues = run_statementless(sql);
479        assert_eq!(issues.len(), 1);
480    }
481
482    #[test]
483    fn detects_create_or_alter_procedure_in_statementless_mode() {
484        let sql = "CREATE OR ALTER PROCEDURE dbo.p AS SELECT 1; SELECT 2;";
485        let issues = run_statementless(sql);
486        assert_eq!(issues.len(), 1);
487    }
488
489    #[test]
490    fn does_not_flag_external_name_statementless_procedure() {
491        let sql = "CREATE PROCEDURE dbo.ExternalProc AS EXTERNAL NAME Assembly.Class.Method;";
492        let issues = run_statementless(sql);
493        assert!(issues.is_empty());
494    }
495
496    #[test]
497    fn does_not_flag_procedure_text_inside_string_literal() {
498        let issues = run_statementless("SELECT 'CREATE PROCEDURE p AS SELECT 1' AS sql_snippet");
499        assert!(issues.is_empty());
500    }
501}