Skip to main content

cast_core/
migration.rs

1//! Migration runner. Each migration is a Rust value with `up` / `down` methods.
2//!
3//! Multi-driver: when the runner owns a `Pool::Postgres` it emits Postgres DDL;
4//! same for MySQL / SQLite. The `Schema` passed to `up`/`down` is pre-configured
5//! with the right dialect.
6
7use crate::pool::{Driver, Pool};
8use crate::schema::Schema;
9use crate::Error;
10
11pub trait Migration: Send + Sync {
12    fn name(&self) -> &'static str;
13    fn up(&self, schema: &mut Schema);
14    fn down(&self, schema: &mut Schema);
15}
16
17inventory::collect!(MigrationRegistration);
18
19pub struct MigrationRegistration {
20    pub builder: fn() -> Box<dyn Migration>,
21}
22
23pub fn collected() -> Vec<Box<dyn Migration>> {
24    inventory::iter::<MigrationRegistration>
25        .into_iter()
26        .map(|r| (r.builder)())
27        .collect()
28}
29
30pub struct MigrationRunner {
31    pool: Pool,
32    migrations: Vec<Box<dyn Migration>>,
33}
34
35impl MigrationRunner {
36    pub fn new(pool: Pool) -> Self {
37        let mut migrations = collected();
38        migrations.sort_by_key(|m| m.name().to_string());
39        Self { pool, migrations }
40    }
41
42    pub fn with_migrations(pool: Pool, mut migrations: Vec<Box<dyn Migration>>) -> Self {
43        migrations.sort_by_key(|m| m.name().to_string());
44        Self { pool, migrations }
45    }
46
47    fn driver(&self) -> Driver {
48        self.pool.driver()
49    }
50
51    // ─── per-driver SQL ──────────────────────────────────────────────────────
52
53    fn migrations_table_ddl(&self) -> &'static str {
54        match self.driver() {
55            Driver::Postgres => {
56                "CREATE TABLE IF NOT EXISTS migrations (
57                id BIGSERIAL PRIMARY KEY,
58                name TEXT NOT NULL UNIQUE,
59                batch INTEGER NOT NULL,
60                applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
61            )"
62            }
63            Driver::MySql => {
64                "CREATE TABLE IF NOT EXISTS migrations (
65                id BIGINT AUTO_INCREMENT PRIMARY KEY,
66                name VARCHAR(255) NOT NULL UNIQUE,
67                batch INT NOT NULL,
68                applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
69            )"
70            }
71            Driver::Sqlite => {
72                "CREATE TABLE IF NOT EXISTS migrations (
73                id INTEGER PRIMARY KEY AUTOINCREMENT,
74                name TEXT NOT NULL UNIQUE,
75                batch INTEGER NOT NULL,
76                applied_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
77            )"
78            }
79        }
80    }
81
82    fn fresh_ddl(&self) -> Vec<&'static str> {
83        match self.driver() {
84            Driver::Postgres => vec!["DROP SCHEMA public CASCADE", "CREATE SCHEMA public"],
85            Driver::MySql => vec![
86                // sqlx + MySQL is finicky about multi-statement scripts. We instead
87                // enumerate the tables and drop them individually below; this is just a hint.
88                "",
89            ],
90            Driver::Sqlite => vec![
91                "PRAGMA writable_schema = 1",
92                "DELETE FROM sqlite_master WHERE type IN ('table','index','trigger')",
93                "PRAGMA writable_schema = 0",
94                "VACUUM",
95            ],
96        }
97    }
98
99    // ─── helpers that dispatch to the right sqlx pool ────────────────────────
100
101    async fn exec(&self, sql: &str) -> Result<(), Error> {
102        if sql.is_empty() {
103            return Ok(());
104        }
105        match &self.pool {
106            Pool::Postgres(p) => {
107                sqlx::query(sql).execute(p).await?;
108            }
109            Pool::MySql(p) => {
110                sqlx::query(sql).execute(p).await?;
111            }
112            Pool::Sqlite(p) => {
113                sqlx::query(sql).execute(p).await?;
114            }
115        }
116        Ok(())
117    }
118
119    async fn applied_rows(&self) -> Result<Vec<(String, i32)>, Error> {
120        Ok(match &self.pool {
121            Pool::Postgres(p) => {
122                sqlx::query_as::<_, (String, i32)>(
123                    "SELECT name, batch FROM migrations ORDER BY batch, id",
124                )
125                .fetch_all(p)
126                .await?
127            }
128            Pool::MySql(p) => {
129                sqlx::query_as::<_, (String, i32)>(
130                    "SELECT name, batch FROM migrations ORDER BY batch, id",
131                )
132                .fetch_all(p)
133                .await?
134            }
135            Pool::Sqlite(p) => {
136                sqlx::query_as::<_, (String, i32)>(
137                    "SELECT name, batch FROM migrations ORDER BY batch, id",
138                )
139                .fetch_all(p)
140                .await?
141            }
142        })
143    }
144
145    async fn max_batch(&self) -> Result<Option<i32>, Error> {
146        Ok(match &self.pool {
147            Pool::Postgres(p) => {
148                sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
149                    .fetch_one(p)
150                    .await?
151                    .0
152            }
153            Pool::MySql(p) => {
154                sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
155                    .fetch_one(p)
156                    .await?
157                    .0
158            }
159            Pool::Sqlite(p) => {
160                sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
161                    .fetch_one(p)
162                    .await?
163                    .0
164            }
165        })
166    }
167
168    async fn names_in_batch(&self, batch: i32) -> Result<Vec<String>, Error> {
169        let rows: Vec<(String,)> = match &self.pool {
170            Pool::Postgres(p) => {
171                sqlx::query_as("SELECT name FROM migrations WHERE batch = $1 ORDER BY id DESC")
172                    .bind(batch)
173                    .fetch_all(p)
174                    .await?
175            }
176            Pool::MySql(p) => {
177                sqlx::query_as("SELECT name FROM migrations WHERE batch = ? ORDER BY id DESC")
178                    .bind(batch)
179                    .fetch_all(p)
180                    .await?
181            }
182            Pool::Sqlite(p) => {
183                sqlx::query_as("SELECT name FROM migrations WHERE batch = ?1 ORDER BY id DESC")
184                    .bind(batch)
185                    .fetch_all(p)
186                    .await?
187            }
188        };
189        Ok(rows.into_iter().map(|(n,)| n).collect())
190    }
191
192    async fn record_applied(&self, name: &str, batch: i32) -> Result<(), Error> {
193        match &self.pool {
194            Pool::Postgres(p) => {
195                sqlx::query("INSERT INTO migrations (name, batch) VALUES ($1, $2)")
196                    .bind(name)
197                    .bind(batch)
198                    .execute(p)
199                    .await?;
200            }
201            Pool::MySql(p) => {
202                sqlx::query("INSERT INTO migrations (name, batch) VALUES (?, ?)")
203                    .bind(name)
204                    .bind(batch)
205                    .execute(p)
206                    .await?;
207            }
208            Pool::Sqlite(p) => {
209                sqlx::query("INSERT INTO migrations (name, batch) VALUES (?1, ?2)")
210                    .bind(name)
211                    .bind(batch)
212                    .execute(p)
213                    .await?;
214            }
215        }
216        Ok(())
217    }
218
219    async fn delete_applied(&self, name: &str) -> Result<(), Error> {
220        match &self.pool {
221            Pool::Postgres(p) => {
222                sqlx::query("DELETE FROM migrations WHERE name = $1")
223                    .bind(name)
224                    .execute(p)
225                    .await?;
226            }
227            Pool::MySql(p) => {
228                sqlx::query("DELETE FROM migrations WHERE name = ?")
229                    .bind(name)
230                    .execute(p)
231                    .await?;
232            }
233            Pool::Sqlite(p) => {
234                sqlx::query("DELETE FROM migrations WHERE name = ?1")
235                    .bind(name)
236                    .execute(p)
237                    .await?;
238            }
239        }
240        Ok(())
241    }
242
243    async fn exec_many(&self, stmts: &[String]) -> Result<(), Error> {
244        for s in stmts {
245            self.exec(s).await?;
246        }
247        Ok(())
248    }
249
250    // ─── public API ─────────────────────────────────────────────────────────
251
252    pub async fn ensure_table(&self) -> Result<(), Error> {
253        let ddl = self.migrations_table_ddl();
254        self.exec(ddl).await
255    }
256
257    pub async fn applied(&self) -> Result<Vec<String>, Error> {
258        Ok(self
259            .applied_rows()
260            .await?
261            .into_iter()
262            .map(|(n, _)| n)
263            .collect())
264    }
265
266    pub async fn next_batch(&self) -> Result<i32, Error> {
267        Ok(self.max_batch().await?.unwrap_or(0) + 1)
268    }
269
270    pub async fn run_up(&self) -> Result<Vec<String>, Error> {
271        self.ensure_table().await?;
272        let already = self.applied().await?;
273        let batch = self.next_batch().await?;
274        let mut applied = Vec::new();
275        for m in &self.migrations {
276            if already.iter().any(|a| a == m.name()) {
277                continue;
278            }
279            let mut schema = Schema::for_driver(self.driver());
280            m.up(&mut schema);
281            self.exec_many(&schema.statements).await?;
282            self.record_applied(m.name(), batch).await?;
283            applied.push(m.name().to_string());
284            tracing::info!(name = m.name(), "migration applied");
285        }
286        Ok(applied)
287    }
288
289    pub async fn rollback(&self) -> Result<Vec<String>, Error> {
290        self.ensure_table().await?;
291        let Some(batch) = self.max_batch().await? else {
292            return Ok(Vec::new());
293        };
294        let names = self.names_in_batch(batch).await?;
295        let mut rolled = Vec::new();
296        for name in names {
297            let Some(m) = self.migrations.iter().find(|m| m.name() == name) else {
298                tracing::warn!(name, "migration row in DB but not registered; skipping");
299                continue;
300            };
301            let mut schema = Schema::for_driver(self.driver());
302            m.down(&mut schema);
303            self.exec_many(&schema.statements).await?;
304            self.delete_applied(&name).await?;
305            rolled.push(name);
306        }
307        Ok(rolled)
308    }
309
310    pub async fn fresh(&self) -> Result<(), Error> {
311        // Wipe schema. MySQL doesn't have a "DROP SCHEMA public" equivalent in the
312        // user-friendly sense (it's tied to the active database), so we enumerate
313        // and drop tables individually for it.
314        match self.driver() {
315            Driver::Postgres => {
316                for s in self.fresh_ddl() {
317                    self.exec(s).await?;
318                }
319            }
320            Driver::MySql => {
321                self.drop_all_mysql_tables().await?;
322            }
323            Driver::Sqlite => {
324                self.drop_all_sqlite_tables().await?;
325            }
326        }
327        self.run_up().await?;
328        Ok(())
329    }
330
331    async fn drop_all_mysql_tables(&self) -> Result<(), Error> {
332        let Pool::MySql(p) = &self.pool else {
333            return Ok(());
334        };
335        let tables: Vec<(String,)> = sqlx::query_as(
336            "SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE()",
337        )
338        .fetch_all(p)
339        .await?;
340        sqlx::query("SET FOREIGN_KEY_CHECKS = 0").execute(p).await?;
341        for (t,) in tables {
342            sqlx::query(&format!("DROP TABLE IF EXISTS `{t}`"))
343                .execute(p)
344                .await?;
345        }
346        sqlx::query("SET FOREIGN_KEY_CHECKS = 1").execute(p).await?;
347        Ok(())
348    }
349
350    async fn drop_all_sqlite_tables(&self) -> Result<(), Error> {
351        let Pool::Sqlite(p) = &self.pool else {
352            return Ok(());
353        };
354        let tables: Vec<(String,)> = sqlx::query_as(
355            "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'",
356        )
357        .fetch_all(p)
358        .await?;
359        for (t,) in tables {
360            sqlx::query(&format!("DROP TABLE IF EXISTS \"{t}\""))
361                .execute(p)
362                .await?;
363        }
364        Ok(())
365    }
366
367    pub async fn status(&self) -> Result<Vec<MigrationStatus>, Error> {
368        self.ensure_table().await?;
369        let rows = self.applied_rows().await?;
370        let applied_map: std::collections::HashMap<String, i32> = rows.into_iter().collect();
371
372        let mut out = Vec::new();
373        for m in &self.migrations {
374            let name = m.name().to_string();
375            let batch = applied_map.get(&name).copied();
376            out.push(MigrationStatus {
377                name,
378                applied: batch.is_some(),
379                batch,
380            });
381        }
382        for (db_name, batch) in &applied_map {
383            if !self.migrations.iter().any(|m| m.name() == db_name) {
384                out.push(MigrationStatus {
385                    name: db_name.clone(),
386                    applied: true,
387                    batch: Some(*batch),
388                });
389            }
390        }
391        Ok(out)
392    }
393
394    pub async fn reset(&self) -> Result<Vec<String>, Error> {
395        self.ensure_table().await?;
396        let mut rolled_total = Vec::new();
397        loop {
398            let rolled = self.rollback().await?;
399            if rolled.is_empty() {
400                break;
401            }
402            rolled_total.extend(rolled);
403        }
404        Ok(rolled_total)
405    }
406
407    pub async fn refresh(&self) -> Result<Vec<String>, Error> {
408        self.reset().await?;
409        self.run_up().await
410    }
411
412    pub async fn run_up_step(&self) -> Result<Vec<String>, Error> {
413        self.ensure_table().await?;
414        let already = self.applied().await?;
415        let mut applied = Vec::new();
416        for m in &self.migrations {
417            if already.iter().any(|a| a == m.name()) {
418                continue;
419            }
420            let batch = self.next_batch().await?;
421            let mut schema = Schema::for_driver(self.driver());
422            m.up(&mut schema);
423            self.exec_many(&schema.statements).await?;
424            self.record_applied(m.name(), batch).await?;
425            applied.push(m.name().to_string());
426            tracing::info!(name = m.name(), batch, "migration applied (stepped)");
427        }
428        Ok(applied)
429    }
430
431    pub async fn pretend(&self) -> Result<Vec<String>, Error> {
432        self.ensure_table().await?;
433        let already = self.applied().await?;
434        let mut lines = Vec::new();
435        for m in &self.migrations {
436            if already.iter().any(|a| a == m.name()) {
437                continue;
438            }
439            lines.push(format!("-- migration: {}", m.name()));
440            let mut schema = Schema::for_driver(self.driver());
441            m.up(&mut schema);
442            for stmt in &schema.statements {
443                lines.push(format!("{stmt};"));
444            }
445            lines.push(String::new());
446        }
447        Ok(lines)
448    }
449
450    pub async fn install(&self) -> Result<(), Error> {
451        self.ensure_table().await
452    }
453
454    pub fn count(&self) -> usize {
455        self.migrations.len()
456    }
457}
458
459/// Returned by `migrate:status`.
460#[derive(Debug, Clone)]
461pub struct MigrationStatus {
462    pub name: String,
463    pub applied: bool,
464    pub batch: Option<i32>,
465}