1use std::error::Error;
2
3use sqlx::{Database, Executor, Pool};
4
5#[derive(Debug, Clone)]
6pub struct Migration {
7 pub version: String,
8 pub name: String,
9 pub up_sql: String,
10 pub down_sql: String,
11}
12
13#[derive(Debug, Clone, sqlx::FromRow)]
14struct AppliedMigration {
15 version: String,
16}
17
18pub struct Migrator<DB: Database> {
19 pool: Pool<DB>,
20}
21
22impl<DB: Database> Migrator<DB> {
23 pub fn new(pool: Pool<DB>) -> Self {
24 Self { pool }
25 }
26}
27
28#[cfg(feature = "sqlite")]
33impl Migrator<sqlx::Sqlite> {
34 pub async fn run(&self, migrations: Vec<Migration>) -> Result<(), Box<dyn Error>> {
35 let mut tx = self.pool.begin().await?;
36
37 sqlx::query(
39 "CREATE TABLE IF NOT EXISTS _premix_migrations (
40 version TEXT PRIMARY KEY,
41 name TEXT NOT NULL,
42 applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
43 )",
44 )
45 .execute(&mut *tx)
46 .await?;
47
48 let applied_versions: Vec<String> = sqlx::query_as::<_, AppliedMigration>(
50 "SELECT version FROM _premix_migrations ORDER BY version ASC",
51 )
52 .fetch_all(&mut *tx)
53 .await?
54 .into_iter()
55 .map(|m| m.version)
56 .collect();
57
58 for migration in migrations {
60 if !applied_versions.contains(&migration.version) {
61 tracing::info!(
62 operation = "migration_apply",
63 version = %migration.version,
64 name = %migration.name,
65 "premix migration"
66 );
67 println!(
68 "🚚 Applying migration: {} - {}",
69 migration.version, migration.name
70 );
71
72 tx.execute(migration.up_sql.as_str()).await?;
74
75 sqlx::query("INSERT INTO _premix_migrations (version, name) VALUES (?, ?)")
77 .bind(&migration.version)
78 .bind(&migration.name)
79 .execute(&mut *tx)
80 .await?;
81 }
82 }
83
84 tx.commit().await?;
85 Ok(())
86 }
87
88 pub async fn rollback_last(&self, migrations: Vec<Migration>) -> Result<bool, Box<dyn Error>> {
89 let mut tx = self.pool.begin().await?;
90
91 sqlx::query(
92 "CREATE TABLE IF NOT EXISTS _premix_migrations (
93 version TEXT PRIMARY KEY,
94 name TEXT NOT NULL,
95 applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
96 )",
97 )
98 .execute(&mut *tx)
99 .await?;
100
101 let versions: Vec<String> = migrations.iter().map(|m| m.version.clone()).collect();
102 if versions.is_empty() {
103 tx.commit().await?;
104 return Ok(false);
105 }
106
107 let placeholders = vec!["?"; versions.len()].join(", ");
108 let sql = format!(
109 "SELECT version FROM _premix_migrations WHERE version IN ({}) ORDER BY version DESC LIMIT 1",
110 placeholders
111 );
112 let mut query = sqlx::query_scalar::<_, String>(&sql);
113 for version in &versions {
114 query = query.bind(version);
115 }
116 let last = query.fetch_optional(&mut *tx).await?;
117
118 let Some(last) = last else {
119 tx.commit().await?;
120 return Ok(false);
121 };
122
123 let migration = migrations
124 .into_iter()
125 .find(|m| m.version == last)
126 .ok_or_else(|| format!("Migration {} not found.", last))?;
127
128 if migration.down_sql.trim().is_empty() {
129 return Err("Down migration is empty.".into());
130 }
131
132 tracing::info!(
133 operation = "migration_rollback",
134 version = %migration.version,
135 name = %migration.name,
136 "premix migration"
137 );
138 tx.execute(migration.down_sql.as_str()).await?;
139 sqlx::query("DELETE FROM _premix_migrations WHERE version = ?")
140 .bind(&migration.version)
141 .execute(&mut *tx)
142 .await?;
143
144 tx.commit().await?;
145 Ok(true)
146 }
147}
148
149#[cfg(feature = "postgres")]
150impl Migrator<sqlx::Postgres> {
151 pub async fn run(&self, migrations: Vec<Migration>) -> Result<(), Box<dyn Error>> {
152 let mut tx = self.pool.begin().await?;
153
154 sqlx::query(
156 "CREATE TABLE IF NOT EXISTS _premix_migrations (
157 version TEXT PRIMARY KEY,
158 name TEXT NOT NULL,
159 applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
160 )",
161 )
162 .execute(&mut *tx)
163 .await?;
164
165 let applied_versions: Vec<String> = sqlx::query_as::<_, AppliedMigration>(
167 "SELECT version FROM _premix_migrations ORDER BY version ASC",
168 )
169 .fetch_all(&mut *tx)
170 .await?
171 .into_iter()
172 .map(|m| m.version)
173 .collect();
174
175 for migration in migrations {
177 if !applied_versions.contains(&migration.version) {
178 tracing::info!(
179 operation = "migration_apply",
180 version = %migration.version,
181 name = %migration.name,
182 "premix migration"
183 );
184 println!(
185 "🚚 Applying migration: {} - {}",
186 migration.version, migration.name
187 );
188
189 tx.execute(migration.up_sql.as_str()).await?;
197
198 sqlx::query("INSERT INTO _premix_migrations (version, name) VALUES ($1, $2)")
200 .bind(&migration.version)
201 .bind(&migration.name)
202 .execute(&mut *tx)
203 .await?;
204 }
205 }
206
207 tx.commit().await?;
208 Ok(())
209 }
210
211 pub async fn rollback_last(&self, migrations: Vec<Migration>) -> Result<bool, Box<dyn Error>> {
212 let mut tx = self.pool.begin().await?;
213
214 sqlx::query(
215 "CREATE TABLE IF NOT EXISTS _premix_migrations (
216 version TEXT PRIMARY KEY,
217 name TEXT NOT NULL,
218 applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
219 )",
220 )
221 .execute(&mut *tx)
222 .await?;
223
224 let versions: Vec<String> = migrations.iter().map(|m| m.version.clone()).collect();
225 if versions.is_empty() {
226 tx.commit().await?;
227 return Ok(false);
228 }
229
230 let last = sqlx::query_scalar::<_, String>(
231 "SELECT version FROM _premix_migrations WHERE version = ANY($1) ORDER BY version DESC LIMIT 1",
232 )
233 .bind(&versions)
234 .fetch_optional(&mut *tx)
235 .await?;
236
237 let Some(last) = last else {
238 tx.commit().await?;
239 return Ok(false);
240 };
241
242 let migration = migrations
243 .into_iter()
244 .find(|m| m.version == last)
245 .ok_or_else(|| format!("Migration {} not found.", last))?;
246
247 if migration.down_sql.trim().is_empty() {
248 return Err("Down migration is empty.".into());
249 }
250
251 tracing::info!(
252 operation = "migration_rollback",
253 version = %migration.version,
254 name = %migration.name,
255 "premix migration"
256 );
257 tx.execute(migration.down_sql.as_str()).await?;
258 sqlx::query("DELETE FROM _premix_migrations WHERE version = $1")
259 .bind(&migration.version)
260 .execute(&mut *tx)
261 .await?;
262
263 tx.commit().await?;
264 Ok(true)
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 #[cfg(feature = "postgres")]
271 use std::time::{SystemTime, UNIX_EPOCH};
272
273 use sqlx::sqlite::SqlitePoolOptions;
274
275 use super::*;
276
277 #[cfg(feature = "postgres")]
278 async fn pg_pool_or_skip() -> Option<sqlx::Pool<sqlx::Postgres>> {
279 let db_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| {
280 "postgres://postgres:admin123@localhost:5432/premix_bench".to_string()
281 });
282 match sqlx::postgres::PgPoolOptions::new()
283 .max_connections(1)
284 .connect(&db_url)
285 .await
286 {
287 Ok(pool) => Some(pool),
288 Err(_) => None,
289 }
290 }
291
292 #[tokio::test]
293 async fn sqlite_migrator_applies_pending_once() {
294 let pool = SqlitePoolOptions::new()
295 .max_connections(1)
296 .connect("sqlite::memory:")
297 .await
298 .unwrap();
299 let migrator = Migrator::new(pool.clone());
300
301 let migrations = vec![Migration {
302 version: "20260101000000".to_string(),
303 name: "create_users".to_string(),
304 up_sql: "CREATE TABLE users (id INTEGER PRIMARY KEY);".to_string(),
305 down_sql: "DROP TABLE users;".to_string(),
306 }];
307
308 migrator.run(migrations.clone()).await.unwrap();
309
310 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
311 .fetch_one(&pool)
312 .await
313 .unwrap();
314 assert_eq!(count, 1);
315
316 migrator.run(migrations).await.unwrap();
317 let count_after: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
318 .fetch_one(&pool)
319 .await
320 .unwrap();
321 assert_eq!(count_after, 1);
322 }
323
324 #[tokio::test]
325 async fn sqlite_migrator_applies_multiple() {
326 let pool = SqlitePoolOptions::new()
327 .max_connections(1)
328 .connect("sqlite::memory:")
329 .await
330 .unwrap();
331 let migrator = Migrator::new(pool.clone());
332
333 let migrations = vec![
334 Migration {
335 version: "20260103000000".to_string(),
336 name: "create_a".to_string(),
337 up_sql: "CREATE TABLE a (id INTEGER PRIMARY KEY);".to_string(),
338 down_sql: "DROP TABLE a;".to_string(),
339 },
340 Migration {
341 version: "20260104000000".to_string(),
342 name: "create_b".to_string(),
343 up_sql: "CREATE TABLE b (id INTEGER PRIMARY KEY);".to_string(),
344 down_sql: "DROP TABLE b;".to_string(),
345 },
346 ];
347
348 migrator.run(migrations).await.unwrap();
349
350 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
351 .fetch_one(&pool)
352 .await
353 .unwrap();
354 assert_eq!(count, 2);
355 }
356
357 #[tokio::test]
358 async fn sqlite_migrator_rolls_back_last() {
359 let pool = SqlitePoolOptions::new()
360 .max_connections(1)
361 .connect("sqlite::memory:")
362 .await
363 .unwrap();
364 let migrator = Migrator::new(pool.clone());
365
366 let migrations = vec![
367 Migration {
368 version: "20260103000000".to_string(),
369 name: "create_a".to_string(),
370 up_sql: "CREATE TABLE a (id INTEGER PRIMARY KEY);".to_string(),
371 down_sql: "DROP TABLE a;".to_string(),
372 },
373 Migration {
374 version: "20260104000000".to_string(),
375 name: "create_b".to_string(),
376 up_sql: "CREATE TABLE b (id INTEGER PRIMARY KEY);".to_string(),
377 down_sql: "DROP TABLE b;".to_string(),
378 },
379 ];
380
381 migrator.run(migrations.clone()).await.unwrap();
382 let rolled_back = migrator.rollback_last(migrations).await.unwrap();
383 assert!(rolled_back);
384
385 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
386 .fetch_one(&pool)
387 .await
388 .unwrap();
389 assert_eq!(count, 1);
390 }
391
392 #[tokio::test]
393 async fn sqlite_migrator_rolls_back_on_error() {
394 let pool = SqlitePoolOptions::new()
395 .max_connections(1)
396 .connect("sqlite::memory:")
397 .await
398 .unwrap();
399 let migrator = Migrator::new(pool.clone());
400
401 let migrations = vec![Migration {
402 version: "20260102000000".to_string(),
403 name: "bad_sql".to_string(),
404 up_sql: "CREATE TABLE broken (id INTEGER PRIMARY KEY); INVALID SQL".to_string(),
405 down_sql: "DROP TABLE broken;".to_string(),
406 }];
407
408 let err = migrator.run(migrations).await.unwrap_err();
409 assert!(err.to_string().contains("syntax"));
410
411 let table: Option<String> = sqlx::query_scalar(
412 "SELECT name FROM sqlite_master WHERE type='table' AND name='_premix_migrations'",
413 )
414 .fetch_optional(&pool)
415 .await
416 .unwrap();
417 assert!(table.is_none());
418 }
419
420 #[cfg(feature = "postgres")]
421 #[tokio::test]
422 async fn postgres_migrator_applies_pending_once() {
423 let Some(pool) = pg_pool_or_skip().await else {
424 return;
425 };
426 let migrator = Migrator::new(pool.clone());
427
428 let suffix = SystemTime::now()
429 .duration_since(UNIX_EPOCH)
430 .unwrap()
431 .as_nanos();
432 let version = format!("20260101{:020}", suffix);
433 let table_name = format!("premix_mig_test_{}", suffix);
434 let migrations = vec![Migration {
435 version: version.clone(),
436 name: "create_test_table".to_string(),
437 up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_name),
438 down_sql: format!("DROP TABLE {};", table_name),
439 }];
440
441 migrator.run(migrations.clone()).await.unwrap();
442
443 let count: i64 =
444 sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
445 .bind(&version)
446 .fetch_one(&pool)
447 .await
448 .unwrap();
449 assert_eq!(count, 1);
450
451 migrator.run(migrations).await.unwrap();
452 let count_after: i64 =
453 sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
454 .bind(&version)
455 .fetch_one(&pool)
456 .await
457 .unwrap();
458 assert_eq!(count_after, 1);
459
460 let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_name))
461 .execute(&pool)
462 .await;
463 let _ = sqlx::query("DELETE FROM _premix_migrations WHERE version = $1")
464 .bind(&version)
465 .execute(&pool)
466 .await;
467 }
468
469 #[cfg(feature = "postgres")]
470 #[tokio::test]
471 async fn postgres_migrator_rolls_back_on_error() {
472 let Some(pool) = pg_pool_or_skip().await else {
473 return;
474 };
475 let migrator = Migrator::new(pool.clone());
476
477 let suffix = SystemTime::now()
478 .duration_since(UNIX_EPOCH)
479 .unwrap()
480 .as_nanos();
481 let version = format!("20260102{:020}", suffix);
482 let table_name = format!("premix_mig_bad_{}", suffix);
483 let migrations = vec![Migration {
484 version: version.clone(),
485 name: "bad_sql".to_string(),
486 up_sql: format!(
487 "CREATE TABLE {} (id SERIAL PRIMARY KEY); INVALID SQL",
488 table_name
489 ),
490 down_sql: format!("DROP TABLE {};", table_name),
491 }];
492
493 let err = migrator.run(migrations).await.unwrap_err();
494 assert!(err.to_string().contains("syntax"));
495
496 let table_exists: Option<String> =
497 sqlx::query_scalar("SELECT to_regclass('_premix_migrations')::text")
498 .fetch_one(&pool)
499 .await
500 .unwrap();
501 if table_exists.is_some() {
502 let count: i64 =
503 sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
504 .bind(&version)
505 .fetch_one(&pool)
506 .await
507 .unwrap();
508 assert_eq!(count, 0);
509 }
510
511 let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_name))
512 .execute(&pool)
513 .await;
514 }
515
516 #[cfg(feature = "postgres")]
517 #[tokio::test]
518 async fn postgres_migrator_rolls_back_last() {
519 let Some(pool) = pg_pool_or_skip().await else {
520 return;
521 };
522 let migrator = Migrator::new(pool.clone());
523
524 let suffix = SystemTime::now()
525 .duration_since(UNIX_EPOCH)
526 .unwrap()
527 .as_nanos();
528 let version_a = format!("20260103{:020}", suffix);
529 let version_b = format!("20260104{:020}", suffix);
530 let table_a = format!("premix_mig_a_{}", suffix);
531 let table_b = format!("premix_mig_b_{}", suffix);
532 let migrations = vec![
533 Migration {
534 version: version_a.clone(),
535 name: "create_a".to_string(),
536 up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_a),
537 down_sql: format!("DROP TABLE {};", table_a),
538 },
539 Migration {
540 version: version_b.clone(),
541 name: "create_b".to_string(),
542 up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_b),
543 down_sql: format!("DROP TABLE {};", table_b),
544 },
545 ];
546
547 migrator.run(migrations.clone()).await.unwrap();
548 let rolled_back = migrator.rollback_last(migrations).await.unwrap();
549 assert!(rolled_back);
550
551 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
552 .fetch_one(&pool)
553 .await
554 .unwrap();
555 assert!(count >= 1);
556
557 let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_a))
558 .execute(&pool)
559 .await;
560 let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_b))
561 .execute(&pool)
562 .await;
563 let _ = sqlx::query("DELETE FROM _premix_migrations WHERE version = $1 OR version = $2")
564 .bind(&version_a)
565 .bind(&version_b)
566 .execute(&pool)
567 .await;
568 }
569}