premix_core/
migrator.rs

1use std::error::Error;
2
3use sqlx::{Database, Executor, Pool};
4
5#[derive(Debug, Clone)]
6pub struct Migration {
7    pub version: String,
8    pub name: String,
9    pub up_sql: String,
10    pub down_sql: String,
11}
12
13#[derive(Debug, Clone, sqlx::FromRow)]
14struct AppliedMigration {
15    version: String,
16}
17
18pub struct Migrator<DB: Database> {
19    pool: Pool<DB>,
20}
21
22impl<DB: Database> Migrator<DB> {
23    pub fn new(pool: Pool<DB>) -> Self {
24        Self { pool }
25    }
26}
27
28// Specialized implementations for SQLite (Feature-gated or trait-based later)
29// For Version 1, we'll try to use generic Executor where possible,
30// but creating the migrations table often requires dialect specific SQL.
31
32#[cfg(feature = "sqlite")]
33impl Migrator<sqlx::Sqlite> {
34    pub async fn run(&self, migrations: Vec<Migration>) -> Result<(), Box<dyn Error>> {
35        let mut tx = self.pool.begin().await?;
36
37        // 1. Ensure Table Exists
38        sqlx::query(
39            "CREATE TABLE IF NOT EXISTS _premix_migrations (
40                version TEXT PRIMARY KEY,
41                name TEXT NOT NULL,
42                applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
43            )",
44        )
45        .execute(&mut *tx)
46        .await?;
47
48        // 2. Get Applied Versions
49        let applied_versions: Vec<String> = sqlx::query_as::<_, AppliedMigration>(
50            "SELECT version FROM _premix_migrations ORDER BY version ASC",
51        )
52        .fetch_all(&mut *tx)
53        .await?
54        .into_iter()
55        .map(|m| m.version)
56        .collect();
57
58        // 3. Filter Pending
59        for migration in migrations {
60            if !applied_versions.contains(&migration.version) {
61                println!(
62                    "🚚 Applying migration: {} - {}",
63                    migration.version, migration.name
64                );
65
66                // Execute UP SQL
67                tx.execute(migration.up_sql.as_str()).await?;
68
69                // Record Version
70                sqlx::query("INSERT INTO _premix_migrations (version, name) VALUES (?, ?)")
71                    .bind(&migration.version)
72                    .bind(&migration.name)
73                    .execute(&mut *tx)
74                    .await?;
75            }
76        }
77
78        tx.commit().await?;
79        Ok(())
80    }
81}
82
83#[cfg(feature = "postgres")]
84impl Migrator<sqlx::Postgres> {
85    pub async fn run(&self, migrations: Vec<Migration>) -> Result<(), Box<dyn Error>> {
86        let mut tx = self.pool.begin().await?;
87
88        // 1. Ensure Table Exists
89        sqlx::query(
90            "CREATE TABLE IF NOT EXISTS _premix_migrations (
91                version TEXT PRIMARY KEY,
92                name TEXT NOT NULL,
93                applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP
94            )",
95        )
96        .execute(&mut *tx)
97        .await?;
98
99        // 2. Get Applied Versions
100        let applied_versions: Vec<String> = sqlx::query_as::<_, AppliedMigration>(
101            "SELECT version FROM _premix_migrations ORDER BY version ASC",
102        )
103        .fetch_all(&mut *tx)
104        .await?
105        .into_iter()
106        .map(|m| m.version)
107        .collect();
108
109        // 3. Filter Pending
110        for migration in migrations {
111            if !applied_versions.contains(&migration.version) {
112                println!(
113                    "🚚 Applying migration: {} - {}",
114                    migration.version, migration.name
115                );
116
117                // Execute UP SQL
118                // Note: splitting by ; might be needed for multiple statements in one file
119                // But for MVP we assume sqlx can handle the string block or user separates properly.
120                // sqlx::execute only runs the first statement for some drivers,
121                // but Executor::execute roughly maps to running the query.
122                // For safety in Postgres with multiple statements, simple Executor::execute might fail if not wrapped or specific support.
123                // We'll trust user provides valid script block for now.
124                tx.execute(migration.up_sql.as_str()).await?;
125
126                // Record Version
127                sqlx::query("INSERT INTO _premix_migrations (version, name) VALUES ($1, $2)")
128                    .bind(&migration.version)
129                    .bind(&migration.name)
130                    .execute(&mut *tx)
131                    .await?;
132            }
133        }
134
135        tx.commit().await?;
136        Ok(())
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    #[cfg(feature = "postgres")]
143    use std::time::{SystemTime, UNIX_EPOCH};
144
145    use sqlx::sqlite::SqlitePoolOptions;
146
147    use super::*;
148
149    #[cfg(feature = "postgres")]
150    async fn pg_pool_or_skip() -> Option<sqlx::Pool<sqlx::Postgres>> {
151        let db_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| {
152            "postgres://postgres:admin123@localhost:5432/premix_bench".to_string()
153        });
154        match sqlx::postgres::PgPoolOptions::new()
155            .max_connections(1)
156            .connect(&db_url)
157            .await
158        {
159            Ok(pool) => Some(pool),
160            Err(_) => None,
161        }
162    }
163
164    #[tokio::test]
165    async fn sqlite_migrator_applies_pending_once() {
166        let pool = SqlitePoolOptions::new()
167            .max_connections(1)
168            .connect("sqlite::memory:")
169            .await
170            .unwrap();
171        let migrator = Migrator::new(pool.clone());
172
173        let migrations = vec![Migration {
174            version: "20260101000000".to_string(),
175            name: "create_users".to_string(),
176            up_sql: "CREATE TABLE users (id INTEGER PRIMARY KEY);".to_string(),
177            down_sql: "DROP TABLE users;".to_string(),
178        }];
179
180        migrator.run(migrations.clone()).await.unwrap();
181
182        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
183            .fetch_one(&pool)
184            .await
185            .unwrap();
186        assert_eq!(count, 1);
187
188        migrator.run(migrations).await.unwrap();
189        let count_after: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
190            .fetch_one(&pool)
191            .await
192            .unwrap();
193        assert_eq!(count_after, 1);
194    }
195
196    #[tokio::test]
197    async fn sqlite_migrator_applies_multiple() {
198        let pool = SqlitePoolOptions::new()
199            .max_connections(1)
200            .connect("sqlite::memory:")
201            .await
202            .unwrap();
203        let migrator = Migrator::new(pool.clone());
204
205        let migrations = vec![
206            Migration {
207                version: "20260103000000".to_string(),
208                name: "create_a".to_string(),
209                up_sql: "CREATE TABLE a (id INTEGER PRIMARY KEY);".to_string(),
210                down_sql: "DROP TABLE a;".to_string(),
211            },
212            Migration {
213                version: "20260104000000".to_string(),
214                name: "create_b".to_string(),
215                up_sql: "CREATE TABLE b (id INTEGER PRIMARY KEY);".to_string(),
216                down_sql: "DROP TABLE b;".to_string(),
217            },
218        ];
219
220        migrator.run(migrations).await.unwrap();
221
222        let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations")
223            .fetch_one(&pool)
224            .await
225            .unwrap();
226        assert_eq!(count, 2);
227    }
228
229    #[tokio::test]
230    async fn sqlite_migrator_rolls_back_on_error() {
231        let pool = SqlitePoolOptions::new()
232            .max_connections(1)
233            .connect("sqlite::memory:")
234            .await
235            .unwrap();
236        let migrator = Migrator::new(pool.clone());
237
238        let migrations = vec![Migration {
239            version: "20260102000000".to_string(),
240            name: "bad_sql".to_string(),
241            up_sql: "CREATE TABLE broken (id INTEGER PRIMARY KEY); INVALID SQL".to_string(),
242            down_sql: "DROP TABLE broken;".to_string(),
243        }];
244
245        let err = migrator.run(migrations).await.unwrap_err();
246        assert!(err.to_string().contains("syntax"));
247
248        let table: Option<String> = sqlx::query_scalar(
249            "SELECT name FROM sqlite_master WHERE type='table' AND name='_premix_migrations'",
250        )
251        .fetch_optional(&pool)
252        .await
253        .unwrap();
254        assert!(table.is_none());
255    }
256
257    #[cfg(feature = "postgres")]
258    #[tokio::test]
259    async fn postgres_migrator_applies_pending_once() {
260        let Some(pool) = pg_pool_or_skip().await else {
261            return;
262        };
263        let migrator = Migrator::new(pool.clone());
264
265        let suffix = SystemTime::now()
266            .duration_since(UNIX_EPOCH)
267            .unwrap()
268            .as_nanos();
269        let version = format!("20260101{:020}", suffix);
270        let table_name = format!("premix_mig_test_{}", suffix);
271        let migrations = vec![Migration {
272            version: version.clone(),
273            name: "create_test_table".to_string(),
274            up_sql: format!("CREATE TABLE {} (id SERIAL PRIMARY KEY);", table_name),
275            down_sql: format!("DROP TABLE {};", table_name),
276        }];
277
278        migrator.run(migrations.clone()).await.unwrap();
279
280        let count: i64 =
281            sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
282                .bind(&version)
283                .fetch_one(&pool)
284                .await
285                .unwrap();
286        assert_eq!(count, 1);
287
288        migrator.run(migrations).await.unwrap();
289        let count_after: i64 =
290            sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
291                .bind(&version)
292                .fetch_one(&pool)
293                .await
294                .unwrap();
295        assert_eq!(count_after, 1);
296
297        let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_name))
298            .execute(&pool)
299            .await;
300        let _ = sqlx::query("DELETE FROM _premix_migrations WHERE version = $1")
301            .bind(&version)
302            .execute(&pool)
303            .await;
304    }
305
306    #[cfg(feature = "postgres")]
307    #[tokio::test]
308    async fn postgres_migrator_rolls_back_on_error() {
309        let Some(pool) = pg_pool_or_skip().await else {
310            return;
311        };
312        let migrator = Migrator::new(pool.clone());
313
314        let suffix = SystemTime::now()
315            .duration_since(UNIX_EPOCH)
316            .unwrap()
317            .as_nanos();
318        let version = format!("20260102{:020}", suffix);
319        let table_name = format!("premix_mig_bad_{}", suffix);
320        let migrations = vec![Migration {
321            version: version.clone(),
322            name: "bad_sql".to_string(),
323            up_sql: format!(
324                "CREATE TABLE {} (id SERIAL PRIMARY KEY); INVALID SQL",
325                table_name
326            ),
327            down_sql: format!("DROP TABLE {};", table_name),
328        }];
329
330        let err = migrator.run(migrations).await.unwrap_err();
331        assert!(err.to_string().contains("syntax"));
332
333        let table_exists: Option<String> =
334            sqlx::query_scalar("SELECT to_regclass('_premix_migrations')::text")
335                .fetch_one(&pool)
336                .await
337                .unwrap();
338        if table_exists.is_some() {
339            let count: i64 =
340                sqlx::query_scalar("SELECT COUNT(*) FROM _premix_migrations WHERE version = $1")
341                    .bind(&version)
342                    .fetch_one(&pool)
343                    .await
344                    .unwrap();
345            assert_eq!(count, 0);
346        }
347
348        let _ = sqlx::query(&format!("DROP TABLE IF EXISTS {}", table_name))
349            .execute(&pool)
350            .await;
351    }
352}