use anyhow::Result;
#[derive(Debug, Clone, PartialEq)]
pub enum ScriptDirective {
Skip,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ScriptStatementType {
Query(String),
Exit(Option<i32>),
}
#[derive(Debug, Clone)]
pub struct ScriptStatement {
pub statement_type: ScriptStatementType,
pub directives: Vec<ScriptDirective>,
}
impl ScriptStatement {
pub fn should_skip(&self) -> bool {
self.directives.contains(&ScriptDirective::Skip)
}
pub fn is_exit(&self) -> bool {
matches!(self.statement_type, ScriptStatementType::Exit(_))
}
pub fn get_exit_code(&self) -> Option<i32> {
match &self.statement_type {
ScriptStatementType::Exit(code) => Some(code.unwrap_or(0)),
_ => None,
}
}
pub fn get_query(&self) -> Option<&str> {
match &self.statement_type {
ScriptStatementType::Query(sql) => Some(sql),
ScriptStatementType::Exit(_) => None,
}
}
}
pub struct ScriptParser {
content: String,
data_file_hint: Option<String>,
}
impl ScriptParser {
pub fn new(content: &str) -> Self {
let data_file_hint = Self::extract_data_file_hint(content);
Self {
content: content.to_string(),
data_file_hint,
}
}
fn extract_data_file_hint(content: &str) -> Option<String> {
for line in content.lines() {
let trimmed = line.trim();
if !trimmed.starts_with("--") {
continue;
}
let comment_content = trimmed.strip_prefix("--").unwrap().trim();
if let Some(path) = comment_content.strip_prefix("#!data:") {
return Some(path.trim().to_string());
}
if let Some(path) = comment_content.strip_prefix("#!datafile:") {
return Some(path.trim().to_string());
}
if let Some(path) = comment_content.strip_prefix("#!") {
let path = path.trim();
if path.contains('.') || path.contains('/') || path.contains('\\') {
return Some(path.to_string());
}
}
}
None
}
pub fn data_file_hint(&self) -> Option<&str> {
self.data_file_hint.as_deref()
}
fn parse_directives(comment_lines: &[String]) -> Vec<ScriptDirective> {
let mut directives = Vec::new();
for line in comment_lines {
let trimmed = line.trim();
if !trimmed.starts_with("--") {
continue;
}
let comment_content = trimmed.strip_prefix("--").unwrap().trim();
if comment_content.eq_ignore_ascii_case("[skip]")
|| comment_content.eq_ignore_ascii_case("[ignore]")
{
directives.push(ScriptDirective::Skip);
}
}
directives
}
pub fn parse_script_statements(&self) -> Vec<ScriptStatement> {
let mut statements = Vec::new();
let mut current_statement = String::new();
let mut pending_comments = Vec::new();
for line in self.content.lines() {
let trimmed = line.trim();
if trimmed.eq_ignore_ascii_case("go") {
let statement = current_statement.trim().to_string();
if !statement.is_empty() && !Self::is_comment_only(&statement) {
let directives = Self::parse_directives(&pending_comments);
let statement_type = Self::parse_exit_statement(&statement)
.unwrap_or_else(|| ScriptStatementType::Query(statement));
statements.push(ScriptStatement {
statement_type,
directives,
});
}
current_statement.clear();
pending_comments.clear();
} else if trimmed.starts_with("--") {
pending_comments.push(line.to_string());
if !current_statement.is_empty() {
current_statement.push('\n');
}
current_statement.push_str(line);
} else {
if !current_statement.is_empty() {
current_statement.push('\n');
}
current_statement.push_str(line);
}
}
let statement = current_statement.trim().to_string();
if !statement.is_empty() && !Self::is_comment_only(&statement) {
let directives = Self::parse_directives(&pending_comments);
let statement_type = Self::parse_exit_statement(&statement)
.unwrap_or_else(|| ScriptStatementType::Query(statement));
statements.push(ScriptStatement {
statement_type,
directives,
});
}
statements
}
fn parse_exit_statement(statement: &str) -> Option<ScriptStatementType> {
let mut non_comment_lines = Vec::new();
for line in statement.lines() {
let trimmed = line.trim();
if !trimmed.is_empty() && !trimmed.starts_with("--") {
non_comment_lines.push(trimmed);
}
}
if non_comment_lines.is_empty() {
return None;
}
let content = non_comment_lines.join(" ");
let trimmed = content.trim().trim_end_matches(';').trim();
if trimmed.eq_ignore_ascii_case("exit") {
return Some(ScriptStatementType::Exit(None));
}
let parts: Vec<&str> = trimmed.split_whitespace().collect();
if parts.len() == 2 && parts[0].eq_ignore_ascii_case("exit") {
if let Ok(code) = parts[1].parse::<i32>() {
return Some(ScriptStatementType::Exit(Some(code)));
}
}
None
}
pub fn parse_statements(&self) -> Vec<String> {
self.parse_script_statements()
.into_iter()
.filter_map(|stmt| match stmt.statement_type {
ScriptStatementType::Query(sql) => Some(sql),
ScriptStatementType::Exit(_) => None,
})
.collect()
}
fn is_comment_only(statement: &str) -> bool {
for line in statement.lines() {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with("--") {
continue;
}
return false;
}
true
}
pub fn parse_and_validate(&self) -> Result<Vec<String>> {
let statements = self.parse_statements();
if statements.is_empty() {
anyhow::bail!("No SQL statements found in script");
}
for (i, stmt) in statements.iter().enumerate() {
if stmt.trim().is_empty() {
anyhow::bail!("Empty statement at position {}", i + 1);
}
}
Ok(statements)
}
}
#[derive(Debug)]
pub struct StatementResult {
pub statement_number: usize,
pub sql: String,
pub success: bool,
pub rows_affected: usize,
pub error_message: Option<String>,
pub execution_time_ms: f64,
}
#[derive(Debug)]
pub struct ScriptResult {
pub total_statements: usize,
pub successful_statements: usize,
pub failed_statements: usize,
pub total_execution_time_ms: f64,
pub statement_results: Vec<StatementResult>,
}
impl ScriptResult {
pub fn new() -> Self {
Self {
total_statements: 0,
successful_statements: 0,
failed_statements: 0,
total_execution_time_ms: 0.0,
statement_results: Vec::new(),
}
}
pub fn add_success(&mut self, statement_number: usize, sql: String, rows: usize, time_ms: f64) {
self.total_statements += 1;
self.successful_statements += 1;
self.total_execution_time_ms += time_ms;
self.statement_results.push(StatementResult {
statement_number,
sql,
success: true,
rows_affected: rows,
error_message: None,
execution_time_ms: time_ms,
});
}
pub fn add_failure(
&mut self,
statement_number: usize,
sql: String,
error: String,
time_ms: f64,
) {
self.total_statements += 1;
self.failed_statements += 1;
self.total_execution_time_ms += time_ms;
self.statement_results.push(StatementResult {
statement_number,
sql,
success: false,
rows_affected: 0,
error_message: Some(error),
execution_time_ms: time_ms,
});
}
pub fn all_successful(&self) -> bool {
self.failed_statements == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_single_statement() {
let script = "SELECT * FROM users";
let parser = ScriptParser::new(script);
let statements = parser.parse_statements();
assert_eq!(statements.len(), 1);
assert_eq!(statements[0], "SELECT * FROM users");
}
#[test]
fn test_parse_multiple_statements_with_go() {
let script = r"
SELECT * FROM users
GO
SELECT * FROM orders
GO
SELECT * FROM products
";
let parser = ScriptParser::new(script);
let statements = parser.parse_statements();
assert_eq!(statements.len(), 3);
assert_eq!(statements[0].trim(), "SELECT * FROM users");
assert_eq!(statements[1].trim(), "SELECT * FROM orders");
assert_eq!(statements[2].trim(), "SELECT * FROM products");
}
#[test]
fn test_go_case_insensitive() {
let script = r"
SELECT 1
go
SELECT 2
Go
SELECT 3
GO
";
let parser = ScriptParser::new(script);
let statements = parser.parse_statements();
assert_eq!(statements.len(), 3);
}
#[test]
fn test_go_in_string_not_separator() {
let script = r"
SELECT 'This string contains GO but should not split' as test
GO
SELECT 'Another statement' as test2
";
let parser = ScriptParser::new(script);
let statements = parser.parse_statements();
assert_eq!(statements.len(), 2);
assert!(statements[0].contains("GO but should not split"));
}
#[test]
fn test_multiline_statements() {
let script = r"
SELECT
id,
name,
email
FROM users
WHERE active = true
GO
SELECT COUNT(*)
FROM orders
";
let parser = ScriptParser::new(script);
let statements = parser.parse_statements();
assert_eq!(statements.len(), 2);
assert!(statements[0].contains("WHERE active = true"));
}
#[test]
fn test_empty_statements_filtered() {
let script = r"
GO
SELECT 1
GO
GO
SELECT 2
GO
";
let parser = ScriptParser::new(script);
let statements = parser.parse_statements();
assert_eq!(statements.len(), 2);
assert_eq!(statements[0].trim(), "SELECT 1");
assert_eq!(statements[1].trim(), "SELECT 2");
}
}