elif_orm/migrations/
rollback.rs

1//! Migration Rollback - Handles rolling back applied migrations
2//!
3//! Provides functionality to rollback migrations by batch or individually,
4//! executing DOWN statements to reverse schema changes.
5
6use sqlx::Row;
7
8use super::definitions::{Migration, MigrationRecord, RollbackResult};
9use super::runner::MigrationRunner;
10use crate::error::{OrmError, OrmResult};
11
12/// Extension trait for MigrationRunner to add rollback functionality
13#[allow(async_fn_in_trait)]
14pub trait MigrationRollback {
15    /// Rollback the last batch of migrations
16    async fn rollback_last_batch(&self) -> OrmResult<RollbackResult>;
17
18    /// Rollback all migrations in a specific batch
19    async fn rollback_batch(&self, batch: i32) -> OrmResult<RollbackResult>;
20
21    /// Rollback a specific migration by ID
22    async fn rollback_migration(&self, migration_id: &str) -> OrmResult<()>;
23
24    /// Rollback all applied migrations
25    async fn rollback_all(&self) -> OrmResult<RollbackResult>;
26
27    /// Get migrations in a specific batch
28    async fn get_migrations_in_batch(&self, batch: i32) -> OrmResult<Vec<MigrationRecord>>;
29}
30
31impl MigrationRollback for MigrationRunner {
32    async fn rollback_last_batch(&self) -> OrmResult<RollbackResult> {
33        let start_time = std::time::Instant::now();
34
35        // Get the latest batch number
36        let latest_batch = self.get_latest_batch_number().await?;
37
38        if latest_batch == 0 {
39            return Ok(RollbackResult {
40                rolled_back_count: 0,
41                rolled_back_migrations: Vec::new(),
42                execution_time_ms: start_time.elapsed().as_millis(),
43            });
44        }
45
46        self.rollback_batch(latest_batch).await
47    }
48
49    async fn rollback_batch(&self, batch: i32) -> OrmResult<RollbackResult> {
50        let start_time = std::time::Instant::now();
51
52        // Get all migrations in this batch
53        let batch_migrations = self.get_migrations_in_batch(batch).await?;
54
55        if batch_migrations.is_empty() {
56            return Ok(RollbackResult {
57                rolled_back_count: 0,
58                rolled_back_migrations: Vec::new(),
59                execution_time_ms: start_time.elapsed().as_millis(),
60            });
61        }
62
63        // Load all migration files to get DOWN SQL
64        let all_migrations = self.manager().load_migrations().await?;
65        let migration_map: std::collections::HashMap<String, Migration> = all_migrations
66            .into_iter()
67            .map(|m| (m.id.clone(), m))
68            .collect();
69
70        let mut rolled_back_migrations = Vec::new();
71
72        // Rollback migrations in reverse order
73        for record in batch_migrations.iter().rev() {
74            if let Some(migration) = migration_map.get(&record.id) {
75                println!(
76                    "Rolling back migration: {} - {}",
77                    migration.id, migration.name
78                );
79
80                // Begin transaction
81                let mut transaction = self.pool().begin().await.map_err(|e| {
82                    OrmError::Migration(format!("Failed to start rollback transaction: {}", e))
83                })?;
84
85                // Execute DOWN SQL
86                if !migration.down_sql.trim().is_empty() {
87                    for statement in self.manager().split_sql_statements(&migration.down_sql)? {
88                        if !statement.trim().is_empty() {
89                            sqlx::query(&statement)
90                                .execute(&mut *transaction)
91                                .await
92                                .map_err(|e| {
93                                    OrmError::Migration(format!(
94                                        "Failed to rollback migration {}: {}",
95                                        migration.id, e
96                                    ))
97                                })?;
98                        }
99                    }
100                }
101
102                // Remove migration record
103                let (remove_sql, params) = self.remove_migration_sql(&migration.id);
104                let mut query = sqlx::query(&remove_sql);
105                for param in params {
106                    query = query.bind(param);
107                }
108                query.execute(&mut *transaction).await.map_err(|e| {
109                    OrmError::Migration(format!("Failed to remove migration record: {}", e))
110                })?;
111
112                // Commit transaction
113                transaction.commit().await.map_err(|e| {
114                    OrmError::Migration(format!("Failed to commit rollback: {}", e))
115                })?;
116
117                rolled_back_migrations.push(record.id.clone());
118            } else {
119                return Err(OrmError::Migration(format!(
120                    "Migration file not found for applied migration: {}",
121                    record.id
122                )));
123            }
124        }
125
126        Ok(RollbackResult {
127            rolled_back_count: rolled_back_migrations.len(),
128            rolled_back_migrations,
129            execution_time_ms: start_time.elapsed().as_millis(),
130        })
131    }
132
133    async fn rollback_migration(&self, migration_id: &str) -> OrmResult<()> {
134        // Check if migration is applied
135        let applied_migrations = self.get_applied_migrations_ordered().await?;
136        let _migration_record = applied_migrations
137            .iter()
138            .find(|m| m.id == migration_id)
139            .ok_or_else(|| {
140                OrmError::Migration(format!("Migration {} is not applied", migration_id))
141            })?;
142
143        // Check if this is the most recent migration
144        if let Some(most_recent) = applied_migrations.first() {
145            if most_recent.id != migration_id {
146                return Err(OrmError::Migration(
147                    "Can only rollback the most recent migration. Use rollback_batch for batch operations.".to_string()
148                ));
149            }
150        }
151
152        // Load the migration file
153        let migrations = self.manager().load_migrations().await?;
154        let migration = migrations
155            .iter()
156            .find(|m| m.id == migration_id)
157            .ok_or_else(|| {
158                OrmError::Migration(format!("Migration file {} not found", migration_id))
159            })?;
160
161        // Begin transaction
162        let mut transaction = self.pool().begin().await.map_err(|e| {
163            OrmError::Migration(format!("Failed to start rollback transaction: {}", e))
164        })?;
165
166        // Execute DOWN SQL
167        if !migration.down_sql.trim().is_empty() {
168            for statement in self.manager().split_sql_statements(&migration.down_sql)? {
169                if !statement.trim().is_empty() {
170                    sqlx::query(&statement)
171                        .execute(&mut *transaction)
172                        .await
173                        .map_err(|e| {
174                            OrmError::Migration(format!(
175                                "Failed to rollback migration {}: {}",
176                                migration.id, e
177                            ))
178                        })?;
179                }
180            }
181        }
182
183        // Remove migration record
184        let (remove_sql, params) = self.remove_migration_sql(&migration.id);
185        let mut query = sqlx::query(&remove_sql);
186        for param in params {
187            query = query.bind(param);
188        }
189        query.execute(&mut *transaction).await.map_err(|e| {
190            OrmError::Migration(format!("Failed to remove migration record: {}", e))
191        })?;
192
193        // Commit transaction
194        transaction
195            .commit()
196            .await
197            .map_err(|e| OrmError::Migration(format!("Failed to commit rollback: {}", e)))?;
198
199        println!(
200            "Rolled back migration: {} - {}",
201            migration.id, migration.name
202        );
203
204        Ok(())
205    }
206
207    async fn rollback_all(&self) -> OrmResult<RollbackResult> {
208        let start_time = std::time::Instant::now();
209        let mut total_rolled_back = Vec::new();
210
211        loop {
212            let result = self.rollback_last_batch().await?;
213            if result.rolled_back_count == 0 {
214                break;
215            }
216            total_rolled_back.extend(result.rolled_back_migrations);
217        }
218
219        Ok(RollbackResult {
220            rolled_back_count: total_rolled_back.len(),
221            rolled_back_migrations: total_rolled_back,
222            execution_time_ms: start_time.elapsed().as_millis(),
223        })
224    }
225
226    async fn get_migrations_in_batch(&self, batch: i32) -> OrmResult<Vec<MigrationRecord>> {
227        let sql = format!(
228            "SELECT id, applied_at, batch FROM {} WHERE batch = $1 ORDER BY applied_at DESC",
229            self.manager().config().migrations_table
230        );
231
232        let rows = sqlx::query(&sql)
233            .bind(batch)
234            .fetch_all(self.pool())
235            .await
236            .map_err(|e| OrmError::Migration(format!("Failed to query batch migrations: {}", e)))?;
237
238        let mut records = Vec::new();
239        for row in rows {
240            let id: String = row
241                .try_get("id")
242                .map_err(|e| OrmError::Migration(format!("Failed to get migration id: {}", e)))?;
243            let applied_at: chrono::DateTime<chrono::Utc> = row
244                .try_get("applied_at")
245                .map_err(|e| OrmError::Migration(format!("Failed to get applied_at: {}", e)))?;
246            let batch: i32 = row
247                .try_get("batch")
248                .map_err(|e| OrmError::Migration(format!("Failed to get batch: {}", e)))?;
249
250            records.push(MigrationRecord {
251                id,
252                applied_at,
253                batch,
254            });
255        }
256
257        Ok(records)
258    }
259}
260
261// Extension methods for MigrationRunner
262impl MigrationRunner {
263    /// Get applied migrations ordered by batch and time (most recent first)
264    async fn get_applied_migrations_ordered(&self) -> OrmResult<Vec<MigrationRecord>> {
265        let sql = format!(
266            "SELECT id, applied_at, batch FROM {} ORDER BY batch DESC, applied_at DESC",
267            self.manager().config().migrations_table
268        );
269
270        let rows = sqlx::query(&sql)
271            .fetch_all(self.pool())
272            .await
273            .map_err(|e| {
274                OrmError::Migration(format!("Failed to query applied migrations: {}", e))
275            })?;
276
277        let mut records = Vec::new();
278        for row in rows {
279            let id: String = row
280                .try_get("id")
281                .map_err(|e| OrmError::Migration(format!("Failed to get migration id: {}", e)))?;
282            let applied_at: chrono::DateTime<chrono::Utc> = row
283                .try_get("applied_at")
284                .map_err(|e| OrmError::Migration(format!("Failed to get applied_at: {}", e)))?;
285            let batch: i32 = row
286                .try_get("batch")
287                .map_err(|e| OrmError::Migration(format!("Failed to get batch: {}", e)))?;
288
289            records.push(MigrationRecord {
290                id,
291                applied_at,
292                batch,
293            });
294        }
295
296        Ok(records)
297    }
298
299    /// Get the latest batch number
300    async fn get_latest_batch_number(&self) -> OrmResult<i32> {
301        let sql = format!(
302            "SELECT COALESCE(MAX(batch), 0) FROM {}",
303            self.manager().config().migrations_table
304        );
305
306        let row = sqlx::query(&sql)
307            .fetch_one(self.pool())
308            .await
309            .map_err(|e| OrmError::Migration(format!("Failed to get latest batch: {}", e)))?;
310
311        let latest_batch: i32 = row.try_get(0).unwrap_or(0);
312        Ok(latest_batch)
313    }
314
315    /// SQL to remove a migration record
316    fn remove_migration_sql(&self, migration_id: &str) -> (String, Vec<String>) {
317        (
318            format!(
319                "DELETE FROM {} WHERE id = $1",
320                self.manager().config().migrations_table
321            ),
322            vec![migration_id.to_string()],
323        )
324    }
325}