premix_core/
migrator.rs

1use std::error::Error;
2
3use sqlx::{Database, Executor, Pool};
4
5#[derive(Debug, Clone)]
6pub struct Migration {
7    pub version: String,
8    pub name: String,
9    pub up_sql: String,
10    pub down_sql: String,
11}
12
13#[derive(Debug, Clone, sqlx::FromRow)]
14struct AppliedMigration {
15    version: String,
16}
17
18pub struct Migrator<DB: Database> {
19    pool: Pool<DB>,
20}
21
22impl<DB: Database> Migrator<DB> {
23    pub fn new(pool: Pool<DB>) -> Self {
24        Self { pool }
25    }
26}
27
28// Specialized implementations for SQLite (Feature-gated or trait-based later)
29// For Version 1, we'll try to use generic Executor where possible,
30// but creating the migrations table often requires dialect specific SQL.
31
32#[cfg(feature = "sqlite")]
33impl Migrator<sqlx::Sqlite> {
34    pub async fn run(&self, migrations: Vec<Migration>) -> Result<(), Box<dyn Error>> {
35        let mut tx = self.pool.begin().await?;
36
37        // 1. Ensure Table Exists
38        sqlx::query(
39            "CREATE TABLE IF NOT EXISTS _premix_migrations (
40                version TEXT PRIMARY KEY,
41                name TEXT NOT NULL,
42                applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
43            )",
44        )
45        .execute(&mut *tx)
46        .await?;
47
48        // 2. Get Applied Versions
49        let applied_versions: Vec<String> = sqlx::query_as::<_, AppliedMigration>(
50            "SELECT version FROM _premix_migrations ORDER BY version ASC",
51        )
52        .fetch_all(&mut *tx)
53        .await?
54        .into_iter()
55        .map(|m| m.version)
56        .collect();
57
58        // 3. Filter Pending
59        for migration in migrations {
60            if !applied_versions.contains(&migration.version) {
61                tracing::info!(
62                    operation = "migration_apply",
63                    version = %migration.version,
64                    name = %migration.name,
65                    "premix migration"
66                );
67                println!(
68                    "🚚 Applying migration: {} - {}",
69                    migration.version, migration.name
70                );
71
72                // Execute UP SQL
73                tx.execute(migration.up_sql.as_str()).await?;
74
75                // Record Version
76                sqlx::query("INSERT INTO _premix_migrations (version, name) VALUES (?, ?)")
77                    .bind(&migration.version)
78                    .bind(&migration.name)
79                    .execute(&mut *tx)
80                    .await?;
81            }
82        }
83
84        tx.commit().await?;
85        Ok(())
86    }
87
88    pub async fn rollback_last(&self, migrations: Vec<Migration>) -> Result<bool, Box<dyn Error>> {
89        let mut tx = self.pool.begin().await?;
90
91        sqlx::query(
92            "CREATE TABLE IF NOT EXISTS _premix_migrations (
93                version TEXT PRIMARY KEY,
94                name TEXT NOT NULL,
95                applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
96            )",
97        )
98        .execute(&mut *tx)
99        .await?;
100
101        let versions: Vec<String> = migrations.iter().map(|m| m.version.clone()).collect();
102        if versions.is_empty() {
103            tx.commit().await?;
104            return Ok(false);
105        }
106
107        let placeholders = vec!["?"; versions.len()].join(", ");
108        let sql = format!(
109            "SELECT version FROM _premix_migrations WHERE version IN ({}) ORDER BY version DESC LIMIT 1",
110            placeholders
111        );
112        let mut query = sqlx::query_scalar::<_, String>(&sql);
113        for version in &versions {
114            query = query.bind(version);
115        }
116        let last = query.fetch_optional(&mut *tx).await?;
117
118        let Some(last) = last else {
119            tx.commit().await?;
120            return Ok(false);
121        };
122
123        let migration = migrations
124            .into_iter()
125            .find(|m| m.version == last)
126            .ok_or_else(|| format!("Migration {} not found.", last))?;
127
128        if migration.down_sql.trim().is_empty() {
129            return Err("Down migration is empty.".into());
130        }
131
132        tracing::info!(
133            operation = "migration_rollback",
134            version = %migration.version,
135            name = %migration.name,
136            "premix migration"
137        );
138        tx.execute(migration.down_sql.as_str()).await?;
139        sqlx::query("DELETE FROM _premix_migrations WHERE version = ?")
140            .bind(&migration.version)
141            .execute(&mut *tx)
142            .await?;
143
144        tx.commit().await?;
145        Ok(true)
146    }
147}
148
149#[cfg(feature = "postgres")]
150impl Migrator<sqlx::Postgres> {
151    pub async fn run(&self, migrations: Vec<Migration>) -> Result<(), Box<dyn Error>> {
152        let mut tx = self.pool.begin().await?;
153
154        // 1. Ensure Table Exists
155        sqlx::query(
156            "CREATE TABLE IF NOT EXISTS _premix_migrations (
157                version TEXT PRIMARY KEY,
158                name TEXT NOT NULL,
159                applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
160            )",
161        )
162        .execute(&mut *tx)
163        .await?;
164
165        // 2. Get Applied Versions
166        let applied_versions: Vec<String> = sqlx::query_as::<_, AppliedMigration>(
167            "SELECT version FROM _premix_migrations ORDER BY version ASC",
168        )
169        .fetch_all(&mut *tx)
170        .await?
171        .into_iter()
172        .map(|m| m.version)
173        .collect();
174
175        // 3. Filter Pending
176        for migration in migrations {
177            if !applied_versions.contains(&migration.version) {
178                tracing::info!(
179                    operation = "migration_apply",
180                    version = %migration.version,
181                    name = %migration.name,
182                    "premix migration"
183                );
184                println!(
185                    "🚚 Applying migration: {} - {}",
186                    migration.version, migration.name
187                );
188
189                // Execute UP SQL
190                // Note: splitting by ; might be needed for multiple statements in one file
191                // But for MVP we assume sqlx can handle the string block or user separates properly.
192                // sqlx::execute only runs the first statement for some drivers,
193                // but Executor::execute roughly maps to running the query.
194                // For safety in Postgres with multiple statements, simple Executor::execute might fail if not wrapped or specific support.
195                // We'll trust user provides valid script block for now.
196                tx.execute(migration.up_sql.as_str()).await?;
197
198                // Record Version
199                sqlx::query("INSERT INTO _premix_migrations (version, name) VALUES ($1, $2)")
200                    .bind(&migration.version)
201                    .bind(&migration.name)
202                    .execute(&mut *tx)
203                    .await?;
204            }
205        }
206
207        tx.commit().await?;
208        Ok(())
209    }
210
211    pub async fn rollback_last(&self, migrations: Vec<Migration>) -> Result<bool, Box<dyn Error>> {
212        let mut tx = self.pool.begin().await?;
213
214        sqlx::query(
215            "CREATE TABLE IF NOT EXISTS _premix_migrations (
216                version TEXT PRIMARY KEY,
217                name TEXT NOT NULL,
218                applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
219            )",
220        )
221        .execute(&mut *tx)
222        .await?;
223
224        let versions: Vec<String> = migrations.iter().map(|m| m.version.clone()).collect();
225        if versions.is_empty() {
226            tx.commit().await?;
227            return Ok(false);
228        }
229
230        let last = sqlx::query_scalar::<_, String>(
231            "SELECT version FROM _premix_migrations WHERE version = ANY($1) ORDER BY version DESC LIMIT 1",
232        )
233        .bind(&versions)
234        .fetch_optional(&mut *tx)
235        .await?;
236
237        let Some(last) = last else {
238            tx.commit().await?;
239            return Ok(false);
240        };
241
242        let migration = migrations
243            .into_iter()
244            .find(|m| m.version == last)
245            .ok_or_else(|| format!("Migration {} not found.", last))?;
246
247        if migration.down_sql.trim().is_empty() {
248            return Err("Down migration is empty.".into());
249        }
250
251        tracing::info!(
252            operation = "migration_rollback",
253            version = %migration.version,
254            name = %migration.name,
255            "premix migration"
256        );
257        tx.execute(migration.down_sql.as_str()).await?;
258        sqlx::query("DELETE FROM _premix_migrations WHERE version = $1")
259            .bind(&migration.version)
260            .execute(&mut *tx)
261            .await?;
262
263        tx.commit().await?;
264        Ok(true)
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    #[cfg(feature = "postgres")]
271    use std::time::{SystemTime, UNIX_EPOCH};
272
273    use sqlx::sqlite::SqlitePoolOptions;
274
275    use super::*;
276
277    #[cfg(feature = "postgres")]
278    async fn pg_pool_or_skip() -> Option<sqlx::Pool<sqlx::Postgres>> {
279        let db_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| {
280            "postgres://postgres:admin123@localhost:5432/premix_bench".to_string()
281        });
282        match sqlx::postgres::PgPoolOptions::new()
283            .max_connections(1)
284            .connect(&db_url)
285            .await
286        {
287            Ok(pool) => Some(pool),
288            Err(_) => None,
289        }
290    }
291
292    #[tokio::test]
293    async fn sqlite_migrator_applies_pending_once() {
294        let pool = SqlitePoolOptions::new()
295            .max_connections(1)
296            .connect("sqlite::memory:")
297            .await
298            .unwrap();
299        let migrator = Migrator::new(pool.clone());
300
301        let migrations = vec![Migration {
302            version: "20260101000000".to_string(),
303            name: "create_users".to_string(),
304            up_sql: "CREATE TABLE users (id INTEGER PRIMARY KEY);".to_string(),
305            down_sql: "DROP TABLE users;".to_string(),
306        }];
307
308        migrator.run(migrations.clone()).await.unwrap();
309
310        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
311            .fetch_one(&pool)
312            .await
313            .unwrap();
314        assert_eq!(count, 1);
315
316        migrator.run(migrations).await.unwrap();
317        let count_after: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
318            .fetch_one(&pool)
319            .await
320            .unwrap();
321        assert_eq!(count_after, 1);
322    }
323
324    #[tokio::test]
325    async fn sqlite_migrator_applies_multiple() {
326        let pool = SqlitePoolOptions::new()
327            .max_connections(1)
328            .connect("sqlite::memory:")
329            .await
330            .unwrap();
331        let migrator = Migrator::new(pool.clone());
332
333        let migrations = vec![
334            Migration {
335                version: "20260103000000".to_string(),
336                name: "create_a".to_string(),
337                up_sql: "CREATE TABLE a (id INTEGER PRIMARY KEY);".to_string(),
338                down_sql: "DROP TABLE a;".to_string(),
339            },
340            Migration {
341                version: "20260104000000".to_string(),
342                name: "create_b".to_string(),
343                up_sql: "CREATE TABLE b (id INTEGER PRIMARY KEY);".to_string(),
344                down_sql: "DROP TABLE b;".to_string(),
345            },
346        ];
347
348        migrator.run(migrations).await.unwrap();
349
350        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
351            .fetch_one(&pool)
352            .await
353            .unwrap();
354        assert_eq!(count, 2);
355    }
356
357    #[tokio::test]
358    async fn sqlite_migrator_rolls_back_last() {
359        let pool = SqlitePoolOptions::new()
360            .max_connections(1)
361            .connect("sqlite::memory:")
362            .await
363            .unwrap();
364        let migrator = Migrator::new(pool.clone());
365
366        let migrations = vec![
367            Migration {
368                version: "20260103000000".to_string(),
369                name: "create_a".to_string(),
370                up_sql: "CREATE TABLE a (id INTEGER PRIMARY KEY);".to_string(),
371                down_sql: "DROP TABLE a;".to_string(),
372            },
373            Migration {
374                version: "20260104000000".to_string(),
375                name: "create_b".to_string(),
376                up_sql: "CREATE TABLE b (id INTEGER PRIMARY KEY);".to_string(),
377                down_sql: "DROP TABLE b;".to_string(),
378            },
379        ];
380
381        migrator.run(migrations.clone()).await.unwrap();
382        let rolled_back = migrator.rollback_last(migrations).await.unwrap();
383        assert!(rolled_back);
384
385        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
386            .fetch_one(&pool)
387            .await
388            .unwrap();
389        assert_eq!(count, 1);
390    }
391
392    #[tokio::test]
393    async fn sqlite_migrator_rolls_back_on_error() {
394        let pool = SqlitePoolOptions::new()
395            .max_connections(1)
396            .connect("sqlite::memory:")
397            .await
398            .unwrap();
399        let migrator = Migrator::new(pool.clone());
400
401        let migrations = vec![Migration {
402            version: "20260102000000".to_string(),
403            name: "bad_sql".to_string(),
404            up_sql: "CREATE TABLE broken (id INTEGER PRIMARY KEY); INVALID SQL".to_string(),
405            down_sql: "DROP TABLE broken;".to_string(),
406        }];
407
408        let err = migrator.run(migrations).await.unwrap_err();
409        assert!(err.to_string().contains("syntax"));
410
411        let table: Option<String> = sqlx::query_scalar(
412            "SELECT name FROM sqlite_master WHERE type='table' AND name='_premix_migrations'",
413        )
414        .fetch_optional(&pool)
415        .await
416        .unwrap();
417        assert!(table.is_none());
418    }
419
420    #[cfg(feature = "postgres")]
421    #[tokio::test]
422    async fn postgres_migrator_applies_pending_once() {
423        let Some(pool) = pg_pool_or_skip().await else {
424            return;
425        };
426        let migrator = Migrator::new(pool.clone());
427
428        let suffix = SystemTime::now()
429            .duration_since(UNIX_EPOCH)
430            .unwrap()
431            .as_nanos();
432        let version = format!("20260101{:020}", suffix);
433        let table_name = format!("premix_mig_test_{}", suffix);
434        let migrations = vec![Migration {
435            version: version.clone(),
436            name: "create_test_table".to_string(),
437            up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_name),
438            down_sql: format!("DROP TABLE {};", table_name),
439        }];
440
441        migrator.run(migrations.clone()).await.unwrap();
442
443        let count: i64 =
444            sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
445                .bind(&version)
446                .fetch_one(&pool)
447                .await
448                .unwrap();
449        assert_eq!(count, 1);
450
451        migrator.run(migrations).await.unwrap();
452        let count_after: i64 =
453            sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
454                .bind(&version)
455                .fetch_one(&pool)
456                .await
457                .unwrap();
458        assert_eq!(count_after, 1);
459
460        let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_name))
461            .execute(&pool)
462            .await;
463        let _ = sqlx::query("DELETE FROM _premix_migrations WHERE version = $1")
464            .bind(&version)
465            .execute(&pool)
466            .await;
467    }
468
469    #[cfg(feature = "postgres")]
470    #[tokio::test]
471    async fn postgres_migrator_rolls_back_on_error() {
472        let Some(pool) = pg_pool_or_skip().await else {
473            return;
474        };
475        let migrator = Migrator::new(pool.clone());
476
477        let suffix = SystemTime::now()
478            .duration_since(UNIX_EPOCH)
479            .unwrap()
480            .as_nanos();
481        let version = format!("20260102{:020}", suffix);
482        let table_name = format!("premix_mig_bad_{}", suffix);
483        let migrations = vec![Migration {
484            version: version.clone(),
485            name: "bad_sql".to_string(),
486            up_sql: format!(
487                "CREATE TABLE {} (id SERIAL PRIMARY KEY); INVALID SQL",
488                table_name
489            ),
490            down_sql: format!("DROP TABLE {};", table_name),
491        }];
492
493        let err = migrator.run(migrations).await.unwrap_err();
494        assert!(err.to_string().contains("syntax"));
495
496        let table_exists: Option<String> =
497            sqlx::query_scalar("SELECT to_regclass('_premix_migrations')::text")
498                .fetch_one(&pool)
499                .await
500                .unwrap();
501        if table_exists.is_some() {
502            let count: i64 =
503                sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
504                    .bind(&version)
505                    .fetch_one(&pool)
506                    .await
507                    .unwrap();
508            assert_eq!(count, 0);
509        }
510
511        let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_name))
512            .execute(&pool)
513            .await;
514    }
515
516    #[cfg(feature = "postgres")]
517    #[tokio::test]
518    async fn postgres_migrator_rolls_back_last() {
519        let Some(pool) = pg_pool_or_skip().await else {
520            return;
521        };
522        let migrator = Migrator::new(pool.clone());
523
524        let suffix = SystemTime::now()
525            .duration_since(UNIX_EPOCH)
526            .unwrap()
527            .as_nanos();
528        let version_a = format!("20260103{:020}", suffix);
529        let version_b = format!("20260104{:020}", suffix);
530        let table_a = format!("premix_mig_a_{}", suffix);
531        let table_b = format!("premix_mig_b_{}", suffix);
532        let migrations = vec![
533            Migration {
534                version: version_a.clone(),
535                name: "create_a".to_string(),
536                up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_a),
537                down_sql: format!("DROP TABLE {};", table_a),
538            },
539            Migration {
540                version: version_b.clone(),
541                name: "create_b".to_string(),
542                up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_b),
543                down_sql: format!("DROP TABLE {};", table_b),
544            },
545        ];
546
547        migrator.run(migrations.clone()).await.unwrap();
548        let rolled_back = migrator.rollback_last(migrations).await.unwrap();
549        assert!(rolled_back);
550
551        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
552            .fetch_one(&pool)
553            .await
554            .unwrap();
555        assert!(count >= 1);
556
557        let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_a))
558            .execute(&pool)
559            .await;
560        let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_b))
561            .execute(&pool)
562            .await;
563        let _ = sqlx::query("DELETE FROM _premix_migrations WHERE version = $1 OR version = $2")
564            .bind(&version_a)
565            .bind(&version_b)
566            .execute(&pool)
567            .await;
568    }
569}