Skip to main content

cast_core/
migration.rs

1//! Migration runner. Each migration is a Rust value with `up` / `down` methods.
2//!
3//! Multi-driver: when the runner owns a `Pool::Postgres` it emits Postgres DDL;
4//! same for MySQL / SQLite. The `Schema` passed to `up`/`down` is pre-configured
5//! with the right dialect.
6
7use crate::pool::{Driver, Pool};
8use crate::schema::Schema;
9use crate::Error;
10
11pub trait Migration: Send + Sync {
12    fn name(&self) -> &'static str;
13    fn up(&self, schema: &mut Schema);
14    fn down(&self, schema: &mut Schema);
15}
16
17inventory::collect!(MigrationRegistration);
18
19pub struct MigrationRegistration {
20    pub builder: fn() -> Box<dyn Migration>,
21}
22
23pub fn collected() -> Vec<Box<dyn Migration>> {
24    inventory::iter::<MigrationRegistration>
25        .into_iter()
26        .map(|r| (r.builder)())
27        .collect()
28}
29
30/// Closure-style migration — Laravel's
31/// `Schema::create('posts', function (Blueprint $t) { ... })` ported to Rust.
32///
33/// Expands to a unit struct + `Migration` impl + `inventory::submit!` —
34/// the same machinery `#[derive(Migration)]` produces, just spelled in
35/// six lines instead of twenty.
36///
37/// Usage:
38///
39/// ```ignore
40/// use anvilforge::prelude::*;
41///
42/// migration!(CreatePostsTable, "2026_05_20_create_posts_table",
43///     up = |s| {
44///         s.create("posts", |t| {
45///             t.id();
46///             t.string("title").not_null();
47///             t.text("body").not_null();
48///             t.timestamps();
49///         });
50///     },
51///     down = |s| {
52///         s.drop_if_exists("posts");
53///     },
54/// );
55/// ```
56///
57/// The struct name is explicit (mirrors Laravel's class name) so the
58/// inventory registration stays deterministic and rollback diagnostics
59/// can name the migration in panics/errors.
60#[macro_export]
61macro_rules! migration {
62    (
63        $struct_name:ident,
64        $name:expr,
65        up = $up:expr,
66        down = $down:expr $(,)?
67    ) => {
68        pub struct $struct_name;
69
70        impl $crate::migration::Migration for $struct_name {
71            fn name(&self) -> &'static str {
72                $name
73            }
74            fn up(&self, schema: &mut $crate::schema::Schema) {
75                let f: fn(&mut $crate::schema::Schema) = $up;
76                f(schema);
77            }
78            fn down(&self, schema: &mut $crate::schema::Schema) {
79                let f: fn(&mut $crate::schema::Schema) = $down;
80                f(schema);
81            }
82        }
83
84        $crate::inventory::submit! {
85            $crate::migration::MigrationRegistration {
86                builder: || -> ::std::boxed::Box<dyn $crate::migration::Migration> {
87                    ::std::boxed::Box::new($struct_name)
88                },
89            }
90        }
91    };
92}
93
94pub struct MigrationRunner {
95    pool: Pool,
96    migrations: Vec<Box<dyn Migration>>,
97}
98
99impl MigrationRunner {
100    pub fn new(pool: Pool) -> Self {
101        let mut migrations = collected();
102        migrations.sort_by_key(|m| m.name().to_string());
103        Self { pool, migrations }
104    }
105
106    pub fn with_migrations(pool: Pool, mut migrations: Vec<Box<dyn Migration>>) -> Self {
107        migrations.sort_by_key(|m| m.name().to_string());
108        Self { pool, migrations }
109    }
110
111    fn driver(&self) -> Driver {
112        self.pool.driver()
113    }
114
115    // ─── per-driver SQL ──────────────────────────────────────────────────────
116
117    fn migrations_table_ddl(&self) -> &'static str {
118        match self.driver() {
119            Driver::Postgres => {
120                "CREATE TABLE IF NOT EXISTS migrations (
121                id BIGSERIAL PRIMARY KEY,
122                name TEXT NOT NULL UNIQUE,
123                batch INTEGER NOT NULL,
124                applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
125            )"
126            }
127            Driver::MySql => {
128                "CREATE TABLE IF NOT EXISTS migrations (
129                id BIGINT AUTO_INCREMENT PRIMARY KEY,
130                name VARCHAR(255) NOT NULL UNIQUE,
131                batch INT NOT NULL,
132                applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
133            )"
134            }
135            Driver::Sqlite => {
136                "CREATE TABLE IF NOT EXISTS migrations (
137                id INTEGER PRIMARY KEY AUTOINCREMENT,
138                name TEXT NOT NULL UNIQUE,
139                batch INTEGER NOT NULL,
140                applied_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
141            )"
142            }
143        }
144    }
145
146    fn fresh_ddl(&self) -> Vec<&'static str> {
147        match self.driver() {
148            Driver::Postgres => vec!["DROP SCHEMA public CASCADE", "CREATE SCHEMA public"],
149            Driver::MySql => vec![
150                // sqlx + MySQL is finicky about multi-statement scripts. We instead
151                // enumerate the tables and drop them individually below; this is just a hint.
152                "",
153            ],
154            Driver::Sqlite => vec![
155                "PRAGMA writable_schema = 1",
156                "DELETE FROM sqlite_master WHERE type IN ('table','index','trigger')",
157                "PRAGMA writable_schema = 0",
158                "VACUUM",
159            ],
160        }
161    }
162
163    // ─── helpers that dispatch to the right sqlx pool ────────────────────────
164
165    async fn exec(&self, sql: &str) -> Result<(), Error> {
166        if sql.is_empty() {
167            return Ok(());
168        }
169        match &self.pool {
170            Pool::Postgres(p) => {
171                sqlx::query(sql).execute(p).await?;
172            }
173            Pool::MySql(p) => {
174                sqlx::query(sql).execute(p).await?;
175            }
176            Pool::Sqlite(p) => {
177                sqlx::query(sql).execute(p).await?;
178            }
179        }
180        Ok(())
181    }
182
183    async fn applied_rows(&self) -> Result<Vec<(String, i32)>, Error> {
184        Ok(match &self.pool {
185            Pool::Postgres(p) => {
186                sqlx::query_as::<_, (String, i32)>(
187                    "SELECT name, batch FROM migrations ORDER BY batch, id",
188                )
189                .fetch_all(p)
190                .await?
191            }
192            Pool::MySql(p) => {
193                sqlx::query_as::<_, (String, i32)>(
194                    "SELECT name, batch FROM migrations ORDER BY batch, id",
195                )
196                .fetch_all(p)
197                .await?
198            }
199            Pool::Sqlite(p) => {
200                sqlx::query_as::<_, (String, i32)>(
201                    "SELECT name, batch FROM migrations ORDER BY batch, id",
202                )
203                .fetch_all(p)
204                .await?
205            }
206        })
207    }
208
209    async fn max_batch(&self) -> Result<Option<i32>, Error> {
210        Ok(match &self.pool {
211            Pool::Postgres(p) => {
212                sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
213                    .fetch_one(p)
214                    .await?
215                    .0
216            }
217            Pool::MySql(p) => {
218                sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
219                    .fetch_one(p)
220                    .await?
221                    .0
222            }
223            Pool::Sqlite(p) => {
224                sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
225                    .fetch_one(p)
226                    .await?
227                    .0
228            }
229        })
230    }
231
232    async fn names_in_batch(&self, batch: i32) -> Result<Vec<String>, Error> {
233        let rows: Vec<(String,)> = match &self.pool {
234            Pool::Postgres(p) => {
235                sqlx::query_as("SELECT name FROM migrations WHERE batch = $1 ORDER BY id DESC")
236                    .bind(batch)
237                    .fetch_all(p)
238                    .await?
239            }
240            Pool::MySql(p) => {
241                sqlx::query_as("SELECT name FROM migrations WHERE batch = ? ORDER BY id DESC")
242                    .bind(batch)
243                    .fetch_all(p)
244                    .await?
245            }
246            Pool::Sqlite(p) => {
247                sqlx::query_as("SELECT name FROM migrations WHERE batch = ?1 ORDER BY id DESC")
248                    .bind(batch)
249                    .fetch_all(p)
250                    .await?
251            }
252        };
253        Ok(rows.into_iter().map(|(n,)| n).collect())
254    }
255
256    async fn record_applied(&self, name: &str, batch: i32) -> Result<(), Error> {
257        match &self.pool {
258            Pool::Postgres(p) => {
259                sqlx::query("INSERT INTO migrations (name, batch) VALUES ($1, $2)")
260                    .bind(name)
261                    .bind(batch)
262                    .execute(p)
263                    .await?;
264            }
265            Pool::MySql(p) => {
266                sqlx::query("INSERT INTO migrations (name, batch) VALUES (?, ?)")
267                    .bind(name)
268                    .bind(batch)
269                    .execute(p)
270                    .await?;
271            }
272            Pool::Sqlite(p) => {
273                sqlx::query("INSERT INTO migrations (name, batch) VALUES (?1, ?2)")
274                    .bind(name)
275                    .bind(batch)
276                    .execute(p)
277                    .await?;
278            }
279        }
280        Ok(())
281    }
282
283    async fn delete_applied(&self, name: &str) -> Result<(), Error> {
284        match &self.pool {
285            Pool::Postgres(p) => {
286                sqlx::query("DELETE FROM migrations WHERE name = $1")
287                    .bind(name)
288                    .execute(p)
289                    .await?;
290            }
291            Pool::MySql(p) => {
292                sqlx::query("DELETE FROM migrations WHERE name = ?")
293                    .bind(name)
294                    .execute(p)
295                    .await?;
296            }
297            Pool::Sqlite(p) => {
298                sqlx::query("DELETE FROM migrations WHERE name = ?1")
299                    .bind(name)
300                    .execute(p)
301                    .await?;
302            }
303        }
304        Ok(())
305    }
306
307    async fn exec_many(&self, stmts: &[String]) -> Result<(), Error> {
308        for s in stmts {
309            self.exec(s).await?;
310        }
311        Ok(())
312    }
313
314    // ─── public API ─────────────────────────────────────────────────────────
315
316    pub async fn ensure_table(&self) -> Result<(), Error> {
317        let ddl = self.migrations_table_ddl();
318        self.exec(ddl).await
319    }
320
321    pub async fn applied(&self) -> Result<Vec<String>, Error> {
322        Ok(self
323            .applied_rows()
324            .await?
325            .into_iter()
326            .map(|(n, _)| n)
327            .collect())
328    }
329
330    pub async fn next_batch(&self) -> Result<i32, Error> {
331        Ok(self.max_batch().await?.unwrap_or(0) + 1)
332    }
333
334    pub async fn run_up(&self) -> Result<Vec<String>, Error> {
335        self.ensure_table().await?;
336        let already = self.applied().await?;
337        let batch = self.next_batch().await?;
338        let mut applied = Vec::new();
339        for m in &self.migrations {
340            if already.iter().any(|a| a == m.name()) {
341                continue;
342            }
343            let mut schema = Schema::for_driver(self.driver());
344            m.up(&mut schema);
345            self.exec_many(&schema.statements).await?;
346            self.record_applied(m.name(), batch).await?;
347            applied.push(m.name().to_string());
348            tracing::info!(name = m.name(), "migration applied");
349        }
350        Ok(applied)
351    }
352
353    pub async fn rollback(&self) -> Result<Vec<String>, Error> {
354        self.ensure_table().await?;
355        let Some(batch) = self.max_batch().await? else {
356            return Ok(Vec::new());
357        };
358        let names = self.names_in_batch(batch).await?;
359        let mut rolled = Vec::new();
360        for name in names {
361            let Some(m) = self.migrations.iter().find(|m| m.name() == name) else {
362                tracing::warn!(name, "migration row in DB but not registered; skipping");
363                continue;
364            };
365            let mut schema = Schema::for_driver(self.driver());
366            m.down(&mut schema);
367            self.exec_many(&schema.statements).await?;
368            self.delete_applied(&name).await?;
369            rolled.push(name);
370        }
371        Ok(rolled)
372    }
373
374    pub async fn fresh(&self) -> Result<(), Error> {
375        // Wipe schema. MySQL doesn't have a "DROP SCHEMA public" equivalent in the
376        // user-friendly sense (it's tied to the active database), so we enumerate
377        // and drop tables individually for it.
378        match self.driver() {
379            Driver::Postgres => {
380                for s in self.fresh_ddl() {
381                    self.exec(s).await?;
382                }
383            }
384            Driver::MySql => {
385                self.drop_all_mysql_tables().await?;
386            }
387            Driver::Sqlite => {
388                self.drop_all_sqlite_tables().await?;
389            }
390        }
391        self.run_up().await?;
392        Ok(())
393    }
394
395    async fn drop_all_mysql_tables(&self) -> Result<(), Error> {
396        let Pool::MySql(p) = &self.pool else {
397            return Ok(());
398        };
399        let tables: Vec<(String,)> = sqlx::query_as(
400            "SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE()",
401        )
402        .fetch_all(p)
403        .await?;
404        sqlx::query("SET FOREIGN_KEY_CHECKS = 0").execute(p).await?;
405        for (t,) in tables {
406            sqlx::query(&format!("DROP TABLE IF EXISTS `{t}`"))
407                .execute(p)
408                .await?;
409        }
410        sqlx::query("SET FOREIGN_KEY_CHECKS = 1").execute(p).await?;
411        Ok(())
412    }
413
414    async fn drop_all_sqlite_tables(&self) -> Result<(), Error> {
415        let Pool::Sqlite(p) = &self.pool else {
416            return Ok(());
417        };
418        let tables: Vec<(String,)> = sqlx::query_as(
419            "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'",
420        )
421        .fetch_all(p)
422        .await?;
423        for (t,) in tables {
424            sqlx::query(&format!("DROP TABLE IF EXISTS \"{t}\""))
425                .execute(p)
426                .await?;
427        }
428        Ok(())
429    }
430
431    pub async fn status(&self) -> Result<Vec<MigrationStatus>, Error> {
432        self.ensure_table().await?;
433        let rows = self.applied_rows().await?;
434        let applied_map: std::collections::HashMap<String, i32> = rows.into_iter().collect();
435
436        let mut out = Vec::new();
437        for m in &self.migrations {
438            let name = m.name().to_string();
439            let batch = applied_map.get(&name).copied();
440            out.push(MigrationStatus {
441                name,
442                applied: batch.is_some(),
443                batch,
444            });
445        }
446        for (db_name, batch) in &applied_map {
447            if !self.migrations.iter().any(|m| m.name() == db_name) {
448                out.push(MigrationStatus {
449                    name: db_name.clone(),
450                    applied: true,
451                    batch: Some(*batch),
452                });
453            }
454        }
455        Ok(out)
456    }
457
458    pub async fn reset(&self) -> Result<Vec<String>, Error> {
459        self.ensure_table().await?;
460        let mut rolled_total = Vec::new();
461        loop {
462            let rolled = self.rollback().await?;
463            if rolled.is_empty() {
464                break;
465            }
466            rolled_total.extend(rolled);
467        }
468        Ok(rolled_total)
469    }
470
471    pub async fn refresh(&self) -> Result<Vec<String>, Error> {
472        self.reset().await?;
473        self.run_up().await
474    }
475
476    pub async fn run_up_step(&self) -> Result<Vec<String>, Error> {
477        self.ensure_table().await?;
478        let already = self.applied().await?;
479        let mut applied = Vec::new();
480        for m in &self.migrations {
481            if already.iter().any(|a| a == m.name()) {
482                continue;
483            }
484            let batch = self.next_batch().await?;
485            let mut schema = Schema::for_driver(self.driver());
486            m.up(&mut schema);
487            self.exec_many(&schema.statements).await?;
488            self.record_applied(m.name(), batch).await?;
489            applied.push(m.name().to_string());
490            tracing::info!(name = m.name(), batch, "migration applied (stepped)");
491        }
492        Ok(applied)
493    }
494
495    pub async fn pretend(&self) -> Result<Vec<String>, Error> {
496        self.ensure_table().await?;
497        let already = self.applied().await?;
498        let mut lines = Vec::new();
499        for m in &self.migrations {
500            if already.iter().any(|a| a == m.name()) {
501                continue;
502            }
503            lines.push(format!("-- migration: {}", m.name()));
504            let mut schema = Schema::for_driver(self.driver());
505            m.up(&mut schema);
506            for stmt in &schema.statements {
507                lines.push(format!("{stmt};"));
508            }
509            lines.push(String::new());
510        }
511        Ok(lines)
512    }
513
514    pub async fn install(&self) -> Result<(), Error> {
515        self.ensure_table().await
516    }
517
518    pub fn count(&self) -> usize {
519        self.migrations.len()
520    }
521}
522
523/// Returned by `migrate:status`.
524#[derive(Debug, Clone)]
525pub struct MigrationStatus {
526    pub name: String,
527    pub applied: bool,
528    pub batch: Option<i32>,
529}
530
531#[cfg(test)]
532mod macro_tests {
533    use super::*;
534    use crate::schema::Schema;
535
536    // Exercise the `migration!` macro at compile time AND assert that it
537    // produces a Migration with the right name + up/down behaviour.
538    crate::migration!(
539        TestCreateThingsTable,
540        "2026_01_01_000003_create_things_table",
541        up = |s| {
542            s.create("things", |t| {
543                t.id();
544                t.string("name").not_null();
545            });
546        },
547        down = |s| {
548            s.drop_if_exists("things");
549        },
550    );
551
552    #[test]
553    fn closure_migration_macro_expands_into_a_working_migration() {
554        let m = TestCreateThingsTable;
555        assert_eq!(m.name(), "2026_01_01_000003_create_things_table");
556
557        // The schema builder records DDL statements as side effects of the
558        // `t.string()` / `s.drop_if_exists()` calls — we just want to check
559        // that running up/down doesn't panic and produces *some* statements.
560        let mut s_up = Schema::for_driver(Driver::Sqlite);
561        m.up(&mut s_up);
562        assert!(
563            !s_up.statements.is_empty(),
564            "up() should emit at least one DDL statement"
565        );
566
567        let mut s_down = Schema::for_driver(Driver::Sqlite);
568        m.down(&mut s_down);
569        assert!(
570            !s_down.statements.is_empty(),
571            "down() should emit at least one DDL statement"
572        );
573    }
574}