use anyhow::Result;
use sqlx::PgPool;
use std::path::Path;
use super::connection::initialize_database_session;
use super::error_context::SqlErrorContext;
#[derive(Debug, Clone)]
pub struct SqlExecutorConfig {
pub initialize_session: bool,
pub verbose: bool,
}
impl Default for SqlExecutorConfig {
fn default() -> Self {
Self {
initialize_session: true,
verbose: true,
}
}
}
pub async fn execute_sql_file(
pool: &PgPool,
file_path: &Path,
config: &SqlExecutorConfig,
) -> Result<()> {
let content = std::fs::read_to_string(file_path)
.map_err(|e| anyhow::anyhow!("Failed to read SQL file {}: {}", file_path.display(), e))?;
execute_sql_content(pool, &content, file_path, config).await
}
pub async fn execute_sql_content(
pool: &PgPool,
content: &str,
source: &Path,
config: &SqlExecutorConfig,
) -> Result<()> {
let trimmed_content = content.trim();
if trimmed_content.is_empty() {
if config.verbose {
println!(" ⏭️ Skipping empty file: {}", source.display());
}
return Ok(());
}
if config.initialize_session {
initialize_database_session(pool).await?;
}
match sqlx::raw_sql(content).execute(pool).await {
Ok(result) => {
if config.verbose {
println!(
" ✅ Executed {} (affected {} rows)",
source.display(),
result.rows_affected()
);
}
Ok(())
}
Err(e) => Err(format_sql_error(&e, source, content)),
}
}
fn format_sql_error(e: &sqlx::Error, source: &Path, content: &str) -> anyhow::Error {
let ctx = SqlErrorContext::from_sqlx_error(e, content);
anyhow::anyhow!("{}", ctx.format(&source.display().to_string(), content))
}
pub fn discover_sql_files_ordered(dir: &Path) -> Result<Vec<std::path::PathBuf>> {
if !dir.exists() {
return Err(anyhow::anyhow!(
"Directory does not exist: {}",
dir.display()
));
}
if !dir.is_dir() {
return Err(anyhow::anyhow!(
"Path is not a directory: {}",
dir.display()
));
}
let mut files = Vec::new();
let mut migration_dirs = Vec::new();
let mut total_entries = 0;
fn collect_sql_files_recursive(
dir: &Path,
files: &mut Vec<std::path::PathBuf>,
migration_dirs: &mut Vec<std::path::PathBuf>,
total_entries: &mut usize,
) -> Result<()> {
if !dir.is_dir() {
return Ok(());
}
let mut entries: Vec<_> = std::fs::read_dir(dir)?
.filter_map(|entry| entry.ok())
.collect();
*total_entries += entries.len();
entries.sort_by_key(|a| a.file_name());
for entry in entries {
let path = entry.path();
if path.is_dir() {
if is_migration_directory(&path) {
migration_dirs.push(path.clone());
}
collect_sql_files_recursive(&path, files, migration_dirs, total_entries)?;
} else if is_sql_file(&path) {
files.push(path);
}
}
Ok(())
}
collect_sql_files_recursive(dir, &mut files, &mut migration_dirs, &mut total_entries)?;
if files.is_empty() && total_entries == 0 {
eprintln!("⚠️ Warning: The directory '{}' is empty.", dir.display());
eprintln!(" No files or subdirectories found to import.");
} else if files.is_empty() && total_entries > 0 {
eprintln!(
"⚠️ Warning: No SQL files found in '{}' (searched {} items recursively).",
dir.display(),
total_entries
);
eprintln!(" The directory contains files but no .sql files were discovered.");
if !migration_dirs.is_empty() {
eprintln!(
" Found {} potential migration directories:",
migration_dirs.len()
);
for (i, migration_dir) in migration_dirs.iter().enumerate().take(5) {
eprintln!(" - {}", migration_dir.display());
if i == 4 && migration_dirs.len() > 5 {
eprintln!(" ... and {} more", migration_dirs.len() - 5);
break;
}
}
}
} else {
eprintln!("✅ Found {} SQL files in '{}'", files.len(), dir.display());
if !migration_dirs.is_empty() {
eprintln!(
" Discovered {} migration directories (Prisma/Flyway style)",
migration_dirs.len()
);
}
if files.len() <= 5 {
eprintln!(
" Files: {}",
files
.iter()
.map(|f| f.file_name().unwrap().to_string_lossy())
.collect::<Vec<_>>()
.join(", ")
);
} else {
eprintln!(
" Sample files: {}, ... and {} more",
files
.iter()
.take(3)
.map(|f| f.file_name().unwrap().to_string_lossy())
.collect::<Vec<_>>()
.join(", "),
files.len() - 3
);
}
}
Ok(files)
}
fn is_migration_directory(dir: &Path) -> bool {
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
if is_sql_file(&entry.path()) {
return true;
}
}
}
false
}
fn is_sql_file(path: &Path) -> bool {
if let Some(extension) = path.extension().and_then(|e| e.to_str()) {
if extension.eq_ignore_ascii_case("sql") {
return true;
}
if extension.eq_ignore_ascii_case("psql") || extension.eq_ignore_ascii_case("pgsql") {
return true;
}
}
if let Some(filename) = path.file_name().and_then(|n| n.to_str())
&& (filename.ends_with(".up")
|| filename.ends_with(".down")
|| filename.contains("migration")
|| filename.contains("schema"))
{
if let Ok(content) = std::fs::read_to_string(path) {
let content_lower = content.to_lowercase();
if content_lower.contains("create")
|| content_lower.contains("alter")
|| content_lower.contains("drop")
|| content_lower.contains("insert")
|| content_lower.contains("select")
{
return true;
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_sql_executor_config_default() {
let config = SqlExecutorConfig::default();
assert!(config.initialize_session);
assert!(config.verbose);
}
#[test]
fn test_is_sql_file() {
assert!(is_sql_file(Path::new("test.sql")));
assert!(is_sql_file(Path::new("test.SQL")));
assert!(is_sql_file(Path::new("test.psql")));
assert!(is_sql_file(Path::new("test.pgsql")));
assert!(!is_sql_file(Path::new("test.txt")));
assert!(!is_sql_file(Path::new("test")));
}
#[test]
fn test_discover_sql_files_ordered() {
let temp_dir = env::temp_dir().join("pgmt_test_sql_discovery");
let _ = std::fs::remove_dir_all(&temp_dir);
std::fs::create_dir_all(&temp_dir).unwrap();
std::fs::write(temp_dir.join("02_tables.sql"), "CREATE TABLE test;").unwrap();
std::fs::write(temp_dir.join("01_schema.sql"), "CREATE SCHEMA test;").unwrap();
std::fs::write(temp_dir.join("03_data.sql"), "INSERT INTO test VALUES (1);").unwrap();
let files = discover_sql_files_ordered(&temp_dir).unwrap();
assert_eq!(files.len(), 3);
assert!(
files[0]
.file_name()
.unwrap()
.to_str()
.unwrap()
.starts_with("01_")
);
assert!(
files[1]
.file_name()
.unwrap()
.to_str()
.unwrap()
.starts_with("02_")
);
assert!(
files[2]
.file_name()
.unwrap()
.to_str()
.unwrap()
.starts_with("03_")
);
let _ = std::fs::remove_dir_all(&temp_dir);
}
}