Skip to main content

cast_core/
migration.rs

1//! Migration runner. Each migration is a Rust value with `up` / `down` methods.
2
3use crate::pool::Pool;
4use crate::schema::Schema;
5use crate::Error;
6
7pub trait Migration: Send + Sync {
8    fn name(&self) -> &'static str;
9    fn up(&self, schema: &mut Schema);
10    fn down(&self, schema: &mut Schema);
11}
12
13inventory::collect!(MigrationRegistration);
14
15pub struct MigrationRegistration {
16    pub builder: fn() -> Box<dyn Migration>,
17}
18
19pub fn collected() -> Vec<Box<dyn Migration>> {
20    inventory::iter::<MigrationRegistration>
21        .into_iter()
22        .map(|r| (r.builder)())
23        .collect()
24}
25
26pub struct MigrationRunner {
27    pool: Pool,
28    migrations: Vec<Box<dyn Migration>>,
29}
30
31impl MigrationRunner {
32    pub fn new(pool: Pool) -> Self {
33        let mut migrations = collected();
34        migrations.sort_by_key(|m| m.name().to_string());
35        Self { pool, migrations }
36    }
37
38    pub fn with_migrations(pool: Pool, mut migrations: Vec<Box<dyn Migration>>) -> Self {
39        migrations.sort_by_key(|m| m.name().to_string());
40        Self { pool, migrations }
41    }
42
43    pub async fn ensure_table(&self) -> Result<(), Error> {
44        sqlx::query(
45            "CREATE TABLE IF NOT EXISTS migrations (
46                id BIGSERIAL PRIMARY KEY,
47                name TEXT NOT NULL UNIQUE,
48                batch INTEGER NOT NULL,
49                applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
50            )",
51        )
52        .execute(&self.pool)
53        .await?;
54        Ok(())
55    }
56
57    pub async fn applied(&self) -> Result<Vec<String>, Error> {
58        let rows: Vec<(String,)> = sqlx::query_as("SELECT name FROM migrations ORDER BY batch, id")
59            .fetch_all(&self.pool)
60            .await?;
61        Ok(rows.into_iter().map(|(n,)| n).collect())
62    }
63
64    pub async fn next_batch(&self) -> Result<i32, Error> {
65        let (max_batch,): (Option<i32>,) =
66            sqlx::query_as("SELECT MAX(batch) FROM migrations")
67                .fetch_one(&self.pool)
68                .await?;
69        Ok(max_batch.unwrap_or(0) + 1)
70    }
71
72    pub async fn run_up(&self) -> Result<Vec<String>, Error> {
73        self.ensure_table().await?;
74        let already = self.applied().await?;
75        let batch = self.next_batch().await?;
76        let mut applied = Vec::new();
77        for m in &self.migrations {
78            if already.iter().any(|a| a == m.name()) {
79                continue;
80            }
81            let mut schema = Schema::new();
82            m.up(&mut schema);
83
84            let mut tx = self.pool.begin().await?;
85            for stmt in &schema.statements {
86                sqlx::query(stmt).execute(&mut *tx).await?;
87            }
88            sqlx::query("INSERT INTO migrations (name, batch) VALUES ($1, $2)")
89                .bind(m.name())
90                .bind(batch)
91                .execute(&mut *tx)
92                .await?;
93            tx.commit().await?;
94            applied.push(m.name().to_string());
95            tracing::info!(name = m.name(), "migration applied");
96        }
97        Ok(applied)
98    }
99
100    pub async fn rollback(&self) -> Result<Vec<String>, Error> {
101        self.ensure_table().await?;
102        let (max_batch,): (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
103            .fetch_one(&self.pool)
104            .await?;
105        let Some(batch) = max_batch else {
106            return Ok(Vec::new());
107        };
108        let rows: Vec<(String,)> = sqlx::query_as(
109            "SELECT name FROM migrations WHERE batch = $1 ORDER BY id DESC",
110        )
111        .bind(batch)
112        .fetch_all(&self.pool)
113        .await?;
114        let names: Vec<String> = rows.into_iter().map(|(n,)| n).collect();
115
116        let mut rolled = Vec::new();
117        for name in names {
118            let Some(m) = self.migrations.iter().find(|m| m.name() == name) else {
119                tracing::warn!(name, "migration row in DB but not registered; skipping");
120                continue;
121            };
122            let mut schema = Schema::new();
123            m.down(&mut schema);
124            let mut tx = self.pool.begin().await?;
125            for stmt in &schema.statements {
126                sqlx::query(stmt).execute(&mut *tx).await?;
127            }
128            sqlx::query("DELETE FROM migrations WHERE name = $1")
129                .bind(&name)
130                .execute(&mut *tx)
131                .await?;
132            tx.commit().await?;
133            rolled.push(name);
134        }
135        Ok(rolled)
136    }
137
138    pub async fn fresh(&self) -> Result<(), Error> {
139        sqlx::query("DROP SCHEMA public CASCADE; CREATE SCHEMA public;")
140            .execute(&self.pool)
141            .await?;
142        self.run_up().await?;
143        Ok(())
144    }
145}