Skip to main content

cast_core/
migration.rs

1//! Migration runner. Each migration is a Rust value with `up` / `down` methods.
2//!
3//! Multi-driver: when the runner owns a `Pool::Postgres` it emits Postgres DDL;
4//! same for MySQL / SQLite. The `Schema` passed to `up`/`down` is pre-configured
5//! with the right dialect.
6
7use crate::pool::{Driver, Pool};
8use crate::schema::Schema;
9use crate::Error;
10
11pub trait Migration: Send + Sync {
12    fn name(&self) -> &'static str;
13    fn up(&self, schema: &mut Schema);
14    fn down(&self, schema: &mut Schema);
15}
16
17inventory::collect!(MigrationRegistration);
18
19pub struct MigrationRegistration {
20    pub builder: fn() -> Box<dyn Migration>,
21}
22
23pub fn collected() -> Vec<Box<dyn Migration>> {
24    inventory::iter::<MigrationRegistration>
25        .into_iter()
26        .map(|r| (r.builder)())
27        .collect()
28}
29
30pub struct MigrationRunner {
31    pool: Pool,
32    migrations: Vec<Box<dyn Migration>>,
33}
34
35impl MigrationRunner {
36    pub fn new(pool: Pool) -> Self {
37        let mut migrations = collected();
38        migrations.sort_by_key(|m| m.name().to_string());
39        Self { pool, migrations }
40    }
41
42    pub fn with_migrations(pool: Pool, mut migrations: Vec<Box<dyn Migration>>) -> Self {
43        migrations.sort_by_key(|m| m.name().to_string());
44        Self { pool, migrations }
45    }
46
47    fn driver(&self) -> Driver {
48        self.pool.driver()
49    }
50
51    // ─── per-driver SQL ──────────────────────────────────────────────────────
52
53    fn migrations_table_ddl(&self) -> &'static str {
54        match self.driver() {
55            Driver::Postgres => "CREATE TABLE IF NOT EXISTS migrations (
56                id BIGSERIAL PRIMARY KEY,
57                name TEXT NOT NULL UNIQUE,
58                batch INTEGER NOT NULL,
59                applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
60            )",
61            Driver::MySql => "CREATE TABLE IF NOT EXISTS migrations (
62                id BIGINT AUTO_INCREMENT PRIMARY KEY,
63                name VARCHAR(255) NOT NULL UNIQUE,
64                batch INT NOT NULL,
65                applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
66            )",
67            Driver::Sqlite => "CREATE TABLE IF NOT EXISTS migrations (
68                id INTEGER PRIMARY KEY AUTOINCREMENT,
69                name TEXT NOT NULL UNIQUE,
70                batch INTEGER NOT NULL,
71                applied_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
72            )",
73        }
74    }
75
76    fn fresh_ddl(&self) -> Vec<&'static str> {
77        match self.driver() {
78            Driver::Postgres => vec!["DROP SCHEMA public CASCADE", "CREATE SCHEMA public"],
79            Driver::MySql => vec![
80                // sqlx + MySQL is finicky about multi-statement scripts. We instead
81                // enumerate the tables and drop them individually below; this is just a hint.
82                "",
83            ],
84            Driver::Sqlite => vec![
85                "PRAGMA writable_schema = 1",
86                "DELETE FROM sqlite_master WHERE type IN ('table','index','trigger')",
87                "PRAGMA writable_schema = 0",
88                "VACUUM",
89            ],
90        }
91    }
92
93    // ─── helpers that dispatch to the right sqlx pool ────────────────────────
94
95    async fn exec(&self, sql: &str) -> Result<(), Error> {
96        if sql.is_empty() {
97            return Ok(());
98        }
99        match &self.pool {
100            Pool::Postgres(p) => {
101                sqlx::query(sql).execute(p).await?;
102            }
103            Pool::MySql(p) => {
104                sqlx::query(sql).execute(p).await?;
105            }
106            Pool::Sqlite(p) => {
107                sqlx::query(sql).execute(p).await?;
108            }
109        }
110        Ok(())
111    }
112
113    async fn applied_rows(&self) -> Result<Vec<(String, i32)>, Error> {
114        Ok(match &self.pool {
115            Pool::Postgres(p) => {
116                sqlx::query_as::<_, (String, i32)>("SELECT name, batch FROM migrations ORDER BY batch, id")
117                    .fetch_all(p)
118                    .await?
119            }
120            Pool::MySql(p) => {
121                sqlx::query_as::<_, (String, i32)>("SELECT name, batch FROM migrations ORDER BY batch, id")
122                    .fetch_all(p)
123                    .await?
124            }
125            Pool::Sqlite(p) => {
126                sqlx::query_as::<_, (String, i32)>("SELECT name, batch FROM migrations ORDER BY batch, id")
127                    .fetch_all(p)
128                    .await?
129            }
130        })
131    }
132
133    async fn max_batch(&self) -> Result<Option<i32>, Error> {
134        Ok(match &self.pool {
135            Pool::Postgres(p) => {
136                sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
137                    .fetch_one(p)
138                    .await?
139                    .0
140            }
141            Pool::MySql(p) => {
142                sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
143                    .fetch_one(p)
144                    .await?
145                    .0
146            }
147            Pool::Sqlite(p) => {
148                sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
149                    .fetch_one(p)
150                    .await?
151                    .0
152            }
153        })
154    }
155
156    async fn names_in_batch(&self, batch: i32) -> Result<Vec<String>, Error> {
157        let rows: Vec<(String,)> = match &self.pool {
158            Pool::Postgres(p) => sqlx::query_as("SELECT name FROM migrations WHERE batch = $1 ORDER BY id DESC")
159                .bind(batch).fetch_all(p).await?,
160            Pool::MySql(p) => sqlx::query_as("SELECT name FROM migrations WHERE batch = ? ORDER BY id DESC")
161                .bind(batch).fetch_all(p).await?,
162            Pool::Sqlite(p) => sqlx::query_as("SELECT name FROM migrations WHERE batch = ?1 ORDER BY id DESC")
163                .bind(batch).fetch_all(p).await?,
164        };
165        Ok(rows.into_iter().map(|(n,)| n).collect())
166    }
167
168    async fn record_applied(&self, name: &str, batch: i32) -> Result<(), Error> {
169        match &self.pool {
170            Pool::Postgres(p) => {
171                sqlx::query("INSERT INTO migrations (name, batch) VALUES ($1, $2)")
172                    .bind(name).bind(batch).execute(p).await?;
173            }
174            Pool::MySql(p) => {
175                sqlx::query("INSERT INTO migrations (name, batch) VALUES (?, ?)")
176                    .bind(name).bind(batch).execute(p).await?;
177            }
178            Pool::Sqlite(p) => {
179                sqlx::query("INSERT INTO migrations (name, batch) VALUES (?1, ?2)")
180                    .bind(name).bind(batch).execute(p).await?;
181            }
182        }
183        Ok(())
184    }
185
186    async fn delete_applied(&self, name: &str) -> Result<(), Error> {
187        match &self.pool {
188            Pool::Postgres(p) => {
189                sqlx::query("DELETE FROM migrations WHERE name = $1").bind(name).execute(p).await?;
190            }
191            Pool::MySql(p) => {
192                sqlx::query("DELETE FROM migrations WHERE name = ?").bind(name).execute(p).await?;
193            }
194            Pool::Sqlite(p) => {
195                sqlx::query("DELETE FROM migrations WHERE name = ?1").bind(name).execute(p).await?;
196            }
197        }
198        Ok(())
199    }
200
201    async fn exec_many(&self, stmts: &[String]) -> Result<(), Error> {
202        for s in stmts {
203            self.exec(s).await?;
204        }
205        Ok(())
206    }
207
208    // ─── public API ─────────────────────────────────────────────────────────
209
210    pub async fn ensure_table(&self) -> Result<(), Error> {
211        let ddl = self.migrations_table_ddl();
212        self.exec(ddl).await
213    }
214
215    pub async fn applied(&self) -> Result<Vec<String>, Error> {
216        Ok(self.applied_rows().await?.into_iter().map(|(n, _)| n).collect())
217    }
218
219    pub async fn next_batch(&self) -> Result<i32, Error> {
220        Ok(self.max_batch().await?.unwrap_or(0) + 1)
221    }
222
223    pub async fn run_up(&self) -> Result<Vec<String>, Error> {
224        self.ensure_table().await?;
225        let already = self.applied().await?;
226        let batch = self.next_batch().await?;
227        let mut applied = Vec::new();
228        for m in &self.migrations {
229            if already.iter().any(|a| a == m.name()) {
230                continue;
231            }
232            let mut schema = Schema::for_driver(self.driver());
233            m.up(&mut schema);
234            self.exec_many(&schema.statements).await?;
235            self.record_applied(m.name(), batch).await?;
236            applied.push(m.name().to_string());
237            tracing::info!(name = m.name(), "migration applied");
238        }
239        Ok(applied)
240    }
241
242    pub async fn rollback(&self) -> Result<Vec<String>, Error> {
243        self.ensure_table().await?;
244        let Some(batch) = self.max_batch().await? else {
245            return Ok(Vec::new());
246        };
247        let names = self.names_in_batch(batch).await?;
248        let mut rolled = Vec::new();
249        for name in names {
250            let Some(m) = self.migrations.iter().find(|m| m.name() == name) else {
251                tracing::warn!(name, "migration row in DB but not registered; skipping");
252                continue;
253            };
254            let mut schema = Schema::for_driver(self.driver());
255            m.down(&mut schema);
256            self.exec_many(&schema.statements).await?;
257            self.delete_applied(&name).await?;
258            rolled.push(name);
259        }
260        Ok(rolled)
261    }
262
263    pub async fn fresh(&self) -> Result<(), Error> {
264        // Wipe schema. MySQL doesn't have a "DROP SCHEMA public" equivalent in the
265        // user-friendly sense (it's tied to the active database), so we enumerate
266        // and drop tables individually for it.
267        match self.driver() {
268            Driver::Postgres => {
269                for s in self.fresh_ddl() {
270                    self.exec(s).await?;
271                }
272            }
273            Driver::MySql => {
274                self.drop_all_mysql_tables().await?;
275            }
276            Driver::Sqlite => {
277                self.drop_all_sqlite_tables().await?;
278            }
279        }
280        self.run_up().await?;
281        Ok(())
282    }
283
284    async fn drop_all_mysql_tables(&self) -> Result<(), Error> {
285        let Pool::MySql(p) = &self.pool else {
286            return Ok(());
287        };
288        let tables: Vec<(String,)> = sqlx::query_as(
289            "SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE()",
290        )
291        .fetch_all(p)
292        .await?;
293        sqlx::query("SET FOREIGN_KEY_CHECKS = 0").execute(p).await?;
294        for (t,) in tables {
295            sqlx::query(&format!("DROP TABLE IF EXISTS `{t}`"))
296                .execute(p)
297                .await?;
298        }
299        sqlx::query("SET FOREIGN_KEY_CHECKS = 1").execute(p).await?;
300        Ok(())
301    }
302
303    async fn drop_all_sqlite_tables(&self) -> Result<(), Error> {
304        let Pool::Sqlite(p) = &self.pool else {
305            return Ok(());
306        };
307        let tables: Vec<(String,)> = sqlx::query_as(
308            "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'",
309        )
310        .fetch_all(p)
311        .await?;
312        for (t,) in tables {
313            sqlx::query(&format!("DROP TABLE IF EXISTS \"{t}\""))
314                .execute(p)
315                .await?;
316        }
317        Ok(())
318    }
319
320    pub async fn status(&self) -> Result<Vec<MigrationStatus>, Error> {
321        self.ensure_table().await?;
322        let rows = self.applied_rows().await?;
323        let applied_map: std::collections::HashMap<String, i32> = rows.into_iter().collect();
324
325        let mut out = Vec::new();
326        for m in &self.migrations {
327            let name = m.name().to_string();
328            let batch = applied_map.get(&name).copied();
329            out.push(MigrationStatus {
330                name,
331                applied: batch.is_some(),
332                batch,
333            });
334        }
335        for (db_name, batch) in &applied_map {
336            if !self.migrations.iter().any(|m| m.name() == db_name) {
337                out.push(MigrationStatus {
338                    name: db_name.clone(),
339                    applied: true,
340                    batch: Some(*batch),
341                });
342            }
343        }
344        Ok(out)
345    }
346
347    pub async fn reset(&self) -> Result<Vec<String>, Error> {
348        self.ensure_table().await?;
349        let mut rolled_total = Vec::new();
350        loop {
351            let rolled = self.rollback().await?;
352            if rolled.is_empty() {
353                break;
354            }
355            rolled_total.extend(rolled);
356        }
357        Ok(rolled_total)
358    }
359
360    pub async fn refresh(&self) -> Result<Vec<String>, Error> {
361        self.reset().await?;
362        self.run_up().await
363    }
364
365    pub async fn run_up_step(&self) -> Result<Vec<String>, Error> {
366        self.ensure_table().await?;
367        let already = self.applied().await?;
368        let mut applied = Vec::new();
369        for m in &self.migrations {
370            if already.iter().any(|a| a == m.name()) {
371                continue;
372            }
373            let batch = self.next_batch().await?;
374            let mut schema = Schema::for_driver(self.driver());
375            m.up(&mut schema);
376            self.exec_many(&schema.statements).await?;
377            self.record_applied(m.name(), batch).await?;
378            applied.push(m.name().to_string());
379            tracing::info!(name = m.name(), batch, "migration applied (stepped)");
380        }
381        Ok(applied)
382    }
383
384    pub async fn pretend(&self) -> Result<Vec<String>, Error> {
385        self.ensure_table().await?;
386        let already = self.applied().await?;
387        let mut lines = Vec::new();
388        for m in &self.migrations {
389            if already.iter().any(|a| a == m.name()) {
390                continue;
391            }
392            lines.push(format!("-- migration: {}", m.name()));
393            let mut schema = Schema::for_driver(self.driver());
394            m.up(&mut schema);
395            for stmt in &schema.statements {
396                lines.push(format!("{stmt};"));
397            }
398            lines.push(String::new());
399        }
400        Ok(lines)
401    }
402
403    pub async fn install(&self) -> Result<(), Error> {
404        self.ensure_table().await
405    }
406
407    pub fn count(&self) -> usize {
408        self.migrations.len()
409    }
410}
411
412/// Returned by `migrate:status`.
413#[derive(Debug, Clone)]
414pub struct MigrationStatus {
415    pub name: String,
416    pub applied: bool,
417    pub batch: Option<i32>,
418}