use std::{
fs,
io::{self, Read},
};
use crate::{
executor::SqlExecutor,
formatter::{OutputFormat, ResultFormatter},
};
pub struct ScriptExecutor {
executor: SqlExecutor,
formatter: ResultFormatter,
verbose: bool,
database_path: Option<String>,
}
impl ScriptExecutor {
pub fn new(
database: Option<String>,
verbose: bool,
format: Option<OutputFormat>,
) -> anyhow::Result<Self> {
let database_path = database.clone();
let executor = SqlExecutor::new(database)?;
let mut formatter = ResultFormatter::new();
if let Some(fmt) = format {
formatter.set_format(fmt);
}
Ok(ScriptExecutor { executor, formatter, verbose, database_path })
}
pub fn execute_file(&mut self, file_path: &str) -> anyhow::Result<()> {
let contents = fs::read_to_string(file_path)
.map_err(|e| anyhow::anyhow!("Failed to read file '{}': {}", file_path, e))?;
self.execute_script(&contents)
}
pub fn execute_stdin(&mut self) -> anyhow::Result<()> {
let mut contents = String::new();
io::stdin()
.read_to_string(&mut contents)
.map_err(|e| anyhow::anyhow!("Failed to read from stdin: {}", e))?;
self.execute_script(&contents)
}
pub fn execute_script(&mut self, script: &str) -> anyhow::Result<()> {
let statements = parse_statements(script);
if statements.is_empty() {
if self.verbose {
println!("No SQL statements found in script");
}
return Ok(());
}
let mut success_count = 0;
let mut error_count = 0;
for (idx, stmt) in statements.iter().enumerate() {
if self.verbose {
println!("Executing statement {} of {}...", idx + 1, statements.len());
}
match self.executor.execute(stmt) {
Ok(result) => {
self.formatter.print_result(&result);
success_count += 1;
if let Some(ref path) = self.database_path {
if is_modification_statement(stmt) {
if let Err(e) = self.executor.save_database(path) {
eprintln!("Warning: Failed to auto-save database: {}", e);
}
}
}
}
Err(e) => {
eprintln!("Error executing statement {}: {}", idx + 1, e);
error_count += 1;
}
}
}
if self.verbose || error_count > 0 {
println!("\n=== Script Execution Summary ===");
println!("Total statements: {}", statements.len());
println!("Successful: {}", success_count);
println!("Failed: {}", error_count);
}
if error_count > 0 {
Err(anyhow::anyhow!("{} statements failed", error_count))
} else {
Ok(())
}
}
}
fn is_modification_statement(sql: &str) -> bool {
let upper = sql.trim().to_uppercase();
upper.starts_with("CREATE ")
|| upper.starts_with("DROP ")
|| upper.starts_with("ALTER ")
|| upper.starts_with("INSERT ")
|| upper.starts_with("UPDATE ")
|| upper.starts_with("DELETE ")
}
fn parse_statements(script: &str) -> Vec<String> {
let mut statements = Vec::new();
let mut current_statement = String::new();
let mut in_string = false;
let mut in_multiline_comment = false;
let mut chars = script.chars().peekable();
while let Some(ch) = chars.next() {
if !in_string && ch == '/' && chars.peek() == Some(&'*') {
chars.next(); in_multiline_comment = true;
continue;
}
if in_multiline_comment {
if ch == '*' && chars.peek() == Some(&'/') {
chars.next(); in_multiline_comment = false;
}
continue;
}
if !in_string && ch == '-' && chars.peek() == Some(&'-') {
for c in chars.by_ref() {
if c == '\n' {
current_statement.push(c); break;
}
}
continue;
}
if ch == '\'' {
current_statement.push(ch);
if in_string && chars.peek() == Some(&'\'') {
chars.next(); current_statement.push('\'');
continue;
}
in_string = !in_string;
continue;
}
if !in_string && ch == ';' {
let trimmed = current_statement.trim();
if !trimmed.is_empty() {
statements.push(trimmed.to_string());
}
current_statement.clear();
continue;
}
current_statement.push(ch);
}
let trimmed = current_statement.trim();
if !trimmed.is_empty() {
statements.push(trimmed.to_string());
}
statements
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_single_statement() {
let script = "SELECT * FROM users;";
let stmts = parse_statements(script);
assert_eq!(stmts.len(), 1);
assert_eq!(stmts[0], "SELECT * FROM users");
}
#[test]
fn test_parse_multiple_statements() {
let script = "CREATE TABLE users (id INT); INSERT INTO users VALUES (1);";
let stmts = parse_statements(script);
assert_eq!(stmts.len(), 2);
assert_eq!(stmts[0], "CREATE TABLE users (id INT)");
assert_eq!(stmts[1], "INSERT INTO users VALUES (1)");
}
#[test]
fn test_parse_with_whitespace() {
let script = " SELECT 1; \n SELECT 2; ";
let stmts = parse_statements(script);
assert_eq!(stmts.len(), 2);
assert_eq!(stmts[0], "SELECT 1");
assert_eq!(stmts[1], "SELECT 2");
}
#[test]
fn test_parse_with_comments() {
let script = "-- This is a comment\nSELECT 1;";
let stmts = parse_statements(script);
assert_eq!(stmts.len(), 1);
assert_eq!(stmts[0], "SELECT 1");
}
#[test]
fn test_parse_empty_script() {
let script = "";
let stmts = parse_statements(script);
assert_eq!(stmts.len(), 0);
}
#[test]
fn test_parse_semicolon_in_string() {
let script = "INSERT INTO test VALUES (1, 'Error at position 10; expected value');";
let stmts = parse_statements(script);
assert_eq!(stmts.len(), 1);
assert_eq!(stmts[0], "INSERT INTO test VALUES (1, 'Error at position 10; expected value')");
}
#[test]
fn test_parse_escaped_quotes_in_string() {
let script = "INSERT INTO test VALUES ('It''s a test');";
let stmts = parse_statements(script);
assert_eq!(stmts.len(), 1);
assert_eq!(stmts[0], "INSERT INTO test VALUES ('It''s a test')");
}
#[test]
fn test_parse_multiline_comment() {
let script = "/* This is a\nmulti-line comment */\nSELECT 1;";
let stmts = parse_statements(script);
assert_eq!(stmts.len(), 1);
assert_eq!(stmts[0], "SELECT 1");
}
#[test]
fn test_parse_comment_with_semicolon() {
let script = "-- This comment has a semicolon; but it should be ignored\nSELECT 1;";
let stmts = parse_statements(script);
assert_eq!(stmts.len(), 1);
assert_eq!(stmts[0], "SELECT 1");
}
#[test]
fn test_parse_complex_error_message() {
let script = r#"INSERT INTO test_results (error_message) VALUES ('query result mismatch: [SQL] SELECT TIMESTAMP ''2025-11-15 00:00:00'' [Diff] expected; actual');"#;
let stmts = parse_statements(script);
assert_eq!(stmts.len(), 1);
assert!(stmts[0].contains("TIMESTAMP"));
assert!(stmts[0].contains("expected; actual"));
}
#[test]
fn test_parse_multiple_with_strings_and_comments() {
let script = r#"
-- First statement
INSERT INTO logs VALUES (1, 'Error: parse failed; retry');
/* Second statement
with comment */
INSERT INTO logs VALUES (2, 'Success');
-- Done
"#;
let stmts = parse_statements(script);
assert_eq!(stmts.len(), 2);
assert!(stmts[0].contains("parse failed; retry"));
assert!(stmts[1].contains("Success"));
}
}