premix_core/
migrator.rs

1use std::error::Error;
2
3use sqlx::{Database, Executor, Pool};
4
5#[derive(Debug, Clone)]
6pub struct Migration {
7    pub version: String,
8    pub name: String,
9    pub up_sql: String,
10    pub down_sql: String,
11}
12
13#[derive(Debug, Clone, sqlx::FromRow)]
14struct AppliedMigration {
15    version: String,
16}
17
18pub struct Migrator<DB: Database> {
19    pool: Pool<DB>,
20}
21
22impl<DB: Database> Migrator<DB> {
23    pub fn new(pool: Pool<DB>) -> Self {
24        Self { pool }
25    }
26}
27
28// Specialized implementations for SQLite (Feature-gated or trait-based later)
29// For Version 1, we'll try to use generic Executor where possible,
30// but creating the migrations table often requires dialect specific SQL.
31
32#[cfg(feature = "sqlite")]
33impl Migrator<sqlx::Sqlite> {
34    pub async fn run(&self, migrations: Vec<Migration>) -> Result<(), Box<dyn Error>> {
35        let mut tx = self.pool.begin().await?;
36
37        // 1. Ensure Table Exists
38        sqlx::query(
39            "CREATE TABLE IF NOT EXISTS _premix_migrations (
40                version TEXT PRIMARY KEY,
41                name TEXT NOT NULL,
42                applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
43            )",
44        )
45        .execute(&mut *tx)
46        .await?;
47
48        // 2. Get Applied Versions
49        let applied_versions: Vec<String> = sqlx::query_as::<_, AppliedMigration>(
50            "SELECT version FROM _premix_migrations ORDER BY version ASC",
51        )
52        .fetch_all(&mut *tx)
53        .await?
54        .into_iter()
55        .map(|m| m.version)
56        .collect();
57
58        // 3. Filter Pending
59        for migration in migrations {
60            if !applied_versions.contains(&migration.version) {
61                println!(
62                    "🚚 Applying migration: {} - {}",
63                    migration.version, migration.name
64                );
65
66                // Execute UP SQL
67                tx.execute(migration.up_sql.as_str()).await?;
68
69                // Record Version
70                sqlx::query("INSERT INTO _premix_migrations (version, name) VALUES (?, ?)")
71                    .bind(&migration.version)
72                    .bind(&migration.name)
73                    .execute(&mut *tx)
74                    .await?;
75            }
76        }
77
78        tx.commit().await?;
79        Ok(())
80    }
81
82    pub async fn rollback_last(&self, migrations: Vec<Migration>) -> Result<bool, Box<dyn Error>> {
83        let mut tx = self.pool.begin().await?;
84
85        sqlx::query(
86            "CREATE TABLE IF NOT EXISTS _premix_migrations (
87                version TEXT PRIMARY KEY,
88                name TEXT NOT NULL,
89                applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
90            )",
91        )
92        .execute(&mut *tx)
93        .await?;
94
95        let versions: Vec<String> = migrations.iter().map(|m| m.version.clone()).collect();
96        if versions.is_empty() {
97            tx.commit().await?;
98            return Ok(false);
99        }
100
101        let placeholders = vec!["?"; versions.len()].join(", ");
102        let sql = format!(
103            "SELECT version FROM _premix_migrations WHERE version IN ({}) ORDER BY version DESC LIMIT 1",
104            placeholders
105        );
106        let mut query = sqlx::query_scalar::<_, String>(&sql);
107        for version in &versions {
108            query = query.bind(version);
109        }
110        let last = query.fetch_optional(&mut *tx).await?;
111
112        let Some(last) = last else {
113            tx.commit().await?;
114            return Ok(false);
115        };
116
117        let migration = migrations
118            .into_iter()
119            .find(|m| m.version == last)
120            .ok_or_else(|| format!("Migration {} not found.", last))?;
121
122        if migration.down_sql.trim().is_empty() {
123            return Err("Down migration is empty.".into());
124        }
125
126        tx.execute(migration.down_sql.as_str()).await?;
127        sqlx::query("DELETE FROM _premix_migrations WHERE version = ?")
128            .bind(&migration.version)
129            .execute(&mut *tx)
130            .await?;
131
132        tx.commit().await?;
133        Ok(true)
134    }
135}
136
137#[cfg(feature = "postgres")]
138impl Migrator<sqlx::Postgres> {
139    pub async fn run(&self, migrations: Vec<Migration>) -> Result<(), Box<dyn Error>> {
140        let mut tx = self.pool.begin().await?;
141
142        // 1. Ensure Table Exists
143        sqlx::query(
144            "CREATE TABLE IF NOT EXISTS _premix_migrations (
145                version TEXT PRIMARY KEY,
146                name TEXT NOT NULL,
147                applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
148            )",
149        )
150        .execute(&mut *tx)
151        .await?;
152
153        // 2. Get Applied Versions
154        let applied_versions: Vec<String> = sqlx::query_as::<_, AppliedMigration>(
155            "SELECT version FROM _premix_migrations ORDER BY version ASC",
156        )
157        .fetch_all(&mut *tx)
158        .await?
159        .into_iter()
160        .map(|m| m.version)
161        .collect();
162
163        // 3. Filter Pending
164        for migration in migrations {
165            if !applied_versions.contains(&migration.version) {
166                println!(
167                    "🚚 Applying migration: {} - {}",
168                    migration.version, migration.name
169                );
170
171                // Execute UP SQL
172                // Note: splitting by ; might be needed for multiple statements in one file
173                // But for MVP we assume sqlx can handle the string block or user separates properly.
174                // sqlx::execute only runs the first statement for some drivers,
175                // but Executor::execute roughly maps to running the query.
176                // For safety in Postgres with multiple statements, simple Executor::execute might fail if not wrapped or specific support.
177                // We'll trust user provides valid script block for now.
178                tx.execute(migration.up_sql.as_str()).await?;
179
180                // Record Version
181                sqlx::query("INSERT INTO _premix_migrations (version, name) VALUES ($1, $2)")
182                    .bind(&migration.version)
183                    .bind(&migration.name)
184                    .execute(&mut *tx)
185                    .await?;
186            }
187        }
188
189        tx.commit().await?;
190        Ok(())
191    }
192
193    pub async fn rollback_last(&self, migrations: Vec<Migration>) -> Result<bool, Box<dyn Error>> {
194        let mut tx = self.pool.begin().await?;
195
196        sqlx::query(
197            "CREATE TABLE IF NOT EXISTS _premix_migrations (
198                version TEXT PRIMARY KEY,
199                name TEXT NOT NULL,
200                applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
201            )",
202        )
203        .execute(&mut *tx)
204        .await?;
205
206        let versions: Vec<String> = migrations.iter().map(|m| m.version.clone()).collect();
207        if versions.is_empty() {
208            tx.commit().await?;
209            return Ok(false);
210        }
211
212        let last = sqlx::query_scalar::<_, String>(
213            "SELECT version FROM _premix_migrations WHERE version = ANY($1) ORDER BY version DESC LIMIT 1",
214        )
215        .bind(&versions)
216        .fetch_optional(&mut *tx)
217        .await?;
218
219        let Some(last) = last else {
220            tx.commit().await?;
221            return Ok(false);
222        };
223
224        let migration = migrations
225            .into_iter()
226            .find(|m| m.version == last)
227            .ok_or_else(|| format!("Migration {} not found.", last))?;
228
229        if migration.down_sql.trim().is_empty() {
230            return Err("Down migration is empty.".into());
231        }
232
233        tx.execute(migration.down_sql.as_str()).await?;
234        sqlx::query("DELETE FROM _premix_migrations WHERE version = $1")
235            .bind(&migration.version)
236            .execute(&mut *tx)
237            .await?;
238
239        tx.commit().await?;
240        Ok(true)
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    #[cfg(feature = "postgres")]
247    use std::time::{SystemTime, UNIX_EPOCH};
248
249    use sqlx::sqlite::SqlitePoolOptions;
250
251    use super::*;
252
253    #[cfg(feature = "postgres")]
254    async fn pg_pool_or_skip() -> Option<sqlx::Pool<sqlx::Postgres>> {
255        let db_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| {
256            "postgres://postgres:admin123@localhost:5432/premix_bench".to_string()
257        });
258        match sqlx::postgres::PgPoolOptions::new()
259            .max_connections(1)
260            .connect(&db_url)
261            .await
262        {
263            Ok(pool) => Some(pool),
264            Err(_) => None,
265        }
266    }
267
268    #[tokio::test]
269    async fn sqlite_migrator_applies_pending_once() {
270        let pool = SqlitePoolOptions::new()
271            .max_connections(1)
272            .connect("sqlite::memory:")
273            .await
274            .unwrap();
275        let migrator = Migrator::new(pool.clone());
276
277        let migrations = vec![Migration {
278            version: "20260101000000".to_string(),
279            name: "create_users".to_string(),
280            up_sql: "CREATE TABLE users (id INTEGER PRIMARY KEY);".to_string(),
281            down_sql: "DROP TABLE users;".to_string(),
282        }];
283
284        migrator.run(migrations.clone()).await.unwrap();
285
286        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
287            .fetch_one(&pool)
288            .await
289            .unwrap();
290        assert_eq!(count, 1);
291
292        migrator.run(migrations).await.unwrap();
293        let count_after: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
294            .fetch_one(&pool)
295            .await
296            .unwrap();
297        assert_eq!(count_after, 1);
298    }
299
300    #[tokio::test]
301    async fn sqlite_migrator_applies_multiple() {
302        let pool = SqlitePoolOptions::new()
303            .max_connections(1)
304            .connect("sqlite::memory:")
305            .await
306            .unwrap();
307        let migrator = Migrator::new(pool.clone());
308
309        let migrations = vec![
310            Migration {
311                version: "20260103000000".to_string(),
312                name: "create_a".to_string(),
313                up_sql: "CREATE TABLE a (id INTEGER PRIMARY KEY);".to_string(),
314                down_sql: "DROP TABLE a;".to_string(),
315            },
316            Migration {
317                version: "20260104000000".to_string(),
318                name: "create_b".to_string(),
319                up_sql: "CREATE TABLE b (id INTEGER PRIMARY KEY);".to_string(),
320                down_sql: "DROP TABLE b;".to_string(),
321            },
322        ];
323
324        migrator.run(migrations).await.unwrap();
325
326        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
327            .fetch_one(&pool)
328            .await
329            .unwrap();
330        assert_eq!(count, 2);
331    }
332
333    #[tokio::test]
334    async fn sqlite_migrator_rolls_back_last() {
335        let pool = SqlitePoolOptions::new()
336            .max_connections(1)
337            .connect("sqlite::memory:")
338            .await
339            .unwrap();
340        let migrator = Migrator::new(pool.clone());
341
342        let migrations = vec![
343            Migration {
344                version: "20260103000000".to_string(),
345                name: "create_a".to_string(),
346                up_sql: "CREATE TABLE a (id INTEGER PRIMARY KEY);".to_string(),
347                down_sql: "DROP TABLE a;".to_string(),
348            },
349            Migration {
350                version: "20260104000000".to_string(),
351                name: "create_b".to_string(),
352                up_sql: "CREATE TABLE b (id INTEGER PRIMARY KEY);".to_string(),
353                down_sql: "DROP TABLE b;".to_string(),
354            },
355        ];
356
357        migrator.run(migrations.clone()).await.unwrap();
358        let rolled_back = migrator.rollback_last(migrations).await.unwrap();
359        assert!(rolled_back);
360
361        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
362            .fetch_one(&pool)
363            .await
364            .unwrap();
365        assert_eq!(count, 1);
366    }
367
368    #[tokio::test]
369    async fn sqlite_migrator_rolls_back_on_error() {
370        let pool = SqlitePoolOptions::new()
371            .max_connections(1)
372            .connect("sqlite::memory:")
373            .await
374            .unwrap();
375        let migrator = Migrator::new(pool.clone());
376
377        let migrations = vec![Migration {
378            version: "20260102000000".to_string(),
379            name: "bad_sql".to_string(),
380            up_sql: "CREATE TABLE broken (id INTEGER PRIMARY KEY); INVALID SQL".to_string(),
381            down_sql: "DROP TABLE broken;".to_string(),
382        }];
383
384        let err = migrator.run(migrations).await.unwrap_err();
385        assert!(err.to_string().contains("syntax"));
386
387        let table: Option<String> = sqlx::query_scalar(
388            "SELECT name FROM sqlite_master WHERE type='table' AND name='_premix_migrations'",
389        )
390        .fetch_optional(&pool)
391        .await
392        .unwrap();
393        assert!(table.is_none());
394    }
395
396    #[cfg(feature = "postgres")]
397    #[tokio::test]
398    async fn postgres_migrator_applies_pending_once() {
399        let Some(pool) = pg_pool_or_skip().await else {
400            return;
401        };
402        let migrator = Migrator::new(pool.clone());
403
404        let suffix = SystemTime::now()
405            .duration_since(UNIX_EPOCH)
406            .unwrap()
407            .as_nanos();
408        let version = format!("20260101{:020}", suffix);
409        let table_name = format!("premix_mig_test_{}", suffix);
410        let migrations = vec![Migration {
411            version: version.clone(),
412            name: "create_test_table".to_string(),
413            up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_name),
414            down_sql: format!("DROP TABLE {};", table_name),
415        }];
416
417        migrator.run(migrations.clone()).await.unwrap();
418
419        let count: i64 =
420            sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
421                .bind(&version)
422                .fetch_one(&pool)
423                .await
424                .unwrap();
425        assert_eq!(count, 1);
426
427        migrator.run(migrations).await.unwrap();
428        let count_after: i64 =
429            sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
430                .bind(&version)
431                .fetch_one(&pool)
432                .await
433                .unwrap();
434        assert_eq!(count_after, 1);
435
436        let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_name))
437            .execute(&pool)
438            .await;
439        let _ = sqlx::query("DELETE FROM _premix_migrations WHERE version = $1")
440            .bind(&version)
441            .execute(&pool)
442            .await;
443    }
444
445    #[cfg(feature = "postgres")]
446    #[tokio::test]
447    async fn postgres_migrator_rolls_back_on_error() {
448        let Some(pool) = pg_pool_or_skip().await else {
449            return;
450        };
451        let migrator = Migrator::new(pool.clone());
452
453        let suffix = SystemTime::now()
454            .duration_since(UNIX_EPOCH)
455            .unwrap()
456            .as_nanos();
457        let version = format!("20260102{:020}", suffix);
458        let table_name = format!("premix_mig_bad_{}", suffix);
459        let migrations = vec![Migration {
460            version: version.clone(),
461            name: "bad_sql".to_string(),
462            up_sql: format!(
463                "CREATE TABLE {} (id SERIAL PRIMARY KEY); INVALID SQL",
464                table_name
465            ),
466            down_sql: format!("DROP TABLE {};", table_name),
467        }];
468
469        let err = migrator.run(migrations).await.unwrap_err();
470        assert!(err.to_string().contains("syntax"));
471
472        let table_exists: Option<String> =
473            sqlx::query_scalar("SELECT to_regclass('_premix_migrations')::text")
474                .fetch_one(&pool)
475                .await
476                .unwrap();
477        if table_exists.is_some() {
478            let count: i64 =
479                sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
480                    .bind(&version)
481                    .fetch_one(&pool)
482                    .await
483                    .unwrap();
484            assert_eq!(count, 0);
485        }
486
487        let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_name))
488            .execute(&pool)
489            .await;
490    }
491
492    #[cfg(feature = "postgres")]
493    #[tokio::test]
494    async fn postgres_migrator_rolls_back_last() {
495        let Some(pool) = pg_pool_or_skip().await else {
496            return;
497        };
498        let migrator = Migrator::new(pool.clone());
499
500        let suffix = SystemTime::now()
501            .duration_since(UNIX_EPOCH)
502            .unwrap()
503            .as_nanos();
504        let version_a = format!("20260103{:020}", suffix);
505        let version_b = format!("20260104{:020}", suffix);
506        let table_a = format!("premix_mig_a_{}", suffix);
507        let table_b = format!("premix_mig_b_{}", suffix);
508        let migrations = vec![
509            Migration {
510                version: version_a.clone(),
511                name: "create_a".to_string(),
512                up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_a),
513                down_sql: format!("DROP TABLE {};", table_a),
514            },
515            Migration {
516                version: version_b.clone(),
517                name: "create_b".to_string(),
518                up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_b),
519                down_sql: format!("DROP TABLE {};", table_b),
520            },
521        ];
522
523        migrator.run(migrations.clone()).await.unwrap();
524        let rolled_back = migrator.rollback_last(migrations).await.unwrap();
525        assert!(rolled_back);
526
527        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
528            .fetch_one(&pool)
529            .await
530            .unwrap();
531        assert!(count >= 1);
532
533        let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_a))
534            .execute(&pool)
535            .await;
536        let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_b))
537            .execute(&pool)
538            .await;
539        let _ = sqlx::query("DELETE FROM _premix_migrations WHERE version = $1 OR version = $2")
540            .bind(&version_a)
541            .bind(&version_b)
542            .execute(&pool)
543            .await;
544    }
545}