use anyhow::{Result, anyhow};
use sqlx::PgPool;
use crate::db::error_context::SqlErrorContext;
use crate::render::Safety;
use crate::schema_loader::SchemaFile;
pub struct SqlContentExecutor {
pool: PgPool,
config: SqlExecutorConfig,
}
#[derive(Debug, Clone)]
pub struct SqlExecutorConfig {
pub error_level: ErrorLevel,
pub progress_style: ProgressStyle,
pub content_truncation: usize,
pub source_context: SourceContextStyle,
pub safety_indicators: bool,
#[allow(dead_code)]
pub continue_on_error: bool,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ErrorLevel {
Enhanced, WithTips, }
#[derive(Debug, Clone, PartialEq)]
pub enum ProgressStyle {
None, FileCount, StepCount, Detailed, }
#[derive(Debug, Clone, PartialEq)]
pub enum SourceContextStyle {
File, Step, Baseline, #[allow(dead_code)]
Custom(String), }
impl Default for SqlExecutorConfig {
fn default() -> Self {
Self {
error_level: ErrorLevel::Enhanced,
progress_style: ProgressStyle::FileCount,
content_truncation: 300,
source_context: SourceContextStyle::File,
safety_indicators: false,
continue_on_error: false,
}
}
}
#[derive(Debug)]
pub struct SqlExecutionError {
pub source_context: String,
pub sql_content: String,
pub line_number: Option<usize>,
pub postgres_error: String,
pub pg_detail: Option<String>,
pub pg_hint: Option<String>,
pub pg_context: Option<String>,
pub suggestion: Option<String>,
pub troubleshooting_tips: Vec<String>,
pub dependencies_info: Option<String>,
}
impl SqlExecutionError {
pub fn format_error(
&self,
error_level: &ErrorLevel,
context_style: &SourceContextStyle,
truncate_at: usize,
) -> String {
let mut error_msg = String::new();
let context_label = match context_style {
SourceContextStyle::File => "schema file",
SourceContextStyle::Step => "migration step",
SourceContextStyle::Baseline => "baseline",
SourceContextStyle::Custom(label) => label,
};
error_msg.push_str(&format!(
"❌ Failed to apply {} '{}'",
context_label, self.source_context
));
if let Some(deps_info) = &self.dependencies_info {
error_msg.push_str(&format!(" {}", deps_info));
}
error_msg.push_str("\n\n🐘 Database Error:\n");
error_msg.push_str(&self.postgres_error);
if matches!(error_level, ErrorLevel::Enhanced | ErrorLevel::WithTips)
&& let Some(line_num) = self.line_number
{
error_msg.push_str(&format!(" (Line {})", line_num));
}
if let Some(detail) = &self.pg_detail {
error_msg.push_str(&format!("\n\n Detail: {}", detail));
}
if let Some(hint) = &self.pg_hint {
error_msg.push_str(&format!("\n Hint: {}", hint));
}
if let Some(ctx) = &self.pg_context {
error_msg.push_str(&format!("\n Context: {}", ctx));
}
let content_preview = if let Some(line_num) = self.line_number {
Self::format_content_with_line_context(&self.sql_content, line_num, truncate_at)
} else {
let trimmed = self.sql_content.trim();
if trimmed.len() > truncate_at {
format!(
"{}...\n\n[Content truncated - {} total characters]",
&trimmed[..truncate_at],
trimmed.len()
)
} else {
trimmed.to_string()
}
};
error_msg.push_str(&format!("\n\n📄 Content:\n{}", content_preview));
if matches!(error_level, ErrorLevel::WithTips) && !self.troubleshooting_tips.is_empty() {
error_msg.push_str("\n\nTroubleshooting tips:\n");
for tip in &self.troubleshooting_tips {
error_msg.push_str(&format!("• {}\n", tip));
}
}
if matches!(error_level, ErrorLevel::Enhanced | ErrorLevel::WithTips)
&& let Some(suggestion) = &self.suggestion
{
error_msg.push_str(&format!("\n💡 Suggestion: {}", suggestion));
}
error_msg
}
fn format_content_with_line_context(
content: &str,
error_line: usize,
max_chars: usize,
) -> String {
let lines: Vec<&str> = content.lines().collect();
let total_lines = lines.len();
const CONTEXT_LINES: usize = 3;
let error_idx = error_line.saturating_sub(1);
let start_idx = error_idx.saturating_sub(CONTEXT_LINES);
let end_idx = (error_idx + CONTEXT_LINES + 1).min(total_lines);
let mut result = String::new();
if start_idx > 0 {
result.push_str(&format!(
"... [Showing lines {}-{} of {}]\n\n",
start_idx + 1,
end_idx,
total_lines
));
}
for (idx, line) in lines[start_idx..end_idx].iter().enumerate() {
let line_num = start_idx + idx + 1;
let is_error_line = line_num == error_line;
if is_error_line {
result.push_str(&format!("❌ {:4} | {}\n", line_num, line));
} else {
result.push_str(&format!(" {:4} | {}\n", line_num, line));
}
}
if end_idx < total_lines {
result.push_str(&format!("\n... [{} more lines]", total_lines - end_idx));
}
if result.len() > max_chars {
format!(
"{}...\n\n[Content truncated - {} total characters]",
&result[..max_chars],
result.len()
)
} else {
result
}
}
}
impl std::fmt::Display for SqlExecutionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
self.format_error(&ErrorLevel::Enhanced, &SourceContextStyle::File, 500)
)
}
}
impl std::error::Error for SqlExecutionError {}
impl SqlContentExecutor {
pub fn new(pool: PgPool, config: SqlExecutorConfig) -> Self {
Self { pool, config }
}
pub async fn execute_content(&self, content: &str, source: &str) -> Result<()> {
self.execute_content_with_deps(content, source, None).await
}
pub async fn execute_content_with_deps(
&self,
content: &str,
source: &str,
deps_info: Option<String>,
) -> Result<()> {
let trimmed_content = content.trim();
if trimmed_content.is_empty() {
return Ok(());
}
match sqlx::raw_sql(content).execute(&self.pool).await {
Ok(_) => Ok(()),
Err(sqlx_error) => {
let ctx = SqlErrorContext::from_sqlx_error(&sqlx_error, content);
let error_info = SqlExecutionError {
source_context: source.to_string(),
sql_content: content.to_string(),
line_number: ctx.line_number,
postgres_error: ctx.message,
pg_detail: ctx.detail,
pg_hint: ctx.hint,
pg_context: ctx.context,
suggestion: Self::generate_suggestion(&sqlx_error),
troubleshooting_tips: Self::generate_troubleshooting_tips(&sqlx_error),
dependencies_info: deps_info,
};
let formatted_error = error_info.format_error(
&self.config.error_level,
&self.config.source_context,
self.config.content_truncation,
);
Err(anyhow!(formatted_error))
}
}
}
pub async fn execute_schema_file(&self, file: &SchemaFile) -> Result<()> {
let deps_info = if !file.dependencies.is_empty() {
Some(format!("(depends on: {})", file.dependencies.join(", ")))
} else {
None
};
self.execute_content_with_deps(&file.content, &file.relative_path, deps_info)
.await
}
pub async fn execute_baseline(&self, baseline_sql: &str, source: &str) -> Result<()> {
self.execute_content(baseline_sql, source).await
}
pub async fn execute_step(
&self,
step_sql: &str,
safety: Safety,
step_num: usize,
) -> Result<()> {
if matches!(
self.config.progress_style,
ProgressStyle::StepCount | ProgressStyle::Detailed
) {
let step_prefix = match safety {
Safety::Safe => "✅",
Safety::Destructive => "⚠️",
};
if self.config.safety_indicators {
println!(
"{} Executing step {}: {:?} operation",
step_prefix, step_num, safety
);
} else {
println!("Executing step {}", step_num);
}
if matches!(self.config.progress_style, ProgressStyle::Detailed) {
println!("{}", step_sql);
}
}
self.execute_content(step_sql, &format!("step {}", step_num))
.await
}
}
impl SqlContentExecutor {
fn generate_suggestion(error: &sqlx::Error) -> Option<String> {
let error_str = error.to_string().to_lowercase();
if error_str.contains("type") && error_str.contains("does not exist") {
Some("Check for typos in data type names. Common types: TEXT, INTEGER, BOOLEAN, TIMESTAMP".to_string())
} else if error_str.contains("syntax error") && error_str.contains("check") {
Some(
"Syntax error near CHECK. Verify CHECK constraint syntax and parentheses."
.to_string(),
)
} else if error_str.contains("column") && error_str.contains("does not exist") {
Some(
"Verify column names and ensure tables are created before referencing them."
.to_string(),
)
} else if error_str.contains("relation") && error_str.contains("does not exist") {
Some("Table/view does not exist. Check dependency order and table names.".to_string())
} else if error_str.contains("syntax error at or near") {
if let Some(near_word) = Self::extract_near_word(&error_str) {
Some(format!(
"Syntax error near '{}'. Check SQL syntax and keywords.",
near_word
))
} else {
Some("SQL syntax error. Verify SQL syntax and keywords.".to_string())
}
} else {
None
}
}
fn generate_troubleshooting_tips(error: &sqlx::Error) -> Vec<String> {
let error_str = error.to_string();
let mut tips = Vec::new();
if error_str.contains("cannot insert multiple commands into a prepared statement") {
tips.push(
"Multiple SQL commands detected. This should be handled automatically.".to_string(),
);
tips.push(
"Ensure each SQL statement ends with a semicolon and is separated by newlines."
.to_string(),
);
}
if error_str.contains("already exists") {
tips.push(
"This object already exists. Check if this file is being applied multiple times."
.to_string(),
);
tips.push(
"Consider if there are duplicate definitions or manual changes to the database."
.to_string(),
);
}
if error_str.contains("does not exist") {
tips.push(
"A referenced object doesn't exist. Check if dependencies are properly specified."
.to_string(),
);
tips.push("Verify that dependent files are listed in the correct order.".to_string());
tips.push("Check if `-- require:` headers are present and correct.".to_string());
}
if error_str.contains("syntax error") || error_str.contains("parse error") {
tips.push(
"There's a SQL syntax error. Check the SQL syntax in this content.".to_string(),
);
tips.push("Look for missing semicolons, unmatched parentheses, or typos.".to_string());
}
if error_str.contains("permission denied") {
tips.push(
"Permission denied. Check database user permissions for this operation."
.to_string(),
);
}
tips
}
fn extract_near_word(error_str: &str) -> Option<String> {
if let Some(start) = error_str.find("at or near \"") {
let start = start + 12; if let Some(end) = error_str[start..].find('"') {
return Some(error_str[start..start + end].to_string());
}
}
None
}
}
pub struct SchemaFileExecutor {
inner: SqlContentExecutor,
}
impl SchemaFileExecutor {
pub fn new(pool: PgPool, verbose: bool) -> Self {
let config = SqlExecutorConfig {
error_level: ErrorLevel::WithTips,
progress_style: if verbose {
ProgressStyle::FileCount
} else {
ProgressStyle::None
},
content_truncation: 400,
source_context: SourceContextStyle::File,
safety_indicators: false,
continue_on_error: false,
};
Self {
inner: SqlContentExecutor::new(pool, config),
}
}
pub async fn execute_schema_file(&self, file: &SchemaFile) -> Result<()> {
self.inner.execute_schema_file(file).await
}
}
pub struct BaselineExecutor {
inner: SqlContentExecutor,
}
impl BaselineExecutor {
pub fn new(pool: PgPool, verbose: bool, force: bool) -> Self {
let config = SqlExecutorConfig {
error_level: if force {
ErrorLevel::Enhanced
} else {
ErrorLevel::WithTips
},
progress_style: if verbose {
ProgressStyle::Detailed
} else {
ProgressStyle::None
},
content_truncation: 300,
source_context: SourceContextStyle::Baseline,
safety_indicators: false,
continue_on_error: force,
};
Self {
inner: SqlContentExecutor::new(pool, config),
}
}
pub async fn execute_baseline(&self, baseline_sql: &str, source: &str) -> Result<()> {
self.inner.execute_baseline(baseline_sql, source).await
}
}
pub struct ApplyStepExecutor {
inner: SqlContentExecutor,
}
impl ApplyStepExecutor {
pub fn new(pool: PgPool, verbose: bool, show_safety: bool, dry_run: bool) -> Self {
let config = SqlExecutorConfig {
error_level: ErrorLevel::Enhanced,
progress_style: if verbose {
if dry_run {
ProgressStyle::Detailed
} else {
ProgressStyle::StepCount
}
} else {
ProgressStyle::None
},
content_truncation: 250,
source_context: SourceContextStyle::Step,
safety_indicators: show_safety,
continue_on_error: false,
};
Self {
inner: SqlContentExecutor::new(pool, config),
}
}
pub async fn execute_step(
&self,
step_sql: &str,
safety: Safety,
step_num: usize,
) -> Result<()> {
self.inner.execute_step(step_sql, safety, step_num).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sql_execution_error_display() {
let error = SqlExecutionError {
source_context: "tables/users.sql".to_string(),
sql_content: "CREATE TABLE users (id SERIAL, email TEXTT)".to_string(),
line_number: Some(1),
postgres_error: "type \"textt\" does not exist".to_string(),
pg_detail: None,
pg_hint: Some("Check spelling of type name".to_string()),
pg_context: None,
suggestion: Some("Check for typos in data type names".to_string()),
troubleshooting_tips: vec![],
dependencies_info: None,
};
let display = format!("{}", error);
assert!(display.contains("tables/users.sql"));
assert!(display.contains("Line 1"));
assert!(display.contains("Check for typos"));
assert!(display.contains("Hint: Check spelling"));
}
#[test]
fn test_sql_execution_error_with_pg_context() {
let error = SqlExecutionError {
source_context: "functions/process.sql".to_string(),
sql_content: "CREATE FUNCTION test() RETURNS void AS $$ BEGIN SELECT * FROM missing_table; END; $$ LANGUAGE plpgsql;".to_string(),
line_number: Some(1),
postgres_error: "relation \"missing_table\" does not exist".to_string(),
pg_detail: Some("Table was dropped in migration V123".to_string()),
pg_hint: None,
pg_context: Some("PL/pgSQL function test() line 1 at SQL statement".to_string()),
suggestion: None,
troubleshooting_tips: vec![],
dependencies_info: None,
};
let display = format!("{}", error);
assert!(display.contains("functions/process.sql"));
assert!(display.contains("Detail: Table was dropped"));
assert!(display.contains("Context: PL/pgSQL function"));
}
}