use crate::parser::{parse_sql_with_dialect, parse_sql_with_dialect_output};
use crate::types::{issue_codes, AnalyzeRequest, Dialect, Issue, Span};
use sqlparser::ast::Statement;
use sqlparser::dialect::MsSqlDialect;
use sqlparser::tokenizer::{Token, Tokenizer};
use std::borrow::Cow;
use std::ops::Range;
use std::rc::Rc;
use thiserror::Error;
#[cfg(feature = "templating")]
use crate::templater::{template_sql, TemplateMode};
const MAX_MERGE_ITERATIONS: usize = 10_000;
#[cfg(feature = "templating")]
fn template_error_issue(
error: &crate::templater::TemplateError,
source_name: Option<&str>,
) -> Issue {
let message = match source_name {
Some(name) => format!("Template error in {name}: {error}"),
None => format!("Template error: {error}"),
};
let mut issue = Issue::error(issue_codes::TEMPLATE_ERROR, message);
if let Some(name) = source_name {
issue = issue.with_source_name(name);
}
issue
}
#[cfg(feature = "templating")]
fn apply_template<'a>(
sql: &'a str,
config: Option<&crate::templater::TemplateConfig>,
) -> Result<(Cow<'a, str>, bool), crate::templater::TemplateError> {
match config {
Some(cfg) if cfg.mode != TemplateMode::Raw => {
let rendered = template_sql(sql, cfg)?;
Ok((Cow::Owned(rendered), true))
}
_ => Ok((Cow::Borrowed(sql), false)),
}
}
#[derive(Debug, Error)]
enum RangeAlignmentError {
#[error("no ranges provided when {0} statements were expected")]
NoRanges(usize),
#[error("fewer ranges ({0}) than statements ({1}), cannot split ranges")]
FewerRangesThanStatements(usize, usize),
#[error("failed to merge ranges to match statement count")]
MergeFailed,
#[error("iteration limit ({0}) exceeded during merge, possible infinite loop")]
IterationLimitExceeded(usize),
#[error("range end ({0}) exceeds source SQL length ({1})")]
OutOfBounds(usize, usize),
#[error("invalid range: start ({0}) > end ({1})")]
InvalidRange(usize, usize),
}
struct ParseContext<'a> {
source_sql: Cow<'a, str>,
source_name: Option<Rc<String>>,
dialect: Dialect,
untemplated_sql: Option<Cow<'a, str>>,
templating_applied: bool,
}
pub(crate) struct StatementInput<'a> {
pub(crate) statement: Statement,
pub(crate) source_name: Option<Rc<String>>,
pub(crate) source_sql: Cow<'a, str>,
pub(crate) source_range: Range<usize>,
pub(crate) source_sql_untemplated: Option<Cow<'a, str>>,
pub(crate) source_range_untemplated: Option<Range<usize>>,
pub(crate) templating_applied: bool,
pub(crate) parser_fallback_used: bool,
}
pub(crate) fn collect_statements<'a>(
request: &'a AnalyzeRequest,
) -> (Vec<StatementInput<'a>>, Vec<Issue>) {
let mut issues = Vec::new();
let mut statements = Vec::new();
let has_sql = !request.sql.trim().is_empty();
let has_files = request
.files
.as_ref()
.map(|files| !files.is_empty())
.unwrap_or(false);
if !has_sql && !has_files {
issues.push(Issue::error(
issue_codes::INVALID_REQUEST,
"Provide inline SQL or at least one file to analyze",
));
return (Vec::new(), issues);
}
if let Some(files) = &request.files {
for file in files {
#[cfg(feature = "templating")]
let (source_sql, templating_applied): (Cow<'_, str>, bool) = {
match apply_template(&file.content, request.template_config.as_ref()) {
Ok((sql, applied)) => (sql, applied),
Err(e) => {
issues.push(template_error_issue(&e, Some(&file.name)));
continue; }
}
};
#[cfg(not(feature = "templating"))]
let (source_sql, templating_applied): (Cow<'_, str>, bool) =
(Cow::Borrowed(file.content.as_str()), false);
let ctx = ParseContext {
source_sql,
source_name: Some(Rc::new(file.name.clone())),
dialect: request.dialect,
untemplated_sql: templating_applied.then_some(Cow::Borrowed(file.content.as_str())),
templating_applied,
};
let (file_stmts, file_issues) = parse_statements_individually(&ctx);
statements.extend(file_stmts);
issues.extend(file_issues);
}
}
if has_sql {
#[cfg(feature = "templating")]
let (source_sql, templating_applied): (Cow<'_, str>, bool) = {
match apply_template(&request.sql, request.template_config.as_ref()) {
Ok((sql, applied)) => (sql, applied),
Err(e) => {
issues.push(template_error_issue(&e, request.source_name.as_deref()));
return (statements, issues);
}
}
};
#[cfg(not(feature = "templating"))]
let (source_sql, templating_applied): (Cow<'_, str>, bool) =
(Cow::Borrowed(request.sql.as_str()), false);
let ctx = ParseContext {
source_sql,
source_name: request.source_name.clone().map(Rc::new),
dialect: request.dialect,
untemplated_sql: templating_applied.then_some(Cow::Borrowed(request.sql.as_str())),
templating_applied,
};
let (inline_stmts, inline_issues) = parse_statements_individually(&ctx);
statements.extend(inline_stmts);
issues.extend(inline_issues);
}
(statements, issues)
}
fn parse_statements_individually<'a>(
ctx: &ParseContext<'a>,
) -> (Vec<StatementInput<'a>>, Vec<Issue>) {
let statement_ranges = compute_statement_ranges_for_dialect(&ctx.source_sql, ctx.dialect);
match parse_full_sql_buffer(ctx, &statement_ranges) {
Ok(statements) => (statements, Vec::new()),
Err(fallback_error) => {
let (statements, mut issues) =
parse_statement_ranges_best_effort(ctx, statement_ranges);
if let Some(error) = fallback_error {
let source_info = ctx
.source_name
.as_deref()
.map(|n| format!(" in {n}"))
.unwrap_or_default();
let message = format!(
"Full SQL parsing failed{source_info}, using best-effort mode: {error}"
);
let mut issue = Issue::warning(issue_codes::PARSE_ERROR, message);
if let Some(name) = ctx.source_name.as_deref() {
issue = issue.with_source_name(name);
}
issues.insert(0, issue);
}
(statements, issues)
}
}
}
fn parse_full_sql_buffer<'a>(
ctx: &ParseContext<'a>,
statement_ranges: &[Range<usize>],
) -> Result<Vec<StatementInput<'a>>, Option<RangeAlignmentError>> {
let parsed_output =
parse_sql_with_dialect_output(&ctx.source_sql, ctx.dialect).map_err(|_| None)?;
let parser_fallback_used = parsed_output.parser_fallback_used;
let parsed = parsed_output.statements;
if parsed.is_empty() {
return Ok(Vec::new());
}
let aligned_ranges = match align_statement_ranges(
&ctx.source_sql,
statement_ranges,
ctx.dialect,
parsed.len(),
) {
Ok(ranges) => ranges,
Err(e) => {
#[cfg(feature = "tracing")]
tracing::debug!(
source = ?ctx.source_name.as_deref(),
error = %e,
"Failed to align statement ranges, falling back to best-effort parsing"
);
return Err(Some(e));
}
};
let aligned_untemplated_ranges = ctx.untemplated_sql.as_deref().and_then(|sql| {
let ranges = compute_statement_ranges_for_dialect(sql, ctx.dialect);
align_statement_ranges(sql, &ranges, ctx.dialect, parsed.len()).ok()
});
let mut statements = Vec::with_capacity(parsed.len());
for (index, (stmt, range)) in parsed.into_iter().zip(aligned_ranges).enumerate() {
statements.push(StatementInput {
statement: stmt,
source_name: ctx.source_name.clone(),
source_sql: ctx.source_sql.clone(),
source_range: range,
source_sql_untemplated: ctx.untemplated_sql.clone(),
source_range_untemplated: aligned_untemplated_ranges
.as_ref()
.and_then(|ranges| ranges.get(index).cloned()),
templating_applied: ctx.templating_applied,
parser_fallback_used,
});
}
Ok(statements)
}
fn align_statement_ranges(
source_sql: &str,
statement_ranges: &[Range<usize>],
dialect: Dialect,
statement_count: usize,
) -> Result<Vec<Range<usize>>, RangeAlignmentError> {
if statement_count == 0 {
return Ok(Vec::new());
}
if statement_ranges.is_empty() {
return Err(RangeAlignmentError::NoRanges(statement_count));
}
if statement_ranges.len() == statement_count {
return Ok(statement_ranges.to_vec());
}
if statement_ranges.len() < statement_count {
return Err(RangeAlignmentError::FewerRangesThanStatements(
statement_ranges.len(),
statement_count,
));
}
merge_statement_ranges(source_sql, statement_ranges, dialect, statement_count)
}
fn merge_statement_ranges(
source_sql: &str,
statement_ranges: &[Range<usize>],
dialect: Dialect,
statement_count: usize,
) -> Result<Vec<Range<usize>>, RangeAlignmentError> {
let mut merged = Vec::with_capacity(statement_count);
let mut range_index = 0usize;
for _ in 0..statement_count {
if range_index >= statement_ranges.len() {
return Err(RangeAlignmentError::MergeFailed);
}
let mut statement_iterations = 0usize;
let mut current_range = statement_ranges[range_index].clone();
range_index += 1;
loop {
statement_iterations += 1;
if statement_iterations > MAX_MERGE_ITERATIONS {
return Err(RangeAlignmentError::IterationLimitExceeded(
MAX_MERGE_ITERATIONS,
));
}
if current_range.start > current_range.end {
return Err(RangeAlignmentError::InvalidRange(
current_range.start,
current_range.end,
));
}
if current_range.end > source_sql.len() {
return Err(RangeAlignmentError::OutOfBounds(
current_range.end,
source_sql.len(),
));
}
let snippet = &source_sql[current_range.clone()];
match parse_sql_with_dialect(snippet, dialect) {
Ok(parsed) if parsed.len() == 1 => {
merged.push(current_range);
break;
}
_ => {
if range_index >= statement_ranges.len() {
return Err(RangeAlignmentError::MergeFailed);
}
current_range = current_range.start..statement_ranges[range_index].end;
range_index += 1;
}
}
}
}
if range_index != statement_ranges.len() {
return Err(RangeAlignmentError::MergeFailed);
}
Ok(merged)
}
fn parse_statement_ranges_best_effort<'a>(
ctx: &ParseContext<'a>,
statement_ranges: Vec<Range<usize>>,
) -> (Vec<StatementInput<'a>>, Vec<Issue>) {
let mut statements = Vec::new();
let mut issues = Vec::new();
let source_sql_ref: &str = &ctx.source_sql;
for range in statement_ranges {
if range.start > range.end || range.end > source_sql_ref.len() {
continue;
}
let statement_sql = &source_sql_ref[range.clone()];
match parse_sql_with_dialect_output(statement_sql, ctx.dialect) {
Ok(parsed_output) => {
let parser_fallback_used = parsed_output.parser_fallback_used;
for stmt in parsed_output.statements {
statements.push(StatementInput {
statement: stmt,
source_name: ctx.source_name.clone(),
source_sql: ctx.source_sql.clone(),
source_range: range.clone(),
source_sql_untemplated: ctx.untemplated_sql.clone(),
source_range_untemplated: None,
templating_applied: ctx.templating_applied,
parser_fallback_used,
});
}
}
Err(e) => {
let message = match ctx.source_name.as_deref() {
Some(name) => format!("Parse error in {name}: {e}"),
None => format!("Parse error: {e}"),
};
let mut issue = Issue::error(issue_codes::PARSE_ERROR, message)
.with_span(Span::new(range.start, range.end));
if let Some(name) = ctx.source_name.as_deref() {
issue = issue.with_source_name(name);
}
issues.push(issue);
}
}
}
(statements, issues)
}
pub(crate) fn split_statement_spans_with_dialect(sql: &str, dialect: Dialect) -> Vec<Span> {
compute_statement_ranges_for_dialect(sql, dialect)
.into_iter()
.map(|range| Span::new(range.start, range.end))
.collect()
}
fn compute_statement_ranges_for_dialect(sql: &str, dialect: Dialect) -> Vec<Range<usize>> {
let ranges = compute_statement_ranges(sql);
if !matches!(dialect, Dialect::Mssql) {
return ranges;
}
split_ranges_on_mssql_go_separators(sql, ranges)
}
fn split_ranges_on_mssql_go_separators(sql: &str, ranges: Vec<Range<usize>>) -> Vec<Range<usize>> {
let go_line_ranges = mssql_go_line_ranges(sql);
if go_line_ranges.is_empty() {
return ranges;
}
let mut out = Vec::new();
for range in ranges {
let mut cursor = range.start;
for go_range in &go_line_ranges {
if go_range.start < range.start || go_range.end > range.end || go_range.start < cursor {
continue;
}
if let Some(chunk) = trim_statement_range(sql, cursor, go_range.start) {
out.push(chunk);
}
cursor = go_range.end;
}
if let Some(chunk) = trim_statement_range(sql, cursor, range.end) {
out.push(chunk);
}
}
out
}
fn mssql_go_line_ranges(sql: &str) -> Vec<Range<usize>> {
let mut tokenizer = Tokenizer::new(&MsSqlDialect {}, sql);
let Ok(tokens) = tokenizer.tokenize_with_location() else {
return Vec::new();
};
let mut go_lines = Vec::new();
for token in tokens {
let Token::Word(word) = token.token else {
continue;
};
if !word.value.eq_ignore_ascii_case("GO") {
continue;
}
let line = token.span.start.line as usize;
if line_is_go_separator(sql, line) {
go_lines.push(line);
}
}
if go_lines.is_empty() {
return Vec::new();
}
go_lines.sort_unstable();
go_lines.dedup();
go_lines
.into_iter()
.filter_map(|line| line_byte_range(sql, line))
.collect()
}
fn line_is_go_separator(sql: &str, line_number: usize) -> bool {
line_text(sql, line_number).is_some_and(|line| line.trim().eq_ignore_ascii_case("GO"))
}
fn line_text(sql: &str, line_number: usize) -> Option<&str> {
let range = line_byte_range(sql, line_number)?;
let line = &sql[range];
Some(line.trim_end_matches(['\n', '\r']))
}
fn line_byte_range(sql: &str, line_number: usize) -> Option<Range<usize>> {
if line_number == 0 {
return None;
}
let bytes = sql.as_bytes();
let mut starts = vec![0usize];
for (idx, byte) in bytes.iter().enumerate() {
if *byte == b'\n' {
starts.push(idx + 1);
}
}
let start = *starts.get(line_number - 1)?;
let end = starts.get(line_number).copied().unwrap_or(sql.len());
Some(start..end)
}
fn compute_statement_ranges(sql: &str) -> Vec<Range<usize>> {
let mut ranges = Vec::new();
if sql.is_empty() {
return ranges;
}
let mut start = 0usize;
let mut i = 0usize;
let len = sql.len();
let mut in_single_quote = false;
let mut in_double_quote = false;
let mut in_backtick = false;
let mut in_bracket = false;
let mut in_line_comment = false;
let mut in_block_comment = false;
let mut dollar_delimiter: Option<String> = None;
while i < len {
if let Some(delim) = &dollar_delimiter {
if sql[i..].starts_with(delim) {
i += delim.len();
dollar_delimiter = None;
} else {
let (_, advance) = next_char(sql, i);
i += advance;
}
continue;
}
if in_line_comment {
let (ch, advance) = next_char(sql, i);
i += advance;
if ch == '\n' || ch == '\r' {
in_line_comment = false;
}
continue;
}
if in_block_comment {
if starts_with_at(sql, i, "*/") {
i += 2;
in_block_comment = false;
} else {
let (_, advance) = next_char(sql, i);
i += advance;
}
continue;
}
if in_single_quote {
let (ch, advance) = next_char(sql, i);
i += advance;
if ch == '\'' {
if let Some((next, next_len)) = char_at(sql, i) {
if next == '\'' {
i += next_len;
} else {
in_single_quote = false;
}
} else {
in_single_quote = false;
}
}
continue;
}
if in_double_quote {
let (ch, advance) = next_char(sql, i);
i += advance;
if ch == '"' {
if let Some((next, next_len)) = char_at(sql, i) {
if next == '"' {
i += next_len;
} else {
in_double_quote = false;
}
} else {
in_double_quote = false;
}
}
continue;
}
if in_backtick {
let (ch, advance) = next_char(sql, i);
i += advance;
if ch == '`' {
if let Some((next, next_len)) = char_at(sql, i) {
if next == '`' {
i += next_len;
} else {
in_backtick = false;
}
} else {
in_backtick = false;
}
}
continue;
}
if in_bracket {
let (ch, advance) = next_char(sql, i);
i += advance;
if ch == ']' {
if let Some((next, next_len)) = char_at(sql, i) {
if next == ']' {
i += next_len;
} else {
in_bracket = false;
}
} else {
in_bracket = false;
}
}
continue;
}
let (ch, advance) = next_char(sql, i);
match ch {
'\'' => {
in_single_quote = true;
i += advance;
continue;
}
'"' => {
in_double_quote = true;
i += advance;
continue;
}
'`' => {
in_backtick = true;
i += advance;
continue;
}
'[' => {
in_bracket = true;
i += advance;
continue;
}
'-' if starts_with_at(sql, i + advance, "-") => {
in_line_comment = true;
i += advance + 1;
continue;
}
'#' => {
in_line_comment = true;
i += advance;
continue;
}
'/' if starts_with_at(sql, i + advance, "*") => {
in_block_comment = true;
i += advance + 1;
continue;
}
'$' => {
if let Some((delim, end_idx)) = detect_dollar_quote(sql, i) {
dollar_delimiter = Some(delim);
i = end_idx;
continue;
}
}
';' => {
push_statement_range(&mut ranges, sql, start, i);
start = i + advance;
}
_ => {}
}
i += advance;
}
push_statement_range(&mut ranges, sql, start, len);
ranges
}
fn detect_dollar_quote(sql: &str, start: usize) -> Option<(String, usize)> {
let len = sql.len();
if start + 1 >= len {
return None;
}
let mut idx = start + 1;
while idx < len {
let (ch, advance) = next_char(sql, idx);
idx += advance;
if ch == '$' {
let delimiter = sql[start..idx].to_string();
return Some((delimiter, idx));
}
if !(ch == '_' || ch.is_ascii_alphanumeric()) {
return None;
}
}
None
}
fn starts_with_at(sql: &str, index: usize, pattern: &str) -> bool {
if index >= sql.len() {
return false;
}
if !sql.is_char_boundary(index) {
return false;
}
sql[index..].starts_with(pattern)
}
fn next_char(sql: &str, index: usize) -> (char, usize) {
debug_assert!(sql.is_char_boundary(index));
let mut iter = sql[index..].char_indices();
let (_, ch) = iter.next().expect("index should point to a char boundary");
let advance = ch.len_utf8();
(ch, advance)
}
fn char_at(sql: &str, index: usize) -> Option<(char, usize)> {
if index >= sql.len() {
return None;
}
if !sql.is_char_boundary(index) {
return None;
}
let mut iter = sql[index..].char_indices();
let (_, ch) = iter.next().expect("index should point to a char boundary");
let advance = ch.len_utf8();
Some((ch, advance))
}
fn push_statement_range(ranges: &mut Vec<Range<usize>>, sql: &str, start: usize, end: usize) {
if let Some(range) = trim_statement_range(sql, start, end) {
ranges.push(range);
}
}
fn trim_statement_range(sql: &str, start: usize, end: usize) -> Option<Range<usize>> {
if start >= end {
return None;
}
let mut s = start;
let mut e = end;
let bytes = sql.as_bytes();
while s < e {
if s + 1 < e {
let first = bytes[s];
let second = bytes[s + 1];
if first == b'-' && second == b'-' {
s = skip_line_comment(bytes, s + 2, e);
continue;
}
if first == b'/' && second == b'*' {
s = skip_block_comment(bytes, s + 2, e);
continue;
}
}
let b = bytes[s];
match b {
b'#' => {
s = skip_line_comment(bytes, s + 1, e);
}
b' ' | b'\t' | b'\r' | b'\n' => {
s += 1;
}
_ => break,
}
}
while s < e {
let b = bytes[e - 1];
match b {
b' ' | b'\t' | b'\r' | b'\n' => {
e -= 1;
}
_ => break,
}
}
if s >= e {
return None;
}
Some(s..e)
}
fn skip_line_comment(bytes: &[u8], mut index: usize, end: usize) -> usize {
while index < end {
let byte = bytes[index];
index += 1;
if byte == b'\n' || byte == b'\r' {
break;
}
}
index
}
fn skip_block_comment(bytes: &[u8], mut index: usize, end: usize) -> usize {
while index < end {
if index + 1 < end && bytes[index] == b'*' && bytes[index + 1] == b'/' {
return index + 2;
}
index += 1;
}
end
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Dialect, FileSource};
fn base_request() -> AnalyzeRequest {
AnalyzeRequest {
sql: String::new(),
files: None,
dialect: Dialect::Generic,
source_name: None,
options: None,
schema: None,
#[cfg(feature = "templating")]
template_config: None,
}
}
#[test]
fn collects_file_and_inline_statements() {
let mut request = base_request();
request.sql = "SELECT 2".to_string();
request.source_name = Some("inline.sql".to_string());
request.files = Some(vec![FileSource {
name: "file.sql".to_string(),
content: "SELECT 1".to_string(),
}]);
let (statements, issues) = collect_statements(&request);
assert!(issues.is_empty());
assert_eq!(statements.len(), 2);
assert_eq!(
statements[0].source_name.as_deref().map(String::as_str),
Some("file.sql")
);
assert_eq!(
statements[0].source_sql[statements[0].source_range.clone()].trim(),
"SELECT 1"
);
assert_eq!(
statements[1].source_name.as_deref().map(String::as_str),
Some("inline.sql")
);
assert_eq!(
statements[1].source_sql[statements[1].source_range.clone()].trim(),
"SELECT 2"
);
}
#[test]
fn reports_invalid_request_without_inputs() {
let request = base_request();
let (_statements, issues) = collect_statements(&request);
assert_eq!(issues.len(), 1);
assert_eq!(issues[0].code, issue_codes::INVALID_REQUEST);
}
#[test]
fn statement_ranges_respect_strings() {
let sql = "SELECT ';' as value;SELECT 2;";
let ranges = compute_statement_ranges(sql);
assert_eq!(ranges.len(), 2);
assert_eq!(&sql[ranges[0].clone()], "SELECT ';' as value");
assert_eq!(&sql[ranges[1].clone()], "SELECT 2");
}
#[test]
fn statement_ranges_skip_comments() {
let sql = "SELECT 1; -- comment; still comment\nSELECT 2; /* block; comment */ SELECT 3;";
let ranges = compute_statement_ranges(sql);
assert_eq!(ranges.len(), 3);
assert_eq!(&sql[ranges[0].clone()], "SELECT 1");
assert_eq!(&sql[ranges[1].clone()], "SELECT 2");
assert_eq!(&sql[ranges[2].clone()], "SELECT 3");
}
#[test]
fn statement_ranges_handle_dollar_quoting() {
let sql = "DO $$ BEGIN RAISE NOTICE ';'; END $$; SELECT 1;";
let ranges = compute_statement_ranges(sql);
assert_eq!(ranges.len(), 2);
assert_eq!(
&sql[ranges[0].clone()],
"DO $$ BEGIN RAISE NOTICE ';'; END $$"
);
assert_eq!(&sql[ranges[1].clone()], "SELECT 1");
}
#[test]
fn mssql_statement_ranges_split_go_batch_separators() {
let sql = "CREATE SCHEMA staging;\nGO\nCREATE TABLE test (id INT)\n";
let ranges = compute_statement_ranges_for_dialect(sql, Dialect::Mssql);
assert_eq!(ranges.len(), 2);
assert_eq!(&sql[ranges[0].clone()], "CREATE SCHEMA staging");
assert_eq!(&sql[ranges[1].clone()], "CREATE TABLE test (id INT)");
}
#[test]
fn collect_statements_mssql_go_batch_without_final_semicolon_parses_statements() {
let mut request = base_request();
request.dialect = Dialect::Mssql;
request.sql = "CREATE SCHEMA staging;\nGO\nCREATE TABLE test (id INT)\n".to_string();
let (statements, issues) = collect_statements(&request);
assert!(
issues.is_empty(),
"MSSQL GO separators should not produce parse errors: {issues:?}"
);
assert_eq!(statements.len(), 2);
}
#[test]
fn parses_procedure_with_inner_semicolons() {
let mut request = base_request();
request.dialect = Dialect::Snowflake;
request.sql = r#"
CREATE PROCEDURE demo()
LANGUAGE SQL
AS
BEGIN
SELECT 'a';
SELECT 'b';
RETURN 'done';
END;
SELECT 1;
"#
.to_string();
let (statements, issues) = collect_statements(&request);
assert!(issues.is_empty(), "Expected no issues, got {issues:?}");
assert_eq!(
statements.len(),
2,
"Expected procedure and trailing select"
);
assert!(matches!(
statements[0].statement,
Statement::CreateProcedure { .. }
));
let procedure_source = &statements[0].source_sql[statements[0].source_range.clone()];
assert!(
procedure_source.contains("SELECT 'b';") && procedure_source.contains("RETURN 'done';"),
"Procedure source should include entire body: {procedure_source:?}"
);
assert!(matches!(statements[1].statement, Statement::Query(_)));
}
#[test]
fn best_effort_parsing_continues_after_error() {
let mut request = base_request();
request.sql = r#"
SELECT 1 FROM users;
SELECT FROM;
SELECT 2 FROM orders;
"#
.to_string();
let (statements, issues) = collect_statements(&request);
assert_eq!(statements.len(), 2, "Expected 2 valid statements");
assert_eq!(issues.len(), 1, "Expected 1 parse error");
assert_eq!(issues[0].code, issue_codes::PARSE_ERROR);
assert!(issues[0].span.is_some(), "Error should have span info");
}
#[test]
fn best_effort_parsing_with_file_source() {
let mut request = base_request();
request.files = Some(vec![FileSource {
name: "test.sql".to_string(),
content: r#"
SELECT a FROM t1;
INVALID SYNTAX HERE;
SELECT b FROM t2;
"#
.to_string(),
}]);
let (statements, issues) = collect_statements(&request);
assert_eq!(statements.len(), 2, "Expected 2 valid statements");
assert_eq!(issues.len(), 1, "Expected 1 parse error");
assert!(
issues[0].message.contains("test.sql"),
"Error should mention file name"
);
assert_eq!(
issues[0].source_name.as_deref(),
Some("test.sql"),
"Issue should have source_name set"
);
}
#[test]
fn best_effort_parsing_multiple_errors() {
let mut request = base_request();
request.sql = r#"
SELECT 1;
BROKEN STATEMENT 1;
SELECT 2;
BROKEN STATEMENT 2;
SELECT 3;
"#
.to_string();
let (statements, issues) = collect_statements(&request);
assert_eq!(statements.len(), 3, "Expected 3 valid statements");
assert_eq!(issues.len(), 2, "Expected 2 parse errors");
}
#[test]
fn empty_sql_returns_no_statements() {
let sql = "";
let ranges = compute_statement_ranges(sql);
assert!(ranges.is_empty(), "Empty SQL should produce no ranges");
}
#[test]
fn whitespace_only_sql_returns_no_statements() {
let sql = " \n\t\r\n ";
let ranges = compute_statement_ranges(sql);
assert!(
ranges.is_empty(),
"Whitespace-only SQL should produce no ranges"
);
}
#[test]
fn comments_only_sql_returns_no_statements() {
let sql = "-- just a comment\n/* another comment */";
let ranges = compute_statement_ranges(sql);
assert!(
ranges.is_empty(),
"Comments-only SQL should produce no ranges"
);
}
#[test]
fn empty_inline_sql_with_valid_file() {
let mut request = base_request();
request.sql = " ".to_string(); request.files = Some(vec![FileSource {
name: "file.sql".to_string(),
content: "SELECT 1".to_string(),
}]);
let (statements, issues) = collect_statements(&request);
assert!(issues.is_empty());
assert_eq!(statements.len(), 1);
assert_eq!(
statements[0].source_name.as_deref().map(String::as_str),
Some("file.sql")
);
}
#[test]
fn statement_ranges_handle_unicode_identifiers() {
let sql = "SELECT '日本語' AS 名前; SELECT '🎉' AS emoji;";
let ranges = compute_statement_ranges(sql);
assert_eq!(ranges.len(), 2);
assert_eq!(&sql[ranges[0].clone()], "SELECT '日本語' AS 名前");
assert_eq!(&sql[ranges[1].clone()], "SELECT '🎉' AS emoji");
}
#[test]
fn statement_ranges_handle_unicode_in_strings() {
let sql = "SELECT '你好;世界' AS greeting; SELECT 2;";
let ranges = compute_statement_ranges(sql);
assert_eq!(ranges.len(), 2);
assert_eq!(&sql[ranges[0].clone()], "SELECT '你好;世界' AS greeting");
assert_eq!(&sql[ranges[1].clone()], "SELECT 2");
}
#[test]
fn statement_ranges_handle_mixed_ascii_unicode() {
let sql = "SELECT 'café' AS drink; SELECT 'naïve' AS word; SELECT 'Müller' AS name;";
let ranges = compute_statement_ranges(sql);
assert_eq!(ranges.len(), 3);
assert_eq!(&sql[ranges[0].clone()], "SELECT 'café' AS drink");
assert_eq!(&sql[ranges[1].clone()], "SELECT 'naïve' AS word");
assert_eq!(&sql[ranges[2].clone()], "SELECT 'Müller' AS name");
}
#[test]
fn unicode_sql_parses_correctly() {
let mut request = base_request();
request.sql = "SELECT '日本' AS country; SELECT 'émoji: 🚀' AS test;".to_string();
let (statements, issues) = collect_statements(&request);
assert!(issues.is_empty(), "Expected no issues, got {issues:?}");
assert_eq!(statements.len(), 2);
let first_sql = &statements[0].source_sql[statements[0].source_range.clone()];
let second_sql = &statements[1].source_sql[statements[1].source_range.clone()];
assert!(
first_sql.contains("日本"),
"First statement should contain Japanese: {first_sql}"
);
assert!(
second_sql.contains("🚀"),
"Second statement should contain rocket emoji: {second_sql}"
);
}
}