1use std::error::Error;
2
3use sqlx::{Database, Executor, Pool};
4
5#[derive(Debug, Clone)]
6pub struct Migration {
7 pub version: String,
8 pub name: String,
9 pub up_sql: String,
10 pub down_sql: String,
11}
12
13#[derive(Debug, Clone, sqlx::FromRow)]
14struct AppliedMigration {
15 version: String,
16}
17
18pub struct Migrator<DB: Database> {
19 pool: Pool<DB>,
20}
21
22impl<DB: Database> Migrator<DB> {
23 pub fn new(pool: Pool<DB>) -> Self {
24 Self { pool }
25 }
26}
27
28#[cfg(feature = "sqlite")]
33impl Migrator<sqlx::Sqlite> {
34 pub async fn run(&self, migrations: Vec<Migration>) -> Result<(), Box<dyn Error>> {
35 let mut tx = self.pool.begin().await?;
36
37 sqlx::query(
39 "CREATE TABLE IF NOT EXISTS _premix_migrations (
40 version TEXT PRIMARY KEY,
41 name TEXT NOT NULL,
42 applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
43 )",
44 )
45 .execute(&mut *tx)
46 .await?;
47
48 let applied_versions: Vec<String> = sqlx::query_as::<_, AppliedMigration>(
50 "SELECT version FROM _premix_migrations ORDER BY version ASC",
51 )
52 .fetch_all(&mut *tx)
53 .await?
54 .into_iter()
55 .map(|m| m.version)
56 .collect();
57
58 for migration in migrations {
60 if !applied_versions.contains(&migration.version) {
61 println!(
62 "🚚 Applying migration: {} - {}",
63 migration.version, migration.name
64 );
65
66 tx.execute(migration.up_sql.as_str()).await?;
68
69 sqlx::query("INSERT INTO _premix_migrations (version, name) VALUES (?, ?)")
71 .bind(&migration.version)
72 .bind(&migration.name)
73 .execute(&mut *tx)
74 .await?;
75 }
76 }
77
78 tx.commit().await?;
79 Ok(())
80 }
81
82 pub async fn rollback_last(&self, migrations: Vec<Migration>) -> Result<bool, Box<dyn Error>> {
83 let mut tx = self.pool.begin().await?;
84
85 sqlx::query(
86 "CREATE TABLE IF NOT EXISTS _premix_migrations (
87 version TEXT PRIMARY KEY,
88 name TEXT NOT NULL,
89 applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
90 )",
91 )
92 .execute(&mut *tx)
93 .await?;
94
95 let versions: Vec<String> = migrations.iter().map(|m| m.version.clone()).collect();
96 if versions.is_empty() {
97 tx.commit().await?;
98 return Ok(false);
99 }
100
101 let placeholders = vec!["?"; versions.len()].join(", ");
102 let sql = format!(
103 "SELECT version FROM _premix_migrations WHERE version IN ({}) ORDER BY version DESC LIMIT 1",
104 placeholders
105 );
106 let mut query = sqlx::query_scalar::<_, String>(&sql);
107 for version in &versions {
108 query = query.bind(version);
109 }
110 let last = query.fetch_optional(&mut *tx).await?;
111
112 let Some(last) = last else {
113 tx.commit().await?;
114 return Ok(false);
115 };
116
117 let migration = migrations
118 .into_iter()
119 .find(|m| m.version == last)
120 .ok_or_else(|| format!("Migration {} not found.", last))?;
121
122 if migration.down_sql.trim().is_empty() {
123 return Err("Down migration is empty.".into());
124 }
125
126 tx.execute(migration.down_sql.as_str()).await?;
127 sqlx::query("DELETE FROM _premix_migrations WHERE version = ?")
128 .bind(&migration.version)
129 .execute(&mut *tx)
130 .await?;
131
132 tx.commit().await?;
133 Ok(true)
134 }
135}
136
137#[cfg(feature = "postgres")]
138impl Migrator<sqlx::Postgres> {
139 pub async fn run(&self, migrations: Vec<Migration>) -> Result<(), Box<dyn Error>> {
140 let mut tx = self.pool.begin().await?;
141
142 sqlx::query(
144 "CREATE TABLE IF NOT EXISTS _premix_migrations (
145 version TEXT PRIMARY KEY,
146 name TEXT NOT NULL,
147 applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
148 )",
149 )
150 .execute(&mut *tx)
151 .await?;
152
153 let applied_versions: Vec<String> = sqlx::query_as::<_, AppliedMigration>(
155 "SELECT version FROM _premix_migrations ORDER BY version ASC",
156 )
157 .fetch_all(&mut *tx)
158 .await?
159 .into_iter()
160 .map(|m| m.version)
161 .collect();
162
163 for migration in migrations {
165 if !applied_versions.contains(&migration.version) {
166 println!(
167 "🚚 Applying migration: {} - {}",
168 migration.version, migration.name
169 );
170
171 tx.execute(migration.up_sql.as_str()).await?;
179
180 sqlx::query("INSERT INTO _premix_migrations (version, name) VALUES ($1, $2)")
182 .bind(&migration.version)
183 .bind(&migration.name)
184 .execute(&mut *tx)
185 .await?;
186 }
187 }
188
189 tx.commit().await?;
190 Ok(())
191 }
192
193 pub async fn rollback_last(&self, migrations: Vec<Migration>) -> Result<bool, Box<dyn Error>> {
194 let mut tx = self.pool.begin().await?;
195
196 sqlx::query(
197 "CREATE TABLE IF NOT EXISTS _premix_migrations (
198 version TEXT PRIMARY KEY,
199 name TEXT NOT NULL,
200 applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
201 )",
202 )
203 .execute(&mut *tx)
204 .await?;
205
206 let versions: Vec<String> = migrations.iter().map(|m| m.version.clone()).collect();
207 if versions.is_empty() {
208 tx.commit().await?;
209 return Ok(false);
210 }
211
212 let last = sqlx::query_scalar::<_, String>(
213 "SELECT version FROM _premix_migrations WHERE version = ANY($1) ORDER BY version DESC LIMIT 1",
214 )
215 .bind(&versions)
216 .fetch_optional(&mut *tx)
217 .await?;
218
219 let Some(last) = last else {
220 tx.commit().await?;
221 return Ok(false);
222 };
223
224 let migration = migrations
225 .into_iter()
226 .find(|m| m.version == last)
227 .ok_or_else(|| format!("Migration {} not found.", last))?;
228
229 if migration.down_sql.trim().is_empty() {
230 return Err("Down migration is empty.".into());
231 }
232
233 tx.execute(migration.down_sql.as_str()).await?;
234 sqlx::query("DELETE FROM _premix_migrations WHERE version = $1")
235 .bind(&migration.version)
236 .execute(&mut *tx)
237 .await?;
238
239 tx.commit().await?;
240 Ok(true)
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 #[cfg(feature = "postgres")]
247 use std::time::{SystemTime, UNIX_EPOCH};
248
249 use sqlx::sqlite::SqlitePoolOptions;
250
251 use super::*;
252
253 #[cfg(feature = "postgres")]
254 async fn pg_pool_or_skip() -> Option<sqlx::Pool<sqlx::Postgres>> {
255 let db_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| {
256 "postgres://postgres:admin123@localhost:5432/premix_bench".to_string()
257 });
258 match sqlx::postgres::PgPoolOptions::new()
259 .max_connections(1)
260 .connect(&db_url)
261 .await
262 {
263 Ok(pool) => Some(pool),
264 Err(_) => None,
265 }
266 }
267
268 #[tokio::test]
269 async fn sqlite_migrator_applies_pending_once() {
270 let pool = SqlitePoolOptions::new()
271 .max_connections(1)
272 .connect("sqlite::memory:")
273 .await
274 .unwrap();
275 let migrator = Migrator::new(pool.clone());
276
277 let migrations = vec![Migration {
278 version: "20260101000000".to_string(),
279 name: "create_users".to_string(),
280 up_sql: "CREATE TABLE users (id INTEGER PRIMARY KEY);".to_string(),
281 down_sql: "DROP TABLE users;".to_string(),
282 }];
283
284 migrator.run(migrations.clone()).await.unwrap();
285
286 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
287 .fetch_one(&pool)
288 .await
289 .unwrap();
290 assert_eq!(count, 1);
291
292 migrator.run(migrations).await.unwrap();
293 let count_after: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
294 .fetch_one(&pool)
295 .await
296 .unwrap();
297 assert_eq!(count_after, 1);
298 }
299
300 #[tokio::test]
301 async fn sqlite_migrator_applies_multiple() {
302 let pool = SqlitePoolOptions::new()
303 .max_connections(1)
304 .connect("sqlite::memory:")
305 .await
306 .unwrap();
307 let migrator = Migrator::new(pool.clone());
308
309 let migrations = vec![
310 Migration {
311 version: "20260103000000".to_string(),
312 name: "create_a".to_string(),
313 up_sql: "CREATE TABLE a (id INTEGER PRIMARY KEY);".to_string(),
314 down_sql: "DROP TABLE a;".to_string(),
315 },
316 Migration {
317 version: "20260104000000".to_string(),
318 name: "create_b".to_string(),
319 up_sql: "CREATE TABLE b (id INTEGER PRIMARY KEY);".to_string(),
320 down_sql: "DROP TABLE b;".to_string(),
321 },
322 ];
323
324 migrator.run(migrations).await.unwrap();
325
326 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
327 .fetch_one(&pool)
328 .await
329 .unwrap();
330 assert_eq!(count, 2);
331 }
332
333 #[tokio::test]
334 async fn sqlite_migrator_rolls_back_last() {
335 let pool = SqlitePoolOptions::new()
336 .max_connections(1)
337 .connect("sqlite::memory:")
338 .await
339 .unwrap();
340 let migrator = Migrator::new(pool.clone());
341
342 let migrations = vec![
343 Migration {
344 version: "20260103000000".to_string(),
345 name: "create_a".to_string(),
346 up_sql: "CREATE TABLE a (id INTEGER PRIMARY KEY);".to_string(),
347 down_sql: "DROP TABLE a;".to_string(),
348 },
349 Migration {
350 version: "20260104000000".to_string(),
351 name: "create_b".to_string(),
352 up_sql: "CREATE TABLE b (id INTEGER PRIMARY KEY);".to_string(),
353 down_sql: "DROP TABLE b;".to_string(),
354 },
355 ];
356
357 migrator.run(migrations.clone()).await.unwrap();
358 let rolled_back = migrator.rollback_last(migrations).await.unwrap();
359 assert!(rolled_back);
360
361 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
362 .fetch_one(&pool)
363 .await
364 .unwrap();
365 assert_eq!(count, 1);
366 }
367
368 #[tokio::test]
369 async fn sqlite_migrator_rolls_back_on_error() {
370 let pool = SqlitePoolOptions::new()
371 .max_connections(1)
372 .connect("sqlite::memory:")
373 .await
374 .unwrap();
375 let migrator = Migrator::new(pool.clone());
376
377 let migrations = vec![Migration {
378 version: "20260102000000".to_string(),
379 name: "bad_sql".to_string(),
380 up_sql: "CREATE TABLE broken (id INTEGER PRIMARY KEY); INVALID SQL".to_string(),
381 down_sql: "DROP TABLE broken;".to_string(),
382 }];
383
384 let err = migrator.run(migrations).await.unwrap_err();
385 assert!(err.to_string().contains("syntax"));
386
387 let table: Option<String> = sqlx::query_scalar(
388 "SELECT name FROM sqlite_master WHERE type='table' AND name='_premix_migrations'",
389 )
390 .fetch_optional(&pool)
391 .await
392 .unwrap();
393 assert!(table.is_none());
394 }
395
396 #[cfg(feature = "postgres")]
397 #[tokio::test]
398 async fn postgres_migrator_applies_pending_once() {
399 let Some(pool) = pg_pool_or_skip().await else {
400 return;
401 };
402 let migrator = Migrator::new(pool.clone());
403
404 let suffix = SystemTime::now()
405 .duration_since(UNIX_EPOCH)
406 .unwrap()
407 .as_nanos();
408 let version = format!("20260101{:020}", suffix);
409 let table_name = format!("premix_mig_test_{}", suffix);
410 let migrations = vec![Migration {
411 version: version.clone(),
412 name: "create_test_table".to_string(),
413 up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_name),
414 down_sql: format!("DROP TABLE {};", table_name),
415 }];
416
417 migrator.run(migrations.clone()).await.unwrap();
418
419 let count: i64 =
420 sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
421 .bind(&version)
422 .fetch_one(&pool)
423 .await
424 .unwrap();
425 assert_eq!(count, 1);
426
427 migrator.run(migrations).await.unwrap();
428 let count_after: i64 =
429 sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
430 .bind(&version)
431 .fetch_one(&pool)
432 .await
433 .unwrap();
434 assert_eq!(count_after, 1);
435
436 let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_name))
437 .execute(&pool)
438 .await;
439 let _ = sqlx::query("DELETE FROM _premix_migrations WHERE version = $1")
440 .bind(&version)
441 .execute(&pool)
442 .await;
443 }
444
445 #[cfg(feature = "postgres")]
446 #[tokio::test]
447 async fn postgres_migrator_rolls_back_on_error() {
448 let Some(pool) = pg_pool_or_skip().await else {
449 return;
450 };
451 let migrator = Migrator::new(pool.clone());
452
453 let suffix = SystemTime::now()
454 .duration_since(UNIX_EPOCH)
455 .unwrap()
456 .as_nanos();
457 let version = format!("20260102{:020}", suffix);
458 let table_name = format!("premix_mig_bad_{}", suffix);
459 let migrations = vec![Migration {
460 version: version.clone(),
461 name: "bad_sql".to_string(),
462 up_sql: format!(
463 "CREATE TABLE {} (id SERIAL PRIMARY KEY); INVALID SQL",
464 table_name
465 ),
466 down_sql: format!("DROP TABLE {};", table_name),
467 }];
468
469 let err = migrator.run(migrations).await.unwrap_err();
470 assert!(err.to_string().contains("syntax"));
471
472 let table_exists: Option<String> =
473 sqlx::query_scalar("SELECT to_regclass('_premix_migrations')::text")
474 .fetch_one(&pool)
475 .await
476 .unwrap();
477 if table_exists.is_some() {
478 let count: i64 =
479 sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
480 .bind(&version)
481 .fetch_one(&pool)
482 .await
483 .unwrap();
484 assert_eq!(count, 0);
485 }
486
487 let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_name))
488 .execute(&pool)
489 .await;
490 }
491
492 #[cfg(feature = "postgres")]
493 #[tokio::test]
494 async fn postgres_migrator_rolls_back_last() {
495 let Some(pool) = pg_pool_or_skip().await else {
496 return;
497 };
498 let migrator = Migrator::new(pool.clone());
499
500 let suffix = SystemTime::now()
501 .duration_since(UNIX_EPOCH)
502 .unwrap()
503 .as_nanos();
504 let version_a = format!("20260103{:020}", suffix);
505 let version_b = format!("20260104{:020}", suffix);
506 let table_a = format!("premix_mig_a_{}", suffix);
507 let table_b = format!("premix_mig_b_{}", suffix);
508 let migrations = vec![
509 Migration {
510 version: version_a.clone(),
511 name: "create_a".to_string(),
512 up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_a),
513 down_sql: format!("DROP TABLE {};", table_a),
514 },
515 Migration {
516 version: version_b.clone(),
517 name: "create_b".to_string(),
518 up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_b),
519 down_sql: format!("DROP TABLE {};", table_b),
520 },
521 ];
522
523 migrator.run(migrations.clone()).await.unwrap();
524 let rolled_back = migrator.rollback_last(migrations).await.unwrap();
525 assert!(rolled_back);
526
527 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
528 .fetch_one(&pool)
529 .await
530 .unwrap();
531 assert!(count >= 1);
532
533 let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_a))
534 .execute(&pool)
535 .await;
536 let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_b))
537 .execute(&pool)
538 .await;
539 let _ = sqlx::query("DELETE FROM _premix_migrations WHERE version = $1 OR version = $2")
540 .bind(&version_a)
541 .bind(&version_b)
542 .execute(&pool)
543 .await;
544 }
545}