use super::differ::SqlGenerator;
use super::schema::*;
use crate::config::DatabaseType;
use sea_orm::{ConnectionTrait, TransactionTrait};
use std::path::PathBuf;
pub struct MigrationExecutor {
pub(crate) connection: sea_orm::DatabaseConnection,
pub(crate) sql_generator: SqlGenerator,
pub(crate) history: MigrationHistory,
}
impl MigrationExecutor {
pub fn new(connection: sea_orm::DatabaseConnection, db_type: DatabaseType) -> Self {
Self {
connection,
sql_generator: SqlGenerator::new(db_type),
history: MigrationHistory::new(),
}
}
pub fn connection(&self) -> &sea_orm::DatabaseConnection {
&self.connection
}
pub fn history(&self) -> &MigrationHistory {
&self.history
}
pub async fn load_history(&mut self) -> Result<(), crate::error::DbError> {
self.ensure_migration_table_exists().await?;
let rows = {
use sea_orm::sea_query::{Alias, Expr, Order, Query};
let mut query = Query::select();
query.from(Alias::new("dbnexus_migrations"));
query.column(Alias::new("version"));
query.column(Alias::new("description"));
query.column(Alias::new("file_path"));
match self.sql_generator.db_type {
DatabaseType::Postgres => {
query.expr_as(Expr::cust("applied_at::text"), Alias::new("applied_at"));
}
DatabaseType::MySql => {
query.expr_as(Expr::cust("CAST(applied_at AS CHAR)"), Alias::new("applied_at"));
}
DatabaseType::Sqlite => {
query.column(Alias::new("applied_at"));
}
}
query.order_by(Alias::new("version"), Order::Asc);
self.connection
.query_all(&query)
.await
.map_err(crate::error::DbError::Connection)?
};
let mut history = MigrationHistory::new();
for row in rows {
let version: Result<i64, _> = row.try_get("", "version");
let version_val = match version {
Ok(v) => v,
Err(e) => {
tracing::warn!("Failed to read migration version: {}", e);
continue;
}
};
let Ok(version) = u32::try_from(version_val) else {
tracing::warn!("Invalid migration version value: {}", version_val);
continue;
};
let description: String = match row.try_get("", "description") {
Ok(d) => d,
Err(e) => {
tracing::debug!("Missing description for migration {}: {}", version, e);
String::new()
}
};
let applied_at_str: String = match row.try_get("", "applied_at") {
Ok(s) => s,
Err(e) => {
tracing::debug!("Missing applied_at for migration {}: {}", version, e);
String::new()
}
};
let applied_at = if applied_at_str.is_empty() {
time::OffsetDateTime::now_utc()
} else {
match time::OffsetDateTime::parse(&applied_at_str, &time::format_description::well_known::Rfc3339) {
Ok(dt) => dt,
Err(e) => {
tracing::debug!("Invalid applied_at format for migration {}: {}", version, e);
time::OffsetDateTime::now_utc()
}
}
};
let file_path: String = match row.try_get("", "file_path") {
Ok(p) => p,
Err(e) => {
tracing::debug!("Missing file_path for migration {}: {}", version, e);
String::new()
}
};
history.add_migration(MigrationVersion {
version,
description,
applied_at,
file_path,
});
}
self.history = history;
Ok(())
}
async fn ensure_migration_table_exists(&self) -> Result<(), crate::error::DbError> {
let create_table_sql = match self.sql_generator.db_type {
DatabaseType::Postgres => {
"CREATE TABLE IF NOT EXISTS dbnexus_migrations (
version INTEGER PRIMARY KEY,
description TEXT NOT NULL,
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
file_path TEXT
);"
}
DatabaseType::MySql => {
"CREATE TABLE IF NOT EXISTS dbnexus_migrations (
version INT PRIMARY KEY,
description TEXT NOT NULL,
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
file_path TEXT
);"
}
DatabaseType::Sqlite => {
"CREATE TABLE IF NOT EXISTS dbnexus_migrations (
version INTEGER PRIMARY KEY,
description TEXT NOT NULL,
applied_at TEXT NOT NULL DEFAULT (datetime('now')),
file_path TEXT
);"
}
};
self.connection
.execute_unprepared(create_table_sql)
.await
.map_err(crate::error::DbError::Connection)?;
Ok(())
}
pub async fn apply_migration(&mut self, migration: &Migration) -> Result<(), crate::error::DbError> {
let sql = self.sql_generator.generate_migration_sql(migration);
let txn = self
.connection
.begin()
.await
.map_err(crate::error::DbError::Connection)?;
if !sql.is_empty() {
txn.execute_unprepared(&sql)
.await
.map_err(crate::error::DbError::Connection)?;
}
let version_record = MigrationVersion {
version: migration.version,
description: migration.description.clone(),
applied_at: migration.timestamp.unwrap_or_else(time::OffsetDateTime::now_utc),
file_path: format!("migration_v{}.sql", migration.version),
};
let insert_sql =
"INSERT INTO dbnexus_migrations (version, description, applied_at, file_path) VALUES (?, ?, ?, ?)";
let backend = match self.sql_generator.db_type {
DatabaseType::Postgres => sea_orm::DbBackend::Postgres,
DatabaseType::MySql => sea_orm::DbBackend::MySql,
DatabaseType::Sqlite => sea_orm::DbBackend::Sqlite,
};
let stmt = sea_orm::Statement::from_sql_and_values(
backend,
insert_sql.to_string(),
vec![
migration.version.into(),
migration.description.clone().into(),
version_record.applied_at.to_string().into(),
version_record.file_path.clone().into(),
],
);
txn.execute_raw(stmt).await.map_err(crate::error::DbError::Connection)?;
txn.commit().await.map_err(crate::error::DbError::Connection)?;
self.history.add_migration(version_record);
Ok(())
}
pub async fn get_pending_migrations<'a>(&'a mut self, all_migrations: &'a [Migration]) -> Vec<&'a Migration> {
if self.load_history().await.is_ok() {
self.history.get_pending_migrations(all_migrations)
} else {
all_migrations.iter().collect()
}
}
pub fn get_all_versions(&self) -> Vec<u32> {
self.history.applied_migrations.iter().map(|m| m.version).collect()
}
pub fn get_latest_migration(&self) -> Option<&MigrationVersion> {
self.history.applied_migrations.last()
}
pub fn is_fully_migrated(&self, total_migrations: usize) -> bool {
self.history.applied_migrations.len() == total_migrations
}
}
#[derive(Debug, Clone)]
pub struct MigrationFile {
pub(crate) version: u32,
pub(crate) description: String,
pub(crate) file_path: PathBuf,
pub(crate) content: String,
}
impl MigrationFile {
pub fn version(&self) -> u32 {
self.version
}
pub fn description(&self) -> &str {
&self.description
}
pub fn file_path(&self) -> &PathBuf {
&self.file_path
}
pub fn content(&self) -> &str {
&self.content
}
}
#[cfg(feature = "auto-migrate")]
impl MigrationExecutor {
pub fn scan_migrations(&self, dir: &std::path::Path) -> Result<Vec<MigrationFile>, crate::error::DbError> {
let mut migrations = Vec::new();
if !dir.exists() {
tracing::warn!("Migration directory does not exist: {}", dir.display());
return Ok(migrations);
}
let entries = std::fs::read_dir(dir)
.map_err(|e| crate::error::DbError::Config(format!("Failed to read migration directory: {}", e)))?;
for entry in entries {
let entry =
entry.map_err(|e| crate::error::DbError::Config(format!("Failed to read migration entry: {}", e)))?;
let path = entry.path();
if path.is_file() && path.extension().map(|e| e == "sql").unwrap_or(false) {
if let Some(filename) = path.file_name().and_then(|n| n.to_str()) {
if let Some((version, description)) = Self::parse_filename(filename) {
let content = std::fs::read_to_string(&path).map_err(|e| {
crate::error::DbError::Config(format!("Failed to read migration file: {}", e))
})?;
migrations.push(MigrationFile {
version,
description,
file_path: path,
content,
});
}
}
}
}
migrations.sort_by_key(|m| m.version);
tracing::info!("Scanned {} migration files in {}", migrations.len(), dir.display());
Ok(migrations)
}
pub(crate) fn parse_filename(filename: &str) -> Option<(u32, String)> {
let parts: Vec<&str> = filename.split('_').collect();
if parts.is_empty() {
return None;
}
let version = parts[0].parse::<u32>().ok()?;
let description = parts[1..].join("_").replace(".sql", "");
Some((version, description))
}
pub async fn run_migrations(&mut self, dir: &std::path::Path) -> Result<u32, crate::error::DbError> {
let migration_files = self.scan_migrations(dir)?;
let mut applied_versions = std::collections::HashSet::new();
for migration_file in &migration_files {
if self.is_migration_applied(migration_file.version).await? {
applied_versions.insert(migration_file.version);
}
}
let pending: Vec<_> = migration_files
.into_iter()
.filter(|m| !applied_versions.contains(&m.version))
.collect();
if pending.is_empty() {
tracing::info!("No pending migrations to apply");
return Ok(0);
}
tracing::info!("Found {} pending migrations", pending.len());
let mut applied_count = 0;
for migration_file in &pending {
tracing::info!(
"Applying migration v{} - {}",
migration_file.version,
migration_file.description
);
match self.apply_migration_file(migration_file).await {
Ok(_) => {
applied_count += 1;
tracing::info!(
"Successfully applied migration v{} - {}",
migration_file.version,
migration_file.description
);
}
Err(e) => {
tracing::error!(
"Failed to apply migration v{} - {}: {}",
migration_file.version,
migration_file.description,
e
);
return Err(e);
}
}
}
Ok(applied_count)
}
async fn is_migration_applied(&self, version: u32) -> Result<bool, crate::error::DbError> {
self.ensure_migration_table_exists().await?;
let row = {
use sea_orm::sea_query::{Alias, Expr, ExprTrait, Query};
let mut query = Query::select();
query.expr(Expr::cust("1"));
query.from(Alias::new("dbnexus_migrations"));
query.and_where(Expr::col(Alias::new("version")).eq(version));
query.limit(1);
self.connection
.query_one(&query)
.await
.map_err(crate::error::DbError::Connection)?
};
Ok(row.is_some())
}
async fn apply_migration_file(&mut self, migration_file: &MigrationFile) -> Result<(), crate::error::DbError> {
let sql = Self::extract_up_sql(&migration_file.content);
let txn = self
.connection
.begin()
.await
.map_err(crate::error::DbError::Connection)?;
if !sql.is_empty() {
txn.execute_unprepared(sql)
.await
.map_err(crate::error::DbError::Connection)?;
}
let applied_at = time::OffsetDateTime::now_utc();
let insert_sql =
"INSERT INTO dbnexus_migrations (version, description, applied_at, file_path) VALUES (?, ?, ?, ?)";
let backend = match self.sql_generator.db_type {
DatabaseType::Postgres => sea_orm::DbBackend::Postgres,
DatabaseType::MySql => sea_orm::DbBackend::MySql,
DatabaseType::Sqlite => sea_orm::DbBackend::Sqlite,
};
let stmt = sea_orm::Statement::from_sql_and_values(
backend,
insert_sql.to_string(),
vec![
migration_file.version.into(),
migration_file.description.clone().into(),
applied_at.to_string().into(),
migration_file.file_path.to_string_lossy().into(),
],
);
txn.execute_raw(stmt).await.map_err(crate::error::DbError::Connection)?;
txn.commit().await.map_err(crate::error::DbError::Connection)?;
self.history.add_migration(MigrationVersion {
version: migration_file.version,
description: migration_file.description.clone(),
applied_at,
file_path: migration_file.file_path.to_string_lossy().to_string(),
});
Ok(())
}
fn extract_up_sql(content: &str) -> &str {
let up_marker = content
.find("-- UP:")
.or(content.find("-- up:"))
.or(content.find("UP:"));
let down_marker = content
.find("-- DOWN:")
.or(content.find("-- down:"))
.or(content.find("DOWN:"));
match (up_marker, down_marker) {
(Some(up_pos), Some(down_pos)) if down_pos > up_pos => {
&content[up_pos + 5..down_pos]
}
(Some(up_pos), _) => {
&content[up_pos + 5..]
}
(None, Some(down_pos)) => {
&content[..down_pos]
}
(None, None) => {
content
}
}
.trim()
}
pub async fn rollback_all(&mut self) -> Result<u32, crate::error::DbError> {
self.load_history().await?;
let applied = &self.history.applied_migrations;
if applied.is_empty() {
tracing::info!("No migrations to rollback");
return Ok(0);
}
let mut rollback_count = 0;
let mut versions: Vec<u32> = applied.iter().map(|m| m.version).collect();
versions.sort_by_key(|v| std::cmp::Reverse(*v));
for version in versions {
tracing::info!("Rolling back migration v{}", version);
match self.rollback_migration(version).await {
Ok(_) => {
rollback_count += 1;
tracing::info!("Successfully rolled back migration v{}", version);
}
Err(e) => {
tracing::error!("Failed to rollback migration v{}: {}", version, e);
return Err(e);
}
}
}
Ok(rollback_count)
}
pub async fn rollback_migration(&mut self, version: u32) -> Result<(), crate::error::DbError> {
let migration = self
.history
.applied_migrations
.iter()
.find(|m| m.version == version)
.ok_or_else(|| crate::error::DbError::Migration(format!("Migration version {} not found", version)))?;
let down_sql = if !migration.file_path.is_empty() {
let content = std::fs::read_to_string(&migration.file_path)
.map_err(|e| crate::error::DbError::Migration(format!("Failed to read migration file: {}", e)))?;
let (_, down_content) = MigrationFileParser::parse_migration_file(&content)
.unwrap_or((migration.description.clone(), String::new()));
if down_content.is_empty() {
return Err(crate::error::DbError::Migration(
"No DOWN SQL found in migration file".to_string(),
));
}
down_content
} else {
#[cfg(feature = "sql-parser")]
{
return Err(crate::error::DbError::Migration(
"Migration file path not found, cannot generate rollback SQL".to_string(),
));
}
#[cfg(not(feature = "sql-parser"))]
{
return Err(crate::error::DbError::Migration(
"Migration file path not found and sql-parser feature is disabled".to_string(),
));
}
};
let backend = match self.sql_generator.db_type {
DatabaseType::Postgres => sea_orm::DbBackend::Postgres,
DatabaseType::MySql => sea_orm::DbBackend::MySql,
DatabaseType::Sqlite => sea_orm::DbBackend::Sqlite,
};
let txn: sea_orm::DatabaseTransaction = self
.connection
.begin()
.await
.map_err(crate::error::DbError::Connection)?;
txn.execute_unprepared(&down_sql)
.await
.map_err(|e| crate::error::DbError::Migration(format!("Rollback execution failed: {}", e)))?;
let delete_sql = sea_orm::Statement::from_sql_and_values(
backend,
"DELETE FROM dbnexus_migrations WHERE version = ?".to_string(),
vec![version.into()],
);
txn.execute_raw(delete_sql)
.await
.map_err(crate::error::DbError::Connection)?;
txn.commit().await.map_err(crate::error::DbError::Connection)?;
self.history.applied_migrations.retain(|m| m.version != version);
Ok(())
}
}
pub struct MigrationFileParser;
impl MigrationFileParser {
pub fn parse_migration_file(content: &str) -> Result<(String, String), String> {
let description = Self::extract_description(content);
Self::validate_sql_syntax(content)?;
Ok((description, content.to_string()))
}
fn extract_description(content: &str) -> String {
for line in content.lines() {
let trimmed = line.trim();
if trimmed.starts_with("-- Migration:") {
return trimmed[12..].trim().to_string();
} else if trimmed.starts_with("/*") || trimmed.starts_with("--") {
continue; } else {
break; }
}
"Migration".to_string()
}
fn validate_sql_syntax(content: &str) -> Result<(), String> {
let has_up = content.contains("UP") || content.contains("up") || content.to_uppercase().contains("-- UP");
let has_down =
content.contains("DOWN") || content.contains("down") || content.to_uppercase().contains("-- DOWN");
if !has_up && !has_down {
let sql_statements = ["CREATE", "ALTER", "DROP", "INSERT", "UPDATE", "DELETE"];
let contains_sql = sql_statements.iter().any(|stmt| content.to_uppercase().contains(stmt));
if !contains_sql {
return Err("Migration file does not contain recognizable SQL statements".to_string());
}
}
Ok(())
}
}