1use sqlx::{Database, Pool};
2
3#[derive(Debug, Clone)]
5pub struct Migration {
6 pub version: String,
8 pub name: String,
10 pub up_sql: String,
12 pub down_sql: String,
14}
15
16#[cfg_attr(not(any(feature = "sqlite", feature = "postgres")), allow(dead_code))]
17#[derive(Debug, Clone, sqlx::FromRow)]
18struct AppliedMigration {
19 version: String,
20}
21
22#[cfg_attr(not(any(feature = "sqlite", feature = "postgres")), allow(dead_code))]
24pub struct Migrator<DB: Database> {
25 pool: Pool<DB>,
26}
27
28impl<DB: Database> std::fmt::Debug for Migrator<DB> {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("Migrator")
31 .field("pool", &"sqlx::Pool")
32 .finish()
33 }
34}
35
36impl<DB: Database> Migrator<DB> {
37 pub fn new(pool: Pool<DB>) -> Self {
40 Self { pool }
41 }
42}
43
44#[cfg(feature = "sqlite")]
49impl Migrator<sqlx::Sqlite> {
50 pub async fn run(
52 &self,
53 migrations: Vec<Migration>,
54 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
55 let mut tx = self.pool.begin().await?;
56
57 sqlx::query(
59 "CREATE TABLE IF NOT EXISTS _premix_migrations (
60 version TEXT PRIMARY KEY,
61 name TEXT NOT NULL,
62 applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
63 )",
64 )
65 .execute(&mut *tx)
66 .await?;
67
68 let applied_versions: Vec<String> = sqlx::query_as::<sqlx::Sqlite, AppliedMigration>(
70 "SELECT version FROM _premix_migrations ORDER BY version ASC",
71 )
72 .fetch_all(&mut *tx)
73 .await?
74 .into_iter()
75 .map(|m| m.version)
76 .collect();
77
78 for migration in migrations {
80 if !applied_versions.contains(&migration.version) {
81 tracing::info!(
82 operation = "migration_apply",
83 version = %migration.version,
84 name = %migration.name,
85 "premix migration"
86 );
87 println!(
88 "🚚 Applying migration: {} - {}",
89 migration.version, migration.name
90 );
91
92 sqlx::query(&migration.up_sql).execute(&mut *tx).await?;
94
95 sqlx::query("INSERT INTO _premix_migrations (version, name) VALUES (?, ?)")
97 .bind(&migration.version)
98 .bind(&migration.name)
99 .execute(&mut *tx)
100 .await?;
101 }
102 }
103
104 tx.commit().await?;
105 Ok(())
106 }
107
108 pub async fn rollback_last(
110 &self,
111 migrations: Vec<Migration>,
112 ) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
113 let mut tx = self.pool.begin().await?;
114
115 sqlx::query(
116 "CREATE TABLE IF NOT EXISTS _premix_migrations (
117 version TEXT PRIMARY KEY,
118 name TEXT NOT NULL,
119 applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
120 )",
121 )
122 .execute(&mut *tx)
123 .await?;
124
125 let versions: Vec<String> = migrations.iter().map(|m| m.version.clone()).collect();
126 if versions.is_empty() {
127 tx.commit().await?;
128 return Ok(false);
129 }
130
131 let placeholders = vec!["?"; versions.len()].join(", ");
132 let sql = format!(
133 "SELECT version FROM _premix_migrations WHERE version IN ({}) ORDER BY version DESC LIMIT 1",
134 placeholders
135 );
136 let mut query = sqlx::query_scalar::<sqlx::Sqlite, String>(&sql);
137 for version in &versions {
138 query = query.bind(version);
139 }
140 let last = query.fetch_optional(&mut *tx).await?;
141
142 let Some(last) = last else {
143 tx.commit().await?;
144 return Ok(false);
145 };
146
147 let migration = migrations
148 .into_iter()
149 .find(|m| m.version == last)
150 .ok_or_else(|| format!("Migration {} not found.", last))?;
151
152 if migration.down_sql.trim().is_empty() {
153 return Err("Down migration is empty.".into());
154 }
155
156 tracing::info!(
157 operation = "migration_rollback",
158 version = %migration.version,
159 name = %migration.name,
160 "premix migration"
161 );
162 sqlx::query(&migration.down_sql).execute(&mut *tx).await?;
163 sqlx::query("DELETE FROM _premix_migrations WHERE version = ?")
164 .bind(&migration.version)
165 .execute(&mut *tx)
166 .await?;
167
168 tx.commit().await?;
169 Ok(true)
170 }
171}
172
173#[cfg(feature = "postgres")]
174impl Migrator<sqlx::Postgres> {
175 pub async fn run(
177 &self,
178 migrations: Vec<Migration>,
179 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
180 let mut tx = self.pool.begin().await?;
181
182 sqlx::query(
184 "CREATE TABLE IF NOT EXISTS _premix_migrations (
185 version TEXT PRIMARY KEY,
186 name TEXT NOT NULL,
187 applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
188 )",
189 )
190 .execute(&mut *tx)
191 .await?;
192
193 let applied_versions: Vec<String> = sqlx::query_as::<sqlx::Postgres, AppliedMigration>(
195 "SELECT version FROM _premix_migrations ORDER BY version ASC",
196 )
197 .fetch_all(&mut *tx)
198 .await?
199 .into_iter()
200 .map(|m| m.version)
201 .collect();
202
203 for migration in migrations {
205 if !applied_versions.contains(&migration.version) {
206 tracing::info!(
207 operation = "migration_apply",
208 version = %migration.version,
209 name = %migration.name,
210 "premix migration"
211 );
212 println!(
213 "🚚 Applying migration: {} - {}",
214 migration.version, migration.name
215 );
216
217 sqlx::query(&migration.up_sql).execute(&mut *tx).await?;
219
220 sqlx::query("INSERT INTO _premix_migrations (version, name) VALUES ($1, $2)")
222 .bind(&migration.version)
223 .bind(&migration.name)
224 .execute(&mut *tx)
225 .await?;
226 }
227 }
228
229 tx.commit().await?;
230 Ok(())
231 }
232
233 pub async fn rollback_last(
235 &self,
236 migrations: Vec<Migration>,
237 ) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
238 let mut tx = self.pool.begin().await?;
239
240 sqlx::query(
241 "CREATE TABLE IF NOT EXISTS _premix_migrations (
242 version TEXT PRIMARY KEY,
243 name TEXT NOT NULL,
244 applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
245 )",
246 )
247 .execute(&mut *tx)
248 .await?;
249
250 let versions: Vec<String> = migrations.iter().map(|m| m.version.clone()).collect();
251 if versions.is_empty() {
252 tx.commit().await?;
253 return Ok(false);
254 }
255
256 let last = sqlx::query_scalar::<sqlx::Postgres, String>(
257 "SELECT version FROM _premix_migrations WHERE version = ANY($1) ORDER BY version DESC LIMIT 1",
258 )
259 .bind(&versions)
260 .fetch_optional(&mut *tx)
261 .await?;
262
263 let Some(last) = last else {
264 tx.commit().await?;
265 return Ok(false);
266 };
267
268 let migration = migrations
269 .into_iter()
270 .find(|m| m.version == last)
271 .ok_or_else(|| format!("Migration {} not found.", last))?;
272
273 if migration.down_sql.trim().is_empty() {
274 return Err("Down migration is empty.".into());
275 }
276
277 tracing::info!(
278 operation = "migration_rollback",
279 version = %migration.version,
280 name = %migration.name,
281 "premix migration"
282 );
283 sqlx::query(&migration.down_sql).execute(&mut *tx).await?;
284 sqlx::query("DELETE FROM _premix_migrations WHERE version = $1")
285 .bind(&migration.version)
286 .execute(&mut *tx)
287 .await?;
288
289 tx.commit().await?;
290 Ok(true)
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 #[cfg(feature = "postgres")]
297 use std::time::{SystemTime, UNIX_EPOCH};
298
299 use sqlx::sqlite::SqlitePoolOptions;
300
301 use super::*;
302
303 #[cfg(feature = "postgres")]
304 async fn pg_pool_or_skip() -> Option<sqlx::Pool<sqlx::Postgres>> {
305 let db_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| {
306 "postgres://postgres:admin123@localhost:5432/premix_bench".to_string()
307 });
308 sqlx::postgres::PgPoolOptions::new()
309 .max_connections(1)
310 .connect(&db_url)
311 .await
312 .ok()
313 }
314
315 #[tokio::test]
316 async fn sqlite_migrator_applies_pending_once() {
317 let pool = SqlitePoolOptions::new()
318 .max_connections(1)
319 .connect("sqlite::memory:")
320 .await
321 .unwrap();
322 let migrator = Migrator::new(pool.clone());
323
324 let migrations = vec![Migration {
325 version: "20260101000000".to_string(),
326 name: "create_users".to_string(),
327 up_sql: "CREATE TABLE users (id INTEGER PRIMARY KEY);".to_string(),
328 down_sql: "DROP TABLE users;".to_string(),
329 }];
330
331 migrator.run(migrations.clone()).await.unwrap();
332
333 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
334 .fetch_one(&pool)
335 .await
336 .unwrap();
337 assert_eq!(count, 1);
338
339 migrator.run(migrations).await.unwrap();
340 let count_after: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
341 .fetch_one(&pool)
342 .await
343 .unwrap();
344 assert_eq!(count_after, 1);
345 }
346
347 #[tokio::test]
348 async fn sqlite_migrator_applies_multiple() {
349 let pool = SqlitePoolOptions::new()
350 .max_connections(1)
351 .connect("sqlite::memory:")
352 .await
353 .unwrap();
354 let migrator = Migrator::new(pool.clone());
355
356 let migrations = vec![
357 Migration {
358 version: "20260103000000".to_string(),
359 name: "create_a".to_string(),
360 up_sql: "CREATE TABLE a (id INTEGER PRIMARY KEY);".to_string(),
361 down_sql: "DROP TABLE a;".to_string(),
362 },
363 Migration {
364 version: "20260104000000".to_string(),
365 name: "create_b".to_string(),
366 up_sql: "CREATE TABLE b (id INTEGER PRIMARY KEY);".to_string(),
367 down_sql: "DROP TABLE b;".to_string(),
368 },
369 ];
370
371 migrator.run(migrations).await.unwrap();
372
373 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
374 .fetch_one(&pool)
375 .await
376 .unwrap();
377 assert_eq!(count, 2);
378 }
379
380 #[tokio::test]
381 async fn sqlite_migrator_rolls_back_last() {
382 let pool = SqlitePoolOptions::new()
383 .max_connections(1)
384 .connect("sqlite::memory:")
385 .await
386 .unwrap();
387 let migrator = Migrator::new(pool.clone());
388
389 let migrations = vec![
390 Migration {
391 version: "20260103000000".to_string(),
392 name: "create_a".to_string(),
393 up_sql: "CREATE TABLE a (id INTEGER PRIMARY KEY);".to_string(),
394 down_sql: "DROP TABLE a;".to_string(),
395 },
396 Migration {
397 version: "20260104000000".to_string(),
398 name: "create_b".to_string(),
399 up_sql: "CREATE TABLE b (id INTEGER PRIMARY KEY);".to_string(),
400 down_sql: "DROP TABLE b;".to_string(),
401 },
402 ];
403
404 migrator.run(migrations.clone()).await.unwrap();
405 let rolled_back = migrator.rollback_last(migrations).await.unwrap();
406 assert!(rolled_back);
407
408 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
409 .fetch_one(&pool)
410 .await
411 .unwrap();
412 assert_eq!(count, 1);
413 }
414
415 #[tokio::test]
416 async fn sqlite_migrator_rolls_back_on_error() {
417 let pool = SqlitePoolOptions::new()
418 .max_connections(1)
419 .connect("sqlite::memory:")
420 .await
421 .unwrap();
422 let migrator = Migrator::new(pool.clone());
423
424 let migrations = vec![Migration {
425 version: "20260102000000".to_string(),
426 name: "bad_sql".to_string(),
427 up_sql: "CREATE TABLE broken (id INTEGER PRIMARY KEY); INVALID SQL".to_string(),
428 down_sql: "DROP TABLE broken;".to_string(),
429 }];
430
431 let err = migrator.run(migrations).await.unwrap_err();
432 assert!(err.to_string().contains("syntax"));
433
434 let table: Option<String> = sqlx::query_scalar(
435 "SELECT name FROM sqlite_master WHERE type='table' AND name='_premix_migrations'",
436 )
437 .fetch_optional(&pool)
438 .await
439 .unwrap();
440 assert!(table.is_none());
441 }
442
443 #[cfg(feature = "postgres")]
444 #[tokio::test]
445 async fn postgres_migrator_applies_pending_once() {
446 let Some(pool) = pg_pool_or_skip().await else {
447 return;
448 };
449 let migrator = Migrator::new(pool.clone());
450
451 let suffix = SystemTime::now()
452 .duration_since(UNIX_EPOCH)
453 .unwrap()
454 .as_nanos();
455 let version = format!("20260101{:020}", suffix);
456 let table_name = format!("premix_mig_test_{}", suffix);
457 let migrations = vec![Migration {
458 version: version.clone(),
459 name: "create_test_table".to_string(),
460 up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_name),
461 down_sql: format!("DROP TABLE {};", table_name),
462 }];
463
464 migrator.run(migrations.clone()).await.unwrap();
465
466 let count: i64 =
467 sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
468 .bind(&version)
469 .fetch_one(&pool)
470 .await
471 .unwrap();
472 assert_eq!(count, 1);
473
474 migrator.run(migrations).await.unwrap();
475 let count_after: i64 =
476 sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
477 .bind(&version)
478 .fetch_one(&pool)
479 .await
480 .unwrap();
481 assert_eq!(count_after, 1);
482
483 let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_name))
484 .execute(&pool)
485 .await;
486 let _ = sqlx::query("DELETE FROM _premix_migrations WHERE version = $1")
487 .bind(&version)
488 .execute(&pool)
489 .await;
490 }
491
492 #[cfg(feature = "postgres")]
493 #[tokio::test]
494 async fn postgres_migrator_rolls_back_on_error() {
495 let Some(pool) = pg_pool_or_skip().await else {
496 return;
497 };
498 let migrator = Migrator::new(pool.clone());
499
500 let suffix = SystemTime::now()
501 .duration_since(UNIX_EPOCH)
502 .unwrap()
503 .as_nanos();
504 let version = format!("20260102{:020}", suffix);
505 let table_name = format!("premix_mig_bad_{}", suffix);
506 let migrations = vec![Migration {
507 version: version.clone(),
508 name: "bad_sql".to_string(),
509 up_sql: format!(
510 "CREATE TABLE {} (id SERIAL PRIMARY KEY); INVALID SQL",
511 table_name
512 ),
513 down_sql: format!("DROP TABLE {};", table_name),
514 }];
515
516 let err = migrator.run(migrations).await.unwrap_err();
517 assert!(err.to_string().contains("syntax"));
518
519 let table_exists: Option<String> =
520 sqlx::query_scalar("SELECT to_regclass('_premix_migrations')::text")
521 .fetch_one(&pool)
522 .await
523 .unwrap();
524 if table_exists.is_some() {
525 let count: i64 =
526 sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
527 .bind(&version)
528 .fetch_one(&pool)
529 .await
530 .unwrap();
531 assert_eq!(count, 0);
532 }
533
534 let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_name))
535 .execute(&pool)
536 .await;
537 }
538
539 #[cfg(feature = "postgres")]
540 #[tokio::test]
541 async fn postgres_migrator_rolls_back_last() {
542 let Some(pool) = pg_pool_or_skip().await else {
543 return;
544 };
545 let migrator = Migrator::new(pool.clone());
546
547 let suffix = SystemTime::now()
548 .duration_since(UNIX_EPOCH)
549 .unwrap()
550 .as_nanos();
551 let version_a = format!("20260103{:020}", suffix);
552 let version_b = format!("20260104{:020}", suffix);
553 let table_a = format!("premix_mig_a_{}", suffix);
554 let table_b = format!("premix_mig_b_{}", suffix);
555 let migrations = vec![
556 Migration {
557 version: version_a.clone(),
558 name: "create_a".to_string(),
559 up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_a),
560 down_sql: format!("DROP TABLE {};", table_a),
561 },
562 Migration {
563 version: version_b.clone(),
564 name: "create_b".to_string(),
565 up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_b),
566 down_sql: format!("DROP TABLE {};", table_b),
567 },
568 ];
569
570 migrator.run(migrations.clone()).await.unwrap();
571 let rolled_back = migrator.rollback_last(migrations).await.unwrap();
572 assert!(rolled_back);
573
574 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
575 .fetch_one(&pool)
576 .await
577 .unwrap();
578 assert!(count >= 1);
579
580 let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_a))
581 .execute(&pool)
582 .await;
583 let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_b))
584 .execute(&pool)
585 .await;
586 let _ = sqlx::query("DELETE FROM _premix_migrations WHERE version = $1 OR version = $2")
587 .bind(&version_a)
588 .bind(&version_b)
589 .execute(&pool)
590 .await;
591 }
592}