Skip to main content

cast_core/
migration.rs

1//! Migration runner. Each migration is a Rust value with `up` / `down` methods.
2//!
3//! Multi-driver: when the runner owns a `Pool::Postgres` it emits Postgres DDL;
4//! same for MySQL / SQLite. The `Schema` passed to `up`/`down` is pre-configured
5//! with the right dialect.
6
7use crate::pool::{Driver, Pool};
8use crate::schema::Schema;
9use crate::Error;
10
11pub trait Migration: Send + Sync {
12    fn name(&self) -> &'static str;
13    fn up(&self, schema: &mut Schema);
14    fn down(&self, schema: &mut Schema);
15}
16
17inventory::collect!(MigrationRegistration);
18
19pub struct MigrationRegistration {
20    pub builder: fn() -> Box<dyn Migration>,
21}
22
23pub fn collected() -> Vec<Box<dyn Migration>> {
24    inventory::iter::<MigrationRegistration>
25        .into_iter()
26        .map(|r| (r.builder)())
27        .collect()
28}
29
30/// Panic with a clear message if two registered migrations return the same
31/// `name()`. The migrations table has a UNIQUE constraint on `name`, but a
32/// duplicate registration silently masks the second migration at apply time —
33/// failing early at runner construction catches the rename footgun (file
34/// renamed, `name()` left stale → collision with the new file's `name()`).
35fn check_unique_names(migrations: &[Box<dyn Migration>]) {
36    use std::collections::HashSet;
37    let mut seen: HashSet<&'static str> = HashSet::with_capacity(migrations.len());
38    let mut dups: Vec<&'static str> = Vec::new();
39    for m in migrations {
40        if !seen.insert(m.name()) {
41            dups.push(m.name());
42        }
43    }
44    if !dups.is_empty() {
45        panic!(
46            "duplicate Migration::name() values: {dups:?}. \
47             A `name()` collision lets one migration silently shadow another. \
48             Check that each migration file's `fn name(&self) -> &'static str` matches its filename stem."
49        );
50    }
51}
52
53/// Closure-style migration — Laravel's
54/// `Schema::create('posts', function (Blueprint $t) { ... })` ported to Rust.
55///
56/// Expands to a unit struct + `Migration` impl + `inventory::submit!` —
57/// the same machinery `#[derive(Migration)]` produces, just spelled in
58/// six lines instead of twenty.
59///
60/// Usage:
61///
62/// ```ignore
63/// use anvilforge::prelude::*;
64///
65/// migration!(CreatePostsTable, "2026_05_20_create_posts_table",
66///     up = |s| {
67///         s.create("posts", |t| {
68///             t.id();
69///             t.string("title").not_null();
70///             t.text("body").not_null();
71///             t.timestamps();
72///         });
73///     },
74///     down = |s| {
75///         s.drop_if_exists("posts");
76///     },
77/// );
78/// ```
79///
80/// The struct name is explicit (mirrors Laravel's class name) so the
81/// inventory registration stays deterministic and rollback diagnostics
82/// can name the migration in panics/errors.
83#[macro_export]
84macro_rules! migration {
85    (
86        $struct_name:ident,
87        $name:expr,
88        up = $up:expr,
89        down = $down:expr $(,)?
90    ) => {
91        pub struct $struct_name;
92
93        impl $crate::migration::Migration for $struct_name {
94            fn name(&self) -> &'static str {
95                $name
96            }
97            fn up(&self, schema: &mut $crate::schema::Schema) {
98                let f: fn(&mut $crate::schema::Schema) = $up;
99                f(schema);
100            }
101            fn down(&self, schema: &mut $crate::schema::Schema) {
102                let f: fn(&mut $crate::schema::Schema) = $down;
103                f(schema);
104            }
105        }
106
107        $crate::inventory::submit! {
108            $crate::migration::MigrationRegistration {
109                builder: || -> ::std::boxed::Box<dyn $crate::migration::Migration> {
110                    ::std::boxed::Box::new($struct_name)
111                },
112            }
113        }
114    };
115}
116
117pub struct MigrationRunner {
118    pool: Pool,
119    migrations: Vec<Box<dyn Migration>>,
120}
121
122impl MigrationRunner {
123    pub fn new(pool: Pool) -> Self {
124        let mut migrations = collected();
125        check_unique_names(&migrations);
126        migrations.sort_by_key(|m| m.name().to_string());
127        Self { pool, migrations }
128    }
129
130    pub fn with_migrations(pool: Pool, mut migrations: Vec<Box<dyn Migration>>) -> Self {
131        check_unique_names(&migrations);
132        migrations.sort_by_key(|m| m.name().to_string());
133        Self { pool, migrations }
134    }
135
136    fn driver(&self) -> Driver {
137        self.pool.driver()
138    }
139
140    // ─── per-driver SQL ──────────────────────────────────────────────────────
141
142    fn migrations_table_ddl(&self) -> &'static str {
143        match self.driver() {
144            Driver::Postgres => {
145                "CREATE TABLE IF NOT EXISTS migrations (
146                id BIGSERIAL PRIMARY KEY,
147                name TEXT NOT NULL UNIQUE,
148                batch INTEGER NOT NULL,
149                applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
150            )"
151            }
152            Driver::MySql => {
153                "CREATE TABLE IF NOT EXISTS migrations (
154                id BIGINT AUTO_INCREMENT PRIMARY KEY,
155                name VARCHAR(255) NOT NULL UNIQUE,
156                batch INT NOT NULL,
157                applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
158            )"
159            }
160            Driver::Sqlite => {
161                "CREATE TABLE IF NOT EXISTS migrations (
162                id INTEGER PRIMARY KEY AUTOINCREMENT,
163                name TEXT NOT NULL UNIQUE,
164                batch INTEGER NOT NULL,
165                applied_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
166            )"
167            }
168        }
169    }
170
171    fn fresh_ddl(&self) -> Vec<&'static str> {
172        match self.driver() {
173            Driver::Postgres => vec!["DROP SCHEMA public CASCADE", "CREATE SCHEMA public"],
174            Driver::MySql => vec![
175                // sqlx + MySQL is finicky about multi-statement scripts. We instead
176                // enumerate the tables and drop them individually below; this is just a hint.
177                "",
178            ],
179            Driver::Sqlite => vec![
180                "PRAGMA writable_schema = 1",
181                "DELETE FROM sqlite_master WHERE type IN ('table','index','trigger')",
182                "PRAGMA writable_schema = 0",
183                "VACUUM",
184            ],
185        }
186    }
187
188    // ─── helpers that dispatch to the right sqlx pool ────────────────────────
189
190    async fn exec(&self, sql: &str) -> Result<(), Error> {
191        if sql.is_empty() {
192            return Ok(());
193        }
194        match &self.pool {
195            Pool::Postgres(p) => {
196                sqlx::query(sql).execute(p).await?;
197            }
198            Pool::MySql(p) => {
199                sqlx::query(sql).execute(p).await?;
200            }
201            Pool::Sqlite(p) => {
202                sqlx::query(sql).execute(p).await?;
203            }
204        }
205        Ok(())
206    }
207
208    async fn applied_rows(&self) -> Result<Vec<(String, i32)>, Error> {
209        Ok(match &self.pool {
210            Pool::Postgres(p) => {
211                sqlx::query_as::<_, (String, i32)>(
212                    "SELECT name, batch FROM migrations ORDER BY batch, id",
213                )
214                .fetch_all(p)
215                .await?
216            }
217            Pool::MySql(p) => {
218                sqlx::query_as::<_, (String, i32)>(
219                    "SELECT name, batch FROM migrations ORDER BY batch, id",
220                )
221                .fetch_all(p)
222                .await?
223            }
224            Pool::Sqlite(p) => {
225                sqlx::query_as::<_, (String, i32)>(
226                    "SELECT name, batch FROM migrations ORDER BY batch, id",
227                )
228                .fetch_all(p)
229                .await?
230            }
231        })
232    }
233
234    async fn max_batch(&self) -> Result<Option<i32>, Error> {
235        Ok(match &self.pool {
236            Pool::Postgres(p) => {
237                sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
238                    .fetch_one(p)
239                    .await?
240                    .0
241            }
242            Pool::MySql(p) => {
243                sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
244                    .fetch_one(p)
245                    .await?
246                    .0
247            }
248            Pool::Sqlite(p) => {
249                sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
250                    .fetch_one(p)
251                    .await?
252                    .0
253            }
254        })
255    }
256
257    async fn names_in_batch(&self, batch: i32) -> Result<Vec<String>, Error> {
258        let rows: Vec<(String,)> = match &self.pool {
259            Pool::Postgres(p) => {
260                sqlx::query_as("SELECT name FROM migrations WHERE batch = $1 ORDER BY id DESC")
261                    .bind(batch)
262                    .fetch_all(p)
263                    .await?
264            }
265            Pool::MySql(p) => {
266                sqlx::query_as("SELECT name FROM migrations WHERE batch = ? ORDER BY id DESC")
267                    .bind(batch)
268                    .fetch_all(p)
269                    .await?
270            }
271            Pool::Sqlite(p) => {
272                sqlx::query_as("SELECT name FROM migrations WHERE batch = ?1 ORDER BY id DESC")
273                    .bind(batch)
274                    .fetch_all(p)
275                    .await?
276            }
277        };
278        Ok(rows.into_iter().map(|(n,)| n).collect())
279    }
280
281    async fn record_applied(&self, name: &str, batch: i32) -> Result<(), Error> {
282        match &self.pool {
283            Pool::Postgres(p) => {
284                sqlx::query("INSERT INTO migrations (name, batch) VALUES ($1, $2)")
285                    .bind(name)
286                    .bind(batch)
287                    .execute(p)
288                    .await?;
289            }
290            Pool::MySql(p) => {
291                sqlx::query("INSERT INTO migrations (name, batch) VALUES (?, ?)")
292                    .bind(name)
293                    .bind(batch)
294                    .execute(p)
295                    .await?;
296            }
297            Pool::Sqlite(p) => {
298                sqlx::query("INSERT INTO migrations (name, batch) VALUES (?1, ?2)")
299                    .bind(name)
300                    .bind(batch)
301                    .execute(p)
302                    .await?;
303            }
304        }
305        Ok(())
306    }
307
308    async fn delete_applied(&self, name: &str) -> Result<(), Error> {
309        match &self.pool {
310            Pool::Postgres(p) => {
311                sqlx::query("DELETE FROM migrations WHERE name = $1")
312                    .bind(name)
313                    .execute(p)
314                    .await?;
315            }
316            Pool::MySql(p) => {
317                sqlx::query("DELETE FROM migrations WHERE name = ?")
318                    .bind(name)
319                    .execute(p)
320                    .await?;
321            }
322            Pool::Sqlite(p) => {
323                sqlx::query("DELETE FROM migrations WHERE name = ?1")
324                    .bind(name)
325                    .execute(p)
326                    .await?;
327            }
328        }
329        Ok(())
330    }
331
332    async fn exec_many(&self, stmts: &[String]) -> Result<(), Error> {
333        for s in stmts {
334            self.exec(s).await?;
335        }
336        Ok(())
337    }
338
339    // ─── public API ─────────────────────────────────────────────────────────
340
341    pub async fn ensure_table(&self) -> Result<(), Error> {
342        let ddl = self.migrations_table_ddl();
343        self.exec(ddl).await
344    }
345
346    pub async fn applied(&self) -> Result<Vec<String>, Error> {
347        Ok(self
348            .applied_rows()
349            .await?
350            .into_iter()
351            .map(|(n, _)| n)
352            .collect())
353    }
354
355    pub async fn next_batch(&self) -> Result<i32, Error> {
356        Ok(self.max_batch().await?.unwrap_or(0) + 1)
357    }
358
359    pub async fn run_up(&self) -> Result<Vec<String>, Error> {
360        self.ensure_table().await?;
361        let already = self.applied().await?;
362        let batch = self.next_batch().await?;
363        let mut applied = Vec::new();
364        for m in &self.migrations {
365            if already.iter().any(|a| a == m.name()) {
366                continue;
367            }
368            let mut schema = Schema::for_driver(self.driver());
369            m.up(&mut schema);
370            self.exec_many(&schema.statements).await?;
371            self.record_applied(m.name(), batch).await?;
372            applied.push(m.name().to_string());
373            tracing::info!(name = m.name(), "migration applied");
374        }
375        Ok(applied)
376    }
377
378    pub async fn rollback(&self) -> Result<Vec<String>, Error> {
379        self.ensure_table().await?;
380        let Some(batch) = self.max_batch().await? else {
381            return Ok(Vec::new());
382        };
383        let names = self.names_in_batch(batch).await?;
384        let mut rolled = Vec::new();
385        for name in names {
386            let Some(m) = self.migrations.iter().find(|m| m.name() == name) else {
387                tracing::warn!(name, "migration row in DB but not registered; skipping");
388                continue;
389            };
390            let mut schema = Schema::for_driver(self.driver());
391            m.down(&mut schema);
392            self.exec_many(&schema.statements).await?;
393            self.delete_applied(&name).await?;
394            rolled.push(name);
395        }
396        Ok(rolled)
397    }
398
399    pub async fn fresh(&self) -> Result<(), Error> {
400        self.wipe().await?;
401        self.run_up().await?;
402        Ok(())
403    }
404
405    /// Drop every table in the current schema, regardless of driver. Doesn't
406    /// re-run migrations — use `fresh()` for that.
407    ///
408    /// - Postgres: `DROP SCHEMA public CASCADE; CREATE SCHEMA public`.
409    /// - MySQL: enumerate user tables and drop each (with `FOREIGN_KEY_CHECKS=0`).
410    /// - SQLite: enumerate user tables in `sqlite_master` and drop each.
411    pub async fn wipe(&self) -> Result<(), Error> {
412        match self.driver() {
413            Driver::Postgres => {
414                for s in self.fresh_ddl() {
415                    self.exec(s).await?;
416                }
417            }
418            Driver::MySql => {
419                self.drop_all_mysql_tables().await?;
420            }
421            Driver::Sqlite => {
422                self.drop_all_sqlite_tables().await?;
423            }
424        }
425        Ok(())
426    }
427
428    async fn drop_all_mysql_tables(&self) -> Result<(), Error> {
429        let Pool::MySql(p) = &self.pool else {
430            return Ok(());
431        };
432        let tables: Vec<(String,)> = sqlx::query_as(
433            "SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE()",
434        )
435        .fetch_all(p)
436        .await?;
437        sqlx::query("SET FOREIGN_KEY_CHECKS = 0").execute(p).await?;
438        for (t,) in tables {
439            sqlx::query(&format!("DROP TABLE IF EXISTS `{t}`"))
440                .execute(p)
441                .await?;
442        }
443        sqlx::query("SET FOREIGN_KEY_CHECKS = 1").execute(p).await?;
444        Ok(())
445    }
446
447    async fn drop_all_sqlite_tables(&self) -> Result<(), Error> {
448        let Pool::Sqlite(p) = &self.pool else {
449            return Ok(());
450        };
451        let tables: Vec<(String,)> = sqlx::query_as(
452            "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'",
453        )
454        .fetch_all(p)
455        .await?;
456        for (t,) in tables {
457            sqlx::query(&format!("DROP TABLE IF EXISTS \"{t}\""))
458                .execute(p)
459                .await?;
460        }
461        Ok(())
462    }
463
464    pub async fn status(&self) -> Result<Vec<MigrationStatus>, Error> {
465        self.ensure_table().await?;
466        let rows = self.applied_rows().await?;
467        let applied_map: std::collections::HashMap<String, i32> = rows.into_iter().collect();
468
469        let mut out = Vec::new();
470        for m in &self.migrations {
471            let name = m.name().to_string();
472            let batch = applied_map.get(&name).copied();
473            out.push(MigrationStatus {
474                name,
475                applied: batch.is_some(),
476                batch,
477            });
478        }
479        for (db_name, batch) in &applied_map {
480            if !self.migrations.iter().any(|m| m.name() == db_name) {
481                out.push(MigrationStatus {
482                    name: db_name.clone(),
483                    applied: true,
484                    batch: Some(*batch),
485                });
486            }
487        }
488        Ok(out)
489    }
490
491    pub async fn reset(&self) -> Result<Vec<String>, Error> {
492        self.ensure_table().await?;
493        let mut rolled_total = Vec::new();
494        loop {
495            let rolled = self.rollback().await?;
496            if rolled.is_empty() {
497                break;
498            }
499            rolled_total.extend(rolled);
500        }
501        Ok(rolled_total)
502    }
503
504    pub async fn refresh(&self) -> Result<Vec<String>, Error> {
505        self.reset().await?;
506        self.run_up().await
507    }
508
509    pub async fn run_up_step(&self) -> Result<Vec<String>, Error> {
510        self.ensure_table().await?;
511        let already = self.applied().await?;
512        let mut applied = Vec::new();
513        for m in &self.migrations {
514            if already.iter().any(|a| a == m.name()) {
515                continue;
516            }
517            let batch = self.next_batch().await?;
518            let mut schema = Schema::for_driver(self.driver());
519            m.up(&mut schema);
520            self.exec_many(&schema.statements).await?;
521            self.record_applied(m.name(), batch).await?;
522            applied.push(m.name().to_string());
523            tracing::info!(name = m.name(), batch, "migration applied (stepped)");
524        }
525        Ok(applied)
526    }
527
528    pub async fn pretend(&self) -> Result<Vec<String>, Error> {
529        self.ensure_table().await?;
530        let already = self.applied().await?;
531        let mut lines = Vec::new();
532        for m in &self.migrations {
533            if already.iter().any(|a| a == m.name()) {
534                continue;
535            }
536            lines.push(format!("-- migration: {}", m.name()));
537            let mut schema = Schema::for_driver(self.driver());
538            m.up(&mut schema);
539            for stmt in &schema.statements {
540                lines.push(format!("{stmt};"));
541            }
542            lines.push(String::new());
543        }
544        Ok(lines)
545    }
546
547    pub async fn install(&self) -> Result<(), Error> {
548        self.ensure_table().await
549    }
550
551    pub fn count(&self) -> usize {
552        self.migrations.len()
553    }
554}
555
556/// Returned by `migrate:status`.
557#[derive(Debug, Clone)]
558pub struct MigrationStatus {
559    pub name: String,
560    pub applied: bool,
561    pub batch: Option<i32>,
562}
563
564#[cfg(test)]
565mod macro_tests {
566    use super::*;
567    use crate::schema::Schema;
568
569    // Exercise the `migration!` macro at compile time AND assert that it
570    // produces a Migration with the right name + up/down behaviour.
571    crate::migration!(
572        TestCreateThingsTable,
573        "2026_01_01_000003_create_things_table",
574        up = |s| {
575            s.create("things", |t| {
576                t.id();
577                t.string("name").not_null();
578            });
579        },
580        down = |s| {
581            s.drop_if_exists("things");
582        },
583    );
584
585    #[test]
586    fn closure_migration_macro_expands_into_a_working_migration() {
587        let m = TestCreateThingsTable;
588        assert_eq!(m.name(), "2026_01_01_000003_create_things_table");
589
590        // The schema builder records DDL statements as side effects of the
591        // `t.string()` / `s.drop_if_exists()` calls — we just want to check
592        // that running up/down doesn't panic and produces *some* statements.
593        let mut s_up = Schema::for_driver(Driver::Sqlite);
594        m.up(&mut s_up);
595        assert!(
596            !s_up.statements.is_empty(),
597            "up() should emit at least one DDL statement"
598        );
599
600        let mut s_down = Schema::for_driver(Driver::Sqlite);
601        m.down(&mut s_down);
602        assert!(
603            !s_down.statements.is_empty(),
604            "down() should emit at least one DDL statement"
605        );
606    }
607
608    struct NamedMigration(&'static str);
609    impl Migration for NamedMigration {
610        fn name(&self) -> &'static str {
611            self.0
612        }
613        fn up(&self, _: &mut Schema) {}
614        fn down(&self, _: &mut Schema) {}
615    }
616
617    #[test]
618    fn check_unique_names_accepts_unique() {
619        let migs: Vec<Box<dyn Migration>> = vec![
620            Box::new(NamedMigration("2026_01_01_000001_a")),
621            Box::new(NamedMigration("2026_01_01_000002_b")),
622            Box::new(NamedMigration("2026_01_01_000003_c")),
623        ];
624        check_unique_names(&migs);
625    }
626
627    #[test]
628    #[should_panic(expected = "duplicate Migration::name() values")]
629    fn check_unique_names_panics_on_collision() {
630        let migs: Vec<Box<dyn Migration>> = vec![
631            Box::new(NamedMigration("2026_01_01_000001_a")),
632            Box::new(NamedMigration("2026_01_01_000001_a")),
633        ];
634        check_unique_names(&migs);
635    }
636}