Skip to main content

premix_core/
migrator.rs

1use sqlx::{Database, Pool};
2
3/// A single database migration.
4#[derive(Debug, Clone)]
5pub struct Migration {
6    /// Unique version identifier for the migration (e.g., timestamp).
7    pub version: String,
8    /// Human-readable name for the migration.
9    pub name: String,
10    /// SQL statement to apply the migration.
11    pub up_sql: String,
12    /// SQL statement to revert the migration.
13    pub down_sql: String,
14}
15
16#[cfg_attr(not(any(feature = "sqlite", feature = "postgres")), allow(dead_code))]
17#[derive(Debug, Clone, sqlx::FromRow)]
18struct AppliedMigration {
19    version: String,
20}
21
22/// A migration manager for applying and rolling back migrations.
23#[cfg_attr(not(any(feature = "sqlite", feature = "postgres")), allow(dead_code))]
24pub struct Migrator<DB: Database> {
25    pool: Pool<DB>,
26}
27
28impl<DB: Database> std::fmt::Debug for Migrator<DB> {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("Migrator")
31            .field("pool", &"sqlx::Pool")
32            .finish()
33    }
34}
35
36impl<DB: Database> Migrator<DB> {
37    /// Creates a new `Migrator` instance using the provided connection pool.
38    /// Creates a new `Migrator` instance using the provided connection pool.
39    pub fn new(pool: Pool<DB>) -> Self {
40        Self { pool }
41    }
42}
43
44// Specialized implementations for SQLite (Feature-gated or trait-based later)
45// For Version 1, we'll try to use generic Executor where possible,
46// but creating the migrations table often requires dialect specific SQL.
47
48#[cfg(feature = "sqlite")]
49impl Migrator<sqlx::Sqlite> {
50    /// Executes all pending migrations.
51    pub async fn run(
52        &self,
53        migrations: Vec<Migration>,
54    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
55        let mut tx = self.pool.begin().await?;
56
57        // 1. Ensure Table Exists
58        sqlx::query(
59            "CREATE TABLE IF NOT EXISTS _premix_migrations (
60                version TEXT PRIMARY KEY,
61                name TEXT NOT NULL,
62                applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
63            )",
64        )
65        .execute(&mut *tx)
66        .await?;
67
68        // 2. Get Applied Versions
69        let applied_versions: Vec<String> = sqlx::query_as::<sqlx::Sqlite, AppliedMigration>(
70            "SELECT version FROM _premix_migrations ORDER BY version ASC",
71        )
72        .fetch_all(&mut *tx)
73        .await?
74        .into_iter()
75        .map(|m| m.version)
76        .collect();
77
78        // 3. Filter Pending
79        for migration in migrations {
80            if !applied_versions.contains(&migration.version) {
81                tracing::info!(
82                    operation = "migration_apply",
83                    version = %migration.version,
84                    name = %migration.name,
85                    "premix migration"
86                );
87                println!(
88                    "🚚 Applying migration: {} - {}",
89                    migration.version, migration.name
90                );
91
92                // Execute UP SQL
93                sqlx::query(&migration.up_sql).execute(&mut *tx).await?;
94
95                // Record Version
96                sqlx::query("INSERT INTO _premix_migrations (version, name) VALUES (?, ?)")
97                    .bind(&migration.version)
98                    .bind(&migration.name)
99                    .execute(&mut *tx)
100                    .await?;
101            }
102        }
103
104        tx.commit().await?;
105        Ok(())
106    }
107
108    /// Reverts the last applied migration.
109    pub async fn rollback_last(
110        &self,
111        migrations: Vec<Migration>,
112    ) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
113        let mut tx = self.pool.begin().await?;
114
115        sqlx::query(
116            "CREATE TABLE IF NOT EXISTS _premix_migrations (
117                version TEXT PRIMARY KEY,
118                name TEXT NOT NULL,
119                applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
120            )",
121        )
122        .execute(&mut *tx)
123        .await?;
124
125        let versions: Vec<String> = migrations.iter().map(|m| m.version.clone()).collect();
126        if versions.is_empty() {
127            tx.commit().await?;
128            return Ok(false);
129        }
130
131        let placeholders = vec!["?"; versions.len()].join(", ");
132        let sql = format!(
133            "SELECT version FROM _premix_migrations WHERE version IN ({}) ORDER BY version DESC LIMIT 1",
134            placeholders
135        );
136        let mut query = sqlx::query_scalar::<sqlx::Sqlite, String>(&sql);
137        for version in &versions {
138            query = query.bind(version);
139        }
140        let last = query.fetch_optional(&mut *tx).await?;
141
142        let Some(last) = last else {
143            tx.commit().await?;
144            return Ok(false);
145        };
146
147        let migration = migrations
148            .into_iter()
149            .find(|m| m.version == last)
150            .ok_or_else(|| format!("Migration {} not found.", last))?;
151
152        if migration.down_sql.trim().is_empty() {
153            return Err("Down migration is empty.".into());
154        }
155
156        tracing::info!(
157            operation = "migration_rollback",
158            version = %migration.version,
159            name = %migration.name,
160            "premix migration"
161        );
162        sqlx::query(&migration.down_sql).execute(&mut *tx).await?;
163        sqlx::query("DELETE FROM _premix_migrations WHERE version = ?")
164            .bind(&migration.version)
165            .execute(&mut *tx)
166            .await?;
167
168        tx.commit().await?;
169        Ok(true)
170    }
171}
172
173#[cfg(feature = "postgres")]
174impl Migrator<sqlx::Postgres> {
175    /// Executes all pending migrations.
176    pub async fn run(
177        &self,
178        migrations: Vec<Migration>,
179    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
180        let mut tx = self.pool.begin().await?;
181
182        // 1. Ensure Table Exists
183        sqlx::query(
184            "CREATE TABLE IF NOT EXISTS _premix_migrations (
185                version TEXT PRIMARY KEY,
186                name TEXT NOT NULL,
187                applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
188            )",
189        )
190        .execute(&mut *tx)
191        .await?;
192
193        // 2. Get Applied Versions
194        let applied_versions: Vec<String> = sqlx::query_as::<sqlx::Postgres, AppliedMigration>(
195            "SELECT version FROM _premix_migrations ORDER BY version ASC",
196        )
197        .fetch_all(&mut *tx)
198        .await?
199        .into_iter()
200        .map(|m| m.version)
201        .collect();
202
203        // 3. Filter Pending
204        for migration in migrations {
205            if !applied_versions.contains(&migration.version) {
206                tracing::info!(
207                    operation = "migration_apply",
208                    version = %migration.version,
209                    name = %migration.name,
210                    "premix migration"
211                );
212                println!(
213                    "🚚 Applying migration: {} - {}",
214                    migration.version, migration.name
215                );
216
217                // Execute UP SQL
218                sqlx::query(&migration.up_sql).execute(&mut *tx).await?;
219
220                // Record Version
221                sqlx::query("INSERT INTO _premix_migrations (version, name) VALUES ($1, $2)")
222                    .bind(&migration.version)
223                    .bind(&migration.name)
224                    .execute(&mut *tx)
225                    .await?;
226            }
227        }
228
229        tx.commit().await?;
230        Ok(())
231    }
232
233    /// Reverts the last applied migration.
234    pub async fn rollback_last(
235        &self,
236        migrations: Vec<Migration>,
237    ) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
238        let mut tx = self.pool.begin().await?;
239
240        sqlx::query(
241            "CREATE TABLE IF NOT EXISTS _premix_migrations (
242                version TEXT PRIMARY KEY,
243                name TEXT NOT NULL,
244                applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
245            )",
246        )
247        .execute(&mut *tx)
248        .await?;
249
250        let versions: Vec<String> = migrations.iter().map(|m| m.version.clone()).collect();
251        if versions.is_empty() {
252            tx.commit().await?;
253            return Ok(false);
254        }
255
256        let last = sqlx::query_scalar::<sqlx::Postgres, String>(
257            "SELECT version FROM _premix_migrations WHERE version = ANY($1) ORDER BY version DESC LIMIT 1",
258        )
259        .bind(&versions)
260        .fetch_optional(&mut *tx)
261        .await?;
262
263        let Some(last) = last else {
264            tx.commit().await?;
265            return Ok(false);
266        };
267
268        let migration = migrations
269            .into_iter()
270            .find(|m| m.version == last)
271            .ok_or_else(|| format!("Migration {} not found.", last))?;
272
273        if migration.down_sql.trim().is_empty() {
274            return Err("Down migration is empty.".into());
275        }
276
277        tracing::info!(
278            operation = "migration_rollback",
279            version = %migration.version,
280            name = %migration.name,
281            "premix migration"
282        );
283        sqlx::query(&migration.down_sql).execute(&mut *tx).await?;
284        sqlx::query("DELETE FROM _premix_migrations WHERE version = $1")
285            .bind(&migration.version)
286            .execute(&mut *tx)
287            .await?;
288
289        tx.commit().await?;
290        Ok(true)
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    #[cfg(feature = "postgres")]
297    use std::time::{SystemTime, UNIX_EPOCH};
298
299    use sqlx::sqlite::SqlitePoolOptions;
300
301    use super::*;
302
303    #[cfg(feature = "postgres")]
304    async fn pg_pool_or_skip() -> Option<sqlx::Pool<sqlx::Postgres>> {
305        let db_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| {
306            "postgres://postgres:admin123@localhost:5432/premix_bench".to_string()
307        });
308        sqlx::postgres::PgPoolOptions::new()
309            .max_connections(1)
310            .connect(&db_url)
311            .await
312            .ok()
313    }
314
315    #[tokio::test]
316    async fn sqlite_migrator_applies_pending_once() {
317        let pool = SqlitePoolOptions::new()
318            .max_connections(1)
319            .connect("sqlite::memory:")
320            .await
321            .unwrap();
322        let migrator = Migrator::new(pool.clone());
323
324        let migrations = vec![Migration {
325            version: "20260101000000".to_string(),
326            name: "create_users".to_string(),
327            up_sql: "CREATE TABLE users (id INTEGER PRIMARY KEY);".to_string(),
328            down_sql: "DROP TABLE users;".to_string(),
329        }];
330
331        migrator.run(migrations.clone()).await.unwrap();
332
333        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
334            .fetch_one(&pool)
335            .await
336            .unwrap();
337        assert_eq!(count, 1);
338
339        migrator.run(migrations).await.unwrap();
340        let count_after: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
341            .fetch_one(&pool)
342            .await
343            .unwrap();
344        assert_eq!(count_after, 1);
345    }
346
347    #[tokio::test]
348    async fn sqlite_migrator_applies_multiple() {
349        let pool = SqlitePoolOptions::new()
350            .max_connections(1)
351            .connect("sqlite::memory:")
352            .await
353            .unwrap();
354        let migrator = Migrator::new(pool.clone());
355
356        let migrations = vec![
357            Migration {
358                version: "20260103000000".to_string(),
359                name: "create_a".to_string(),
360                up_sql: "CREATE TABLE a (id INTEGER PRIMARY KEY);".to_string(),
361                down_sql: "DROP TABLE a;".to_string(),
362            },
363            Migration {
364                version: "20260104000000".to_string(),
365                name: "create_b".to_string(),
366                up_sql: "CREATE TABLE b (id INTEGER PRIMARY KEY);".to_string(),
367                down_sql: "DROP TABLE b;".to_string(),
368            },
369        ];
370
371        migrator.run(migrations).await.unwrap();
372
373        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
374            .fetch_one(&pool)
375            .await
376            .unwrap();
377        assert_eq!(count, 2);
378    }
379
380    #[tokio::test]
381    async fn sqlite_migrator_rolls_back_last() {
382        let pool = SqlitePoolOptions::new()
383            .max_connections(1)
384            .connect("sqlite::memory:")
385            .await
386            .unwrap();
387        let migrator = Migrator::new(pool.clone());
388
389        let migrations = vec![
390            Migration {
391                version: "20260103000000".to_string(),
392                name: "create_a".to_string(),
393                up_sql: "CREATE TABLE a (id INTEGER PRIMARY KEY);".to_string(),
394                down_sql: "DROP TABLE a;".to_string(),
395            },
396            Migration {
397                version: "20260104000000".to_string(),
398                name: "create_b".to_string(),
399                up_sql: "CREATE TABLE b (id INTEGER PRIMARY KEY);".to_string(),
400                down_sql: "DROP TABLE b;".to_string(),
401            },
402        ];
403
404        migrator.run(migrations.clone()).await.unwrap();
405        let rolled_back = migrator.rollback_last(migrations).await.unwrap();
406        assert!(rolled_back);
407
408        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
409            .fetch_one(&pool)
410            .await
411            .unwrap();
412        assert_eq!(count, 1);
413    }
414
415    #[tokio::test]
416    async fn sqlite_migrator_rolls_back_on_error() {
417        let pool = SqlitePoolOptions::new()
418            .max_connections(1)
419            .connect("sqlite::memory:")
420            .await
421            .unwrap();
422        let migrator = Migrator::new(pool.clone());
423
424        let migrations = vec![Migration {
425            version: "20260102000000".to_string(),
426            name: "bad_sql".to_string(),
427            up_sql: "CREATE TABLE broken (id INTEGER PRIMARY KEY); INVALID SQL".to_string(),
428            down_sql: "DROP TABLE broken;".to_string(),
429        }];
430
431        let err = migrator.run(migrations).await.unwrap_err();
432        assert!(err.to_string().contains("syntax"));
433
434        let table: Option<String> = sqlx::query_scalar(
435            "SELECT name FROM sqlite_master WHERE type='table' AND name='_premix_migrations'",
436        )
437        .fetch_optional(&pool)
438        .await
439        .unwrap();
440        assert!(table.is_none());
441    }
442
443    #[cfg(feature = "postgres")]
444    #[tokio::test]
445    async fn postgres_migrator_applies_pending_once() {
446        let Some(pool) = pg_pool_or_skip().await else {
447            return;
448        };
449        let migrator = Migrator::new(pool.clone());
450
451        let suffix = SystemTime::now()
452            .duration_since(UNIX_EPOCH)
453            .unwrap()
454            .as_nanos();
455        let version = format!("20260101{:020}", suffix);
456        let table_name = format!("premix_mig_test_{}", suffix);
457        let migrations = vec![Migration {
458            version: version.clone(),
459            name: "create_test_table".to_string(),
460            up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_name),
461            down_sql: format!("DROP TABLE {};", table_name),
462        }];
463
464        migrator.run(migrations.clone()).await.unwrap();
465
466        let count: i64 =
467            sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
468                .bind(&version)
469                .fetch_one(&pool)
470                .await
471                .unwrap();
472        assert_eq!(count, 1);
473
474        migrator.run(migrations).await.unwrap();
475        let count_after: i64 =
476            sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
477                .bind(&version)
478                .fetch_one(&pool)
479                .await
480                .unwrap();
481        assert_eq!(count_after, 1);
482
483        let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_name))
484            .execute(&pool)
485            .await;
486        let _ = sqlx::query("DELETE FROM _premix_migrations WHERE version = $1")
487            .bind(&version)
488            .execute(&pool)
489            .await;
490    }
491
492    #[cfg(feature = "postgres")]
493    #[tokio::test]
494    async fn postgres_migrator_rolls_back_on_error() {
495        let Some(pool) = pg_pool_or_skip().await else {
496            return;
497        };
498        let migrator = Migrator::new(pool.clone());
499
500        let suffix = SystemTime::now()
501            .duration_since(UNIX_EPOCH)
502            .unwrap()
503            .as_nanos();
504        let version = format!("20260102{:020}", suffix);
505        let table_name = format!("premix_mig_bad_{}", suffix);
506        let migrations = vec![Migration {
507            version: version.clone(),
508            name: "bad_sql".to_string(),
509            up_sql: format!(
510                "CREATE TABLE {} (id SERIAL PRIMARY KEY); INVALID SQL",
511                table_name
512            ),
513            down_sql: format!("DROP TABLE {};", table_name),
514        }];
515
516        let err = migrator.run(migrations).await.unwrap_err();
517        assert!(err.to_string().contains("syntax"));
518
519        let table_exists: Option<String> =
520            sqlx::query_scalar("SELECT to_regclass('_premix_migrations')::text")
521                .fetch_one(&pool)
522                .await
523                .unwrap();
524        if table_exists.is_some() {
525            let count: i64 =
526                sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
527                    .bind(&version)
528                    .fetch_one(&pool)
529                    .await
530                    .unwrap();
531            assert_eq!(count, 0);
532        }
533
534        let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_name))
535            .execute(&pool)
536            .await;
537    }
538
539    #[cfg(feature = "postgres")]
540    #[tokio::test]
541    async fn postgres_migrator_rolls_back_last() {
542        let Some(pool) = pg_pool_or_skip().await else {
543            return;
544        };
545        let migrator = Migrator::new(pool.clone());
546
547        let suffix = SystemTime::now()
548            .duration_since(UNIX_EPOCH)
549            .unwrap()
550            .as_nanos();
551        let version_a = format!("20260103{:020}", suffix);
552        let version_b = format!("20260104{:020}", suffix);
553        let table_a = format!("premix_mig_a_{}", suffix);
554        let table_b = format!("premix_mig_b_{}", suffix);
555        let migrations = vec![
556            Migration {
557                version: version_a.clone(),
558                name: "create_a".to_string(),
559                up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_a),
560                down_sql: format!("DROP TABLE {};", table_a),
561            },
562            Migration {
563                version: version_b.clone(),
564                name: "create_b".to_string(),
565                up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_b),
566                down_sql: format!("DROP TABLE {};", table_b),
567            },
568        ];
569
570        migrator.run(migrations.clone()).await.unwrap();
571        let rolled_back = migrator.rollback_last(migrations).await.unwrap();
572        assert!(rolled_back);
573
574        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
575            .fetch_one(&pool)
576            .await
577            .unwrap();
578        assert!(count >= 1);
579
580        let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_a))
581            .execute(&pool)
582            .await;
583        let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_b))
584            .execute(&pool)
585            .await;
586        let _ = sqlx::query("DELETE FROM _premix_migrations WHERE version = $1 OR version = $2")
587            .bind(&version_a)
588            .bind(&version_b)
589            .execute(&pool)
590            .await;
591    }
592}