1use crate::pool::Pool;
4use crate::schema::Schema;
5use crate::Error;
6
7pub trait Migration: Send + Sync {
8 fn name(&self) -> &'static str;
9 fn up(&self, schema: &mut Schema);
10 fn down(&self, schema: &mut Schema);
11}
12
13inventory::collect!(MigrationRegistration);
14
15pub struct MigrationRegistration {
16 pub builder: fn() -> Box<dyn Migration>,
17}
18
19pub fn collected() -> Vec<Box<dyn Migration>> {
20 inventory::iter::<MigrationRegistration>
21 .into_iter()
22 .map(|r| (r.builder)())
23 .collect()
24}
25
26pub struct MigrationRunner {
27 pool: Pool,
28 migrations: Vec<Box<dyn Migration>>,
29}
30
31impl MigrationRunner {
32 pub fn new(pool: Pool) -> Self {
33 let mut migrations = collected();
34 migrations.sort_by_key(|m| m.name().to_string());
35 Self { pool, migrations }
36 }
37
38 pub fn with_migrations(pool: Pool, mut migrations: Vec<Box<dyn Migration>>) -> Self {
39 migrations.sort_by_key(|m| m.name().to_string());
40 Self { pool, migrations }
41 }
42
43 pub async fn ensure_table(&self) -> Result<(), Error> {
44 sqlx::query(
45 "CREATE TABLE IF NOT EXISTS migrations (
46 id BIGSERIAL PRIMARY KEY,
47 name TEXT NOT NULL UNIQUE,
48 batch INTEGER NOT NULL,
49 applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
50 )",
51 )
52 .execute(&self.pool)
53 .await?;
54 Ok(())
55 }
56
57 pub async fn applied(&self) -> Result<Vec<String>, Error> {
58 let rows: Vec<(String,)> = sqlx::query_as("SELECT name FROM migrations ORDER BY batch, id")
59 .fetch_all(&self.pool)
60 .await?;
61 Ok(rows.into_iter().map(|(n,)| n).collect())
62 }
63
64 pub async fn next_batch(&self) -> Result<i32, Error> {
65 let (max_batch,): (Option<i32>,) =
66 sqlx::query_as("SELECT MAX(batch) FROM migrations")
67 .fetch_one(&self.pool)
68 .await?;
69 Ok(max_batch.unwrap_or(0) + 1)
70 }
71
72 pub async fn run_up(&self) -> Result<Vec<String>, Error> {
73 self.ensure_table().await?;
74 let already = self.applied().await?;
75 let batch = self.next_batch().await?;
76 let mut applied = Vec::new();
77 for m in &self.migrations {
78 if already.iter().any(|a| a == m.name()) {
79 continue;
80 }
81 let mut schema = Schema::new();
82 m.up(&mut schema);
83
84 let mut tx = self.pool.begin().await?;
85 for stmt in &schema.statements {
86 sqlx::query(stmt).execute(&mut *tx).await?;
87 }
88 sqlx::query("INSERT INTO migrations (name, batch) VALUES ($1, $2)")
89 .bind(m.name())
90 .bind(batch)
91 .execute(&mut *tx)
92 .await?;
93 tx.commit().await?;
94 applied.push(m.name().to_string());
95 tracing::info!(name = m.name(), "migration applied");
96 }
97 Ok(applied)
98 }
99
100 pub async fn rollback(&self) -> Result<Vec<String>, Error> {
101 self.ensure_table().await?;
102 let (max_batch,): (Option<i32>,) = sqlx::query_as("SELECT MAX(batch) FROM migrations")
103 .fetch_one(&self.pool)
104 .await?;
105 let Some(batch) = max_batch else {
106 return Ok(Vec::new());
107 };
108 let rows: Vec<(String,)> = sqlx::query_as(
109 "SELECT name FROM migrations WHERE batch = $1 ORDER BY id DESC",
110 )
111 .bind(batch)
112 .fetch_all(&self.pool)
113 .await?;
114 let names: Vec<String> = rows.into_iter().map(|(n,)| n).collect();
115
116 let mut rolled = Vec::new();
117 for name in names {
118 let Some(m) = self.migrations.iter().find(|m| m.name() == name) else {
119 tracing::warn!(name, "migration row in DB but not registered; skipping");
120 continue;
121 };
122 let mut schema = Schema::new();
123 m.down(&mut schema);
124 let mut tx = self.pool.begin().await?;
125 for stmt in &schema.statements {
126 sqlx::query(stmt).execute(&mut *tx).await?;
127 }
128 sqlx::query("DELETE FROM migrations WHERE name = $1")
129 .bind(&name)
130 .execute(&mut *tx)
131 .await?;
132 tx.commit().await?;
133 rolled.push(name);
134 }
135 Ok(rolled)
136 }
137
138 pub async fn fresh(&self) -> Result<(), Error> {
139 sqlx::query("DROP SCHEMA public CASCADE; CREATE SCHEMA public;")
140 .execute(&self.pool)
141 .await?;
142 self.run_up().await?;
143 Ok(())
144 }
145}