use sqlx::Row;
use super::definitions::{Migration, MigrationRecord, RollbackResult};
use super::runner::MigrationRunner;
use crate::error::{OrmError, OrmResult};
#[allow(async_fn_in_trait)]
pub trait MigrationRollback {
async fn rollback_last_batch(&self) -> OrmResult<RollbackResult>;
async fn rollback_batch(&self, batch: i32) -> OrmResult<RollbackResult>;
async fn rollback_migration(&self, migration_id: &str) -> OrmResult<()>;
async fn rollback_all(&self) -> OrmResult<RollbackResult>;
async fn get_migrations_in_batch(&self, batch: i32) -> OrmResult<Vec<MigrationRecord>>;
}
impl MigrationRollback for MigrationRunner {
async fn rollback_last_batch(&self) -> OrmResult<RollbackResult> {
let start_time = std::time::Instant::now();
let latest_batch = self.get_latest_batch_number().await?;
if latest_batch == 0 {
return Ok(RollbackResult {
rolled_back_count: 0,
rolled_back_migrations: Vec::new(),
execution_time_ms: start_time.elapsed().as_millis(),
});
}
self.rollback_batch(latest_batch).await
}
async fn rollback_batch(&self, batch: i32) -> OrmResult<RollbackResult> {
let start_time = std::time::Instant::now();
let batch_migrations = self.get_migrations_in_batch(batch).await?;
if batch_migrations.is_empty() {
return Ok(RollbackResult {
rolled_back_count: 0,
rolled_back_migrations: Vec::new(),
execution_time_ms: start_time.elapsed().as_millis(),
});
}
let all_migrations = self.manager().load_migrations().await?;
let migration_map: std::collections::HashMap<String, Migration> = all_migrations
.into_iter()
.map(|m| (m.id.clone(), m))
.collect();
let mut rolled_back_migrations = Vec::new();
for record in batch_migrations.iter().rev() {
if let Some(migration) = migration_map.get(&record.id) {
println!(
"Rolling back migration: {} - {}",
migration.id, migration.name
);
let mut transaction = self.pool().begin().await.map_err(|e| {
OrmError::Migration(format!("Failed to start rollback transaction: {}", e))
})?;
if !migration.down_sql.trim().is_empty() {
for statement in self.manager().split_sql_statements(&migration.down_sql)? {
if !statement.trim().is_empty() {
sqlx::query(&statement)
.execute(&mut *transaction)
.await
.map_err(|e| {
OrmError::Migration(format!(
"Failed to rollback migration {}: {}",
migration.id, e
))
})?;
}
}
}
let (remove_sql, params) = self.remove_migration_sql(&migration.id);
let mut query = sqlx::query(&remove_sql);
for param in params {
query = query.bind(param);
}
query.execute(&mut *transaction).await.map_err(|e| {
OrmError::Migration(format!("Failed to remove migration record: {}", e))
})?;
transaction.commit().await.map_err(|e| {
OrmError::Migration(format!("Failed to commit rollback: {}", e))
})?;
rolled_back_migrations.push(record.id.clone());
} else {
return Err(OrmError::Migration(format!(
"Migration file not found for applied migration: {}",
record.id
)));
}
}
Ok(RollbackResult {
rolled_back_count: rolled_back_migrations.len(),
rolled_back_migrations,
execution_time_ms: start_time.elapsed().as_millis(),
})
}
async fn rollback_migration(&self, migration_id: &str) -> OrmResult<()> {
let applied_migrations = self.get_applied_migrations_ordered().await?;
let _migration_record = applied_migrations
.iter()
.find(|m| m.id == migration_id)
.ok_or_else(|| {
OrmError::Migration(format!("Migration {} is not applied", migration_id))
})?;
if let Some(most_recent) = applied_migrations.first() {
if most_recent.id != migration_id {
return Err(OrmError::Migration(
"Can only rollback the most recent migration. Use rollback_batch for batch operations.".to_string()
));
}
}
let migrations = self.manager().load_migrations().await?;
let migration = migrations
.iter()
.find(|m| m.id == migration_id)
.ok_or_else(|| {
OrmError::Migration(format!("Migration file {} not found", migration_id))
})?;
let mut transaction = self.pool().begin().await.map_err(|e| {
OrmError::Migration(format!("Failed to start rollback transaction: {}", e))
})?;
if !migration.down_sql.trim().is_empty() {
for statement in self.manager().split_sql_statements(&migration.down_sql)? {
if !statement.trim().is_empty() {
sqlx::query(&statement)
.execute(&mut *transaction)
.await
.map_err(|e| {
OrmError::Migration(format!(
"Failed to rollback migration {}: {}",
migration.id, e
))
})?;
}
}
}
let (remove_sql, params) = self.remove_migration_sql(&migration.id);
let mut query = sqlx::query(&remove_sql);
for param in params {
query = query.bind(param);
}
query.execute(&mut *transaction).await.map_err(|e| {
OrmError::Migration(format!("Failed to remove migration record: {}", e))
})?;
transaction
.commit()
.await
.map_err(|e| OrmError::Migration(format!("Failed to commit rollback: {}", e)))?;
println!(
"Rolled back migration: {} - {}",
migration.id, migration.name
);
Ok(())
}
async fn rollback_all(&self) -> OrmResult<RollbackResult> {
let start_time = std::time::Instant::now();
let mut total_rolled_back = Vec::new();
loop {
let result = self.rollback_last_batch().await?;
if result.rolled_back_count == 0 {
break;
}
total_rolled_back.extend(result.rolled_back_migrations);
}
Ok(RollbackResult {
rolled_back_count: total_rolled_back.len(),
rolled_back_migrations: total_rolled_back,
execution_time_ms: start_time.elapsed().as_millis(),
})
}
async fn get_migrations_in_batch(&self, batch: i32) -> OrmResult<Vec<MigrationRecord>> {
let sql = format!(
"SELECT id, applied_at, batch FROM {} WHERE batch = $1 ORDER BY applied_at DESC",
self.manager().config().migrations_table
);
let rows = sqlx::query(&sql)
.bind(batch)
.fetch_all(self.pool())
.await
.map_err(|e| OrmError::Migration(format!("Failed to query batch migrations: {}", e)))?;
let mut records = Vec::new();
for row in rows {
let id: String = row
.try_get("id")
.map_err(|e| OrmError::Migration(format!("Failed to get migration id: {}", e)))?;
let applied_at: chrono::DateTime<chrono::Utc> = row
.try_get("applied_at")
.map_err(|e| OrmError::Migration(format!("Failed to get applied_at: {}", e)))?;
let batch: i32 = row
.try_get("batch")
.map_err(|e| OrmError::Migration(format!("Failed to get batch: {}", e)))?;
records.push(MigrationRecord {
id,
applied_at,
batch,
});
}
Ok(records)
}
}
impl MigrationRunner {
async fn get_applied_migrations_ordered(&self) -> OrmResult<Vec<MigrationRecord>> {
let sql = format!(
"SELECT id, applied_at, batch FROM {} ORDER BY batch DESC, applied_at DESC",
self.manager().config().migrations_table
);
let rows = sqlx::query(&sql)
.fetch_all(self.pool())
.await
.map_err(|e| {
OrmError::Migration(format!("Failed to query applied migrations: {}", e))
})?;
let mut records = Vec::new();
for row in rows {
let id: String = row
.try_get("id")
.map_err(|e| OrmError::Migration(format!("Failed to get migration id: {}", e)))?;
let applied_at: chrono::DateTime<chrono::Utc> = row
.try_get("applied_at")
.map_err(|e| OrmError::Migration(format!("Failed to get applied_at: {}", e)))?;
let batch: i32 = row
.try_get("batch")
.map_err(|e| OrmError::Migration(format!("Failed to get batch: {}", e)))?;
records.push(MigrationRecord {
id,
applied_at,
batch,
});
}
Ok(records)
}
async fn get_latest_batch_number(&self) -> OrmResult<i32> {
let sql = format!(
"SELECT COALESCE(MAX(batch), 0) FROM {}",
self.manager().config().migrations_table
);
let row = sqlx::query(&sql)
.fetch_one(self.pool())
.await
.map_err(|e| OrmError::Migration(format!("Failed to get latest batch: {}", e)))?;
let latest_batch: i32 = row.try_get(0).unwrap_or(0);
Ok(latest_batch)
}
fn remove_migration_sql(&self, migration_id: &str) -> (String, Vec<String>) {
(
format!(
"DELETE FROM {} WHERE id = $1",
self.manager().config().migrations_table
),
vec![migration_id.to_string()],
)
}
}