elif_orm/migrations/
rollback.rs1use sqlx::Row;
7
8use super::definitions::{Migration, MigrationRecord, RollbackResult};
9use super::runner::MigrationRunner;
10use crate::error::{OrmError, OrmResult};
11
12#[allow(async_fn_in_trait)]
14pub trait MigrationRollback {
15 async fn rollback_last_batch(&self) -> OrmResult<RollbackResult>;
17
18 async fn rollback_batch(&self, batch: i32) -> OrmResult<RollbackResult>;
20
21 async fn rollback_migration(&self, migration_id: &str) -> OrmResult<()>;
23
24 async fn rollback_all(&self) -> OrmResult<RollbackResult>;
26
27 async fn get_migrations_in_batch(&self, batch: i32) -> OrmResult<Vec<MigrationRecord>>;
29}
30
31impl MigrationRollback for MigrationRunner {
32 async fn rollback_last_batch(&self) -> OrmResult<RollbackResult> {
33 let start_time = std::time::Instant::now();
34
35 let latest_batch = self.get_latest_batch_number().await?;
37
38 if latest_batch == 0 {
39 return Ok(RollbackResult {
40 rolled_back_count: 0,
41 rolled_back_migrations: Vec::new(),
42 execution_time_ms: start_time.elapsed().as_millis(),
43 });
44 }
45
46 self.rollback_batch(latest_batch).await
47 }
48
49 async fn rollback_batch(&self, batch: i32) -> OrmResult<RollbackResult> {
50 let start_time = std::time::Instant::now();
51
52 let batch_migrations = self.get_migrations_in_batch(batch).await?;
54
55 if batch_migrations.is_empty() {
56 return Ok(RollbackResult {
57 rolled_back_count: 0,
58 rolled_back_migrations: Vec::new(),
59 execution_time_ms: start_time.elapsed().as_millis(),
60 });
61 }
62
63 let all_migrations = self.manager().load_migrations().await?;
65 let migration_map: std::collections::HashMap<String, Migration> = all_migrations
66 .into_iter()
67 .map(|m| (m.id.clone(), m))
68 .collect();
69
70 let mut rolled_back_migrations = Vec::new();
71
72 for record in batch_migrations.iter().rev() {
74 if let Some(migration) = migration_map.get(&record.id) {
75 println!(
76 "Rolling back migration: {} - {}",
77 migration.id, migration.name
78 );
79
80 let mut transaction = self.pool().begin().await.map_err(|e| {
82 OrmError::Migration(format!("Failed to start rollback transaction: {}", e))
83 })?;
84
85 if !migration.down_sql.trim().is_empty() {
87 for statement in self.manager().split_sql_statements(&migration.down_sql)? {
88 if !statement.trim().is_empty() {
89 sqlx::query(&statement)
90 .execute(&mut *transaction)
91 .await
92 .map_err(|e| {
93 OrmError::Migration(format!(
94 "Failed to rollback migration {}: {}",
95 migration.id, e
96 ))
97 })?;
98 }
99 }
100 }
101
102 let (remove_sql, params) = self.remove_migration_sql(&migration.id);
104 let mut query = sqlx::query(&remove_sql);
105 for param in params {
106 query = query.bind(param);
107 }
108 query.execute(&mut *transaction).await.map_err(|e| {
109 OrmError::Migration(format!("Failed to remove migration record: {}", e))
110 })?;
111
112 transaction.commit().await.map_err(|e| {
114 OrmError::Migration(format!("Failed to commit rollback: {}", e))
115 })?;
116
117 rolled_back_migrations.push(record.id.clone());
118 } else {
119 return Err(OrmError::Migration(format!(
120 "Migration file not found for applied migration: {}",
121 record.id
122 )));
123 }
124 }
125
126 Ok(RollbackResult {
127 rolled_back_count: rolled_back_migrations.len(),
128 rolled_back_migrations,
129 execution_time_ms: start_time.elapsed().as_millis(),
130 })
131 }
132
133 async fn rollback_migration(&self, migration_id: &str) -> OrmResult<()> {
134 let applied_migrations = self.get_applied_migrations_ordered().await?;
136 let _migration_record = applied_migrations
137 .iter()
138 .find(|m| m.id == migration_id)
139 .ok_or_else(|| {
140 OrmError::Migration(format!("Migration {} is not applied", migration_id))
141 })?;
142
143 if let Some(most_recent) = applied_migrations.first() {
145 if most_recent.id != migration_id {
146 return Err(OrmError::Migration(
147 "Can only rollback the most recent migration. Use rollback_batch for batch operations.".to_string()
148 ));
149 }
150 }
151
152 let migrations = self.manager().load_migrations().await?;
154 let migration = migrations
155 .iter()
156 .find(|m| m.id == migration_id)
157 .ok_or_else(|| {
158 OrmError::Migration(format!("Migration file {} not found", migration_id))
159 })?;
160
161 let mut transaction = self.pool().begin().await.map_err(|e| {
163 OrmError::Migration(format!("Failed to start rollback transaction: {}", e))
164 })?;
165
166 if !migration.down_sql.trim().is_empty() {
168 for statement in self.manager().split_sql_statements(&migration.down_sql)? {
169 if !statement.trim().is_empty() {
170 sqlx::query(&statement)
171 .execute(&mut *transaction)
172 .await
173 .map_err(|e| {
174 OrmError::Migration(format!(
175 "Failed to rollback migration {}: {}",
176 migration.id, e
177 ))
178 })?;
179 }
180 }
181 }
182
183 let (remove_sql, params) = self.remove_migration_sql(&migration.id);
185 let mut query = sqlx::query(&remove_sql);
186 for param in params {
187 query = query.bind(param);
188 }
189 query.execute(&mut *transaction).await.map_err(|e| {
190 OrmError::Migration(format!("Failed to remove migration record: {}", e))
191 })?;
192
193 transaction
195 .commit()
196 .await
197 .map_err(|e| OrmError::Migration(format!("Failed to commit rollback: {}", e)))?;
198
199 println!(
200 "Rolled back migration: {} - {}",
201 migration.id, migration.name
202 );
203
204 Ok(())
205 }
206
207 async fn rollback_all(&self) -> OrmResult<RollbackResult> {
208 let start_time = std::time::Instant::now();
209 let mut total_rolled_back = Vec::new();
210
211 loop {
212 let result = self.rollback_last_batch().await?;
213 if result.rolled_back_count == 0 {
214 break;
215 }
216 total_rolled_back.extend(result.rolled_back_migrations);
217 }
218
219 Ok(RollbackResult {
220 rolled_back_count: total_rolled_back.len(),
221 rolled_back_migrations: total_rolled_back,
222 execution_time_ms: start_time.elapsed().as_millis(),
223 })
224 }
225
226 async fn get_migrations_in_batch(&self, batch: i32) -> OrmResult<Vec<MigrationRecord>> {
227 let sql = format!(
228 "SELECT id, applied_at, batch FROM {} WHERE batch = $1 ORDER BY applied_at DESC",
229 self.manager().config().migrations_table
230 );
231
232 let rows = sqlx::query(&sql)
233 .bind(batch)
234 .fetch_all(self.pool())
235 .await
236 .map_err(|e| OrmError::Migration(format!("Failed to query batch migrations: {}", e)))?;
237
238 let mut records = Vec::new();
239 for row in rows {
240 let id: String = row
241 .try_get("id")
242 .map_err(|e| OrmError::Migration(format!("Failed to get migration id: {}", e)))?;
243 let applied_at: chrono::DateTime<chrono::Utc> = row
244 .try_get("applied_at")
245 .map_err(|e| OrmError::Migration(format!("Failed to get applied_at: {}", e)))?;
246 let batch: i32 = row
247 .try_get("batch")
248 .map_err(|e| OrmError::Migration(format!("Failed to get batch: {}", e)))?;
249
250 records.push(MigrationRecord {
251 id,
252 applied_at,
253 batch,
254 });
255 }
256
257 Ok(records)
258 }
259}
260
261impl MigrationRunner {
263 async fn get_applied_migrations_ordered(&self) -> OrmResult<Vec<MigrationRecord>> {
265 let sql = format!(
266 "SELECT id, applied_at, batch FROM {} ORDER BY batch DESC, applied_at DESC",
267 self.manager().config().migrations_table
268 );
269
270 let rows = sqlx::query(&sql)
271 .fetch_all(self.pool())
272 .await
273 .map_err(|e| {
274 OrmError::Migration(format!("Failed to query applied migrations: {}", e))
275 })?;
276
277 let mut records = Vec::new();
278 for row in rows {
279 let id: String = row
280 .try_get("id")
281 .map_err(|e| OrmError::Migration(format!("Failed to get migration id: {}", e)))?;
282 let applied_at: chrono::DateTime<chrono::Utc> = row
283 .try_get("applied_at")
284 .map_err(|e| OrmError::Migration(format!("Failed to get applied_at: {}", e)))?;
285 let batch: i32 = row
286 .try_get("batch")
287 .map_err(|e| OrmError::Migration(format!("Failed to get batch: {}", e)))?;
288
289 records.push(MigrationRecord {
290 id,
291 applied_at,
292 batch,
293 });
294 }
295
296 Ok(records)
297 }
298
299 async fn get_latest_batch_number(&self) -> OrmResult<i32> {
301 let sql = format!(
302 "SELECT COALESCE(MAX(batch), 0) FROM {}",
303 self.manager().config().migrations_table
304 );
305
306 let row = sqlx::query(&sql)
307 .fetch_one(self.pool())
308 .await
309 .map_err(|e| OrmError::Migration(format!("Failed to get latest batch: {}", e)))?;
310
311 let latest_batch: i32 = row.try_get(0).unwrap_or(0);
312 Ok(latest_batch)
313 }
314
315 fn remove_migration_sql(&self, migration_id: &str) -> (String, Vec<String>) {
317 (
318 format!(
319 "DELETE FROM {} WHERE id = $1",
320 self.manager().config().migrations_table
321 ),
322 vec![migration_id.to_string()],
323 )
324 }
325}