1use crate::pool::{Driver, Pool};
8use crate::schema::Schema;
9use crate::Error;
10
11pub trait Migration: Send + Sync {
12 fn name(&self) -> &'static str;
13 fn up(&self, schema: &mut Schema);
14 fn down(&self, schema: &mut Schema);
15}
16
17inventory::collect!(MigrationRegistration);
18
19pub struct MigrationRegistration {
20 pub builder: fn() -> Box<dyn Migration>,
21}
22
23pub fn collected() -> Vec<Box<dyn Migration>> {
24 inventory::iter::<MigrationRegistration>
25 .into_iter()
26 .map(|r| (r.builder)())
27 .collect()
28}
29
30pub struct MigrationRunner {
31 pool: Pool,
32 migrations: Vec<Box<dyn Migration>>,
33}
34
35impl MigrationRunner {
36 pub fn new(pool: Pool) -> Self {
37 let mut migrations = collected();
38 migrations.sort_by_key(|m| m.name().to_string());
39 Self { pool, migrations }
40 }
41
42 pub fn with_migrations(pool: Pool, mut migrations: Vec<Box<dyn Migration>>) -> Self {
43 migrations.sort_by_key(|m| m.name().to_string());
44 Self { pool, migrations }
45 }
46
47 fn driver(&self) -> Driver {
48 self.pool.driver()
49 }
50
51 fn migrations_table_ddl(&self) -> &'static str {
54 match self.driver() {
55 Driver::Postgres => "CREATE TABLE IF NOT EXISTS migrations (
56 id BIGSERIAL PRIMARY KEY,
57 name TEXT NOT NULL UNIQUE,
58 batch INTEGER NOT NULL,
59 applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
60 )",
61 Driver::MySql => "CREATE TABLE IF NOT EXISTS migrations (
62 id BIGINT AUTO_INCREMENT PRIMARY KEY,
63 name VARCHAR(255) NOT NULL UNIQUE,
64 batch INT NOT NULL,
65 applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
66 )",
67 Driver::Sqlite => "CREATE TABLE IF NOT EXISTS migrations (
68 id INTEGER PRIMARY KEY AUTOINCREMENT,
69 name TEXT NOT NULL UNIQUE,
70 batch INTEGER NOT NULL,
71 applied_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
72 )",
73 }
74 }
75
76 fn fresh_ddl(&self) -> Vec<&'static str> {
77 match self.driver() {
78 Driver::Postgres => vec!["DROP SCHEMA public CASCADE", "CREATE SCHEMA public"],
79 Driver::MySql => vec![
80 "",
83 ],
84 Driver::Sqlite => vec![
85 "PRAGMA writable_schema = 1",
86 "DELETE FROM sqlite_master WHERE type IN ('table','index','trigger')",
87 "PRAGMA writable_schema = 0",
88 "VACUUM",
89 ],
90 }
91 }
92
93 async fn exec(&self, sql: &str) -> Result<(), Error> {
96 if sql.is_empty() {
97 return Ok(());
98 }
99 match &self.pool {
100 Pool::Postgres(p) => {
101 sqlx::query(sql).execute(p).await?;
102 }
103 Pool::MySql(p) => {
104 sqlx::query(sql).execute(p).await?;
105 }
106 Pool::Sqlite(p) => {
107 sqlx::query(sql).execute(p).await?;
108 }
109 }
110 Ok(())
111 }
112
113 async fn applied_rows(&self) -> Result<Vec<(String, i32)>, Error> {
114 Ok(match &self.pool {
115 Pool::Postgres(p) => {
116 sqlx::query_as::<_, (String, i32)>("SELECT name, batch FROM migrations ORDER BY batch, id")
117 .fetch_all(p)
118 .await?
119 }
120 Pool::MySql(p) => {
121 sqlx::query_as::<_, (String, i32)>("SELECT name, batch FROM migrations ORDER BY batch, id")
122 .fetch_all(p)
123 .await?
124 }
125 Pool::Sqlite(p) => {
126 sqlx::query_as::<_, (String, i32)>("SELECT name, batch FROM migrations ORDER BY batch, id")
127 .fetch_all(p)
128 .await?
129 }
130 })
131 }
132
133 async fn max_batch(&self) -> Result<Option<i32>, Error> {
134 Ok(match &self.pool {
135 Pool::Postgres(p) => {
136 sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
137 .fetch_one(p)
138 .await?
139 .0
140 }
141 Pool::MySql(p) => {
142 sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
143 .fetch_one(p)
144 .await?
145 .0
146 }
147 Pool::Sqlite(p) => {
148 sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
149 .fetch_one(p)
150 .await?
151 .0
152 }
153 })
154 }
155
156 async fn names_in_batch(&self, batch: i32) -> Result<Vec<String>, Error> {
157 let rows: Vec<(String,)> = match &self.pool {
158 Pool::Postgres(p) => sqlx::query_as("SELECT name FROM migrations WHERE batch = $1 ORDER BY id DESC")
159 .bind(batch).fetch_all(p).await?,
160 Pool::MySql(p) => sqlx::query_as("SELECT name FROM migrations WHERE batch = ? ORDER BY id DESC")
161 .bind(batch).fetch_all(p).await?,
162 Pool::Sqlite(p) => sqlx::query_as("SELECT name FROM migrations WHERE batch = ?1 ORDER BY id DESC")
163 .bind(batch).fetch_all(p).await?,
164 };
165 Ok(rows.into_iter().map(|(n,)| n).collect())
166 }
167
168 async fn record_applied(&self, name: &str, batch: i32) -> Result<(), Error> {
169 match &self.pool {
170 Pool::Postgres(p) => {
171 sqlx::query("INSERT INTO migrations (name, batch) VALUES ($1, $2)")
172 .bind(name).bind(batch).execute(p).await?;
173 }
174 Pool::MySql(p) => {
175 sqlx::query("INSERT INTO migrations (name, batch) VALUES (?, ?)")
176 .bind(name).bind(batch).execute(p).await?;
177 }
178 Pool::Sqlite(p) => {
179 sqlx::query("INSERT INTO migrations (name, batch) VALUES (?1, ?2)")
180 .bind(name).bind(batch).execute(p).await?;
181 }
182 }
183 Ok(())
184 }
185
186 async fn delete_applied(&self, name: &str) -> Result<(), Error> {
187 match &self.pool {
188 Pool::Postgres(p) => {
189 sqlx::query("DELETE FROM migrations WHERE name = $1").bind(name).execute(p).await?;
190 }
191 Pool::MySql(p) => {
192 sqlx::query("DELETE FROM migrations WHERE name = ?").bind(name).execute(p).await?;
193 }
194 Pool::Sqlite(p) => {
195 sqlx::query("DELETE FROM migrations WHERE name = ?1").bind(name).execute(p).await?;
196 }
197 }
198 Ok(())
199 }
200
201 async fn exec_many(&self, stmts: &[String]) -> Result<(), Error> {
202 for s in stmts {
203 self.exec(s).await?;
204 }
205 Ok(())
206 }
207
208 pub async fn ensure_table(&self) -> Result<(), Error> {
211 let ddl = self.migrations_table_ddl();
212 self.exec(ddl).await
213 }
214
215 pub async fn applied(&self) -> Result<Vec<String>, Error> {
216 Ok(self.applied_rows().await?.into_iter().map(|(n, _)| n).collect())
217 }
218
219 pub async fn next_batch(&self) -> Result<i32, Error> {
220 Ok(self.max_batch().await?.unwrap_or(0) + 1)
221 }
222
223 pub async fn run_up(&self) -> Result<Vec<String>, Error> {
224 self.ensure_table().await?;
225 let already = self.applied().await?;
226 let batch = self.next_batch().await?;
227 let mut applied = Vec::new();
228 for m in &self.migrations {
229 if already.iter().any(|a| a == m.name()) {
230 continue;
231 }
232 let mut schema = Schema::for_driver(self.driver());
233 m.up(&mut schema);
234 self.exec_many(&schema.statements).await?;
235 self.record_applied(m.name(), batch).await?;
236 applied.push(m.name().to_string());
237 tracing::info!(name = m.name(), "migration applied");
238 }
239 Ok(applied)
240 }
241
242 pub async fn rollback(&self) -> Result<Vec<String>, Error> {
243 self.ensure_table().await?;
244 let Some(batch) = self.max_batch().await? else {
245 return Ok(Vec::new());
246 };
247 let names = self.names_in_batch(batch).await?;
248 let mut rolled = Vec::new();
249 for name in names {
250 let Some(m) = self.migrations.iter().find(|m| m.name() == name) else {
251 tracing::warn!(name, "migration row in DB but not registered; skipping");
252 continue;
253 };
254 let mut schema = Schema::for_driver(self.driver());
255 m.down(&mut schema);
256 self.exec_many(&schema.statements).await?;
257 self.delete_applied(&name).await?;
258 rolled.push(name);
259 }
260 Ok(rolled)
261 }
262
263 pub async fn fresh(&self) -> Result<(), Error> {
264 match self.driver() {
268 Driver::Postgres => {
269 for s in self.fresh_ddl() {
270 self.exec(s).await?;
271 }
272 }
273 Driver::MySql => {
274 self.drop_all_mysql_tables().await?;
275 }
276 Driver::Sqlite => {
277 self.drop_all_sqlite_tables().await?;
278 }
279 }
280 self.run_up().await?;
281 Ok(())
282 }
283
284 async fn drop_all_mysql_tables(&self) -> Result<(), Error> {
285 let Pool::MySql(p) = &self.pool else {
286 return Ok(());
287 };
288 let tables: Vec<(String,)> = sqlx::query_as(
289 "SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE()",
290 )
291 .fetch_all(p)
292 .await?;
293 sqlx::query("SET FOREIGN_KEY_CHECKS = 0").execute(p).await?;
294 for (t,) in tables {
295 sqlx::query(&format!("DROP TABLE IF EXISTS `{t}`"))
296 .execute(p)
297 .await?;
298 }
299 sqlx::query("SET FOREIGN_KEY_CHECKS = 1").execute(p).await?;
300 Ok(())
301 }
302
303 async fn drop_all_sqlite_tables(&self) -> Result<(), Error> {
304 let Pool::Sqlite(p) = &self.pool else {
305 return Ok(());
306 };
307 let tables: Vec<(String,)> = sqlx::query_as(
308 "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'",
309 )
310 .fetch_all(p)
311 .await?;
312 for (t,) in tables {
313 sqlx::query(&format!("DROP TABLE IF EXISTS \"{t}\""))
314 .execute(p)
315 .await?;
316 }
317 Ok(())
318 }
319
320 pub async fn status(&self) -> Result<Vec<MigrationStatus>, Error> {
321 self.ensure_table().await?;
322 let rows = self.applied_rows().await?;
323 let applied_map: std::collections::HashMap<String, i32> = rows.into_iter().collect();
324
325 let mut out = Vec::new();
326 for m in &self.migrations {
327 let name = m.name().to_string();
328 let batch = applied_map.get(&name).copied();
329 out.push(MigrationStatus {
330 name,
331 applied: batch.is_some(),
332 batch,
333 });
334 }
335 for (db_name, batch) in &applied_map {
336 if !self.migrations.iter().any(|m| m.name() == db_name) {
337 out.push(MigrationStatus {
338 name: db_name.clone(),
339 applied: true,
340 batch: Some(*batch),
341 });
342 }
343 }
344 Ok(out)
345 }
346
347 pub async fn reset(&self) -> Result<Vec<String>, Error> {
348 self.ensure_table().await?;
349 let mut rolled_total = Vec::new();
350 loop {
351 let rolled = self.rollback().await?;
352 if rolled.is_empty() {
353 break;
354 }
355 rolled_total.extend(rolled);
356 }
357 Ok(rolled_total)
358 }
359
360 pub async fn refresh(&self) -> Result<Vec<String>, Error> {
361 self.reset().await?;
362 self.run_up().await
363 }
364
365 pub async fn run_up_step(&self) -> Result<Vec<String>, Error> {
366 self.ensure_table().await?;
367 let already = self.applied().await?;
368 let mut applied = Vec::new();
369 for m in &self.migrations {
370 if already.iter().any(|a| a == m.name()) {
371 continue;
372 }
373 let batch = self.next_batch().await?;
374 let mut schema = Schema::for_driver(self.driver());
375 m.up(&mut schema);
376 self.exec_many(&schema.statements).await?;
377 self.record_applied(m.name(), batch).await?;
378 applied.push(m.name().to_string());
379 tracing::info!(name = m.name(), batch, "migration applied (stepped)");
380 }
381 Ok(applied)
382 }
383
384 pub async fn pretend(&self) -> Result<Vec<String>, Error> {
385 self.ensure_table().await?;
386 let already = self.applied().await?;
387 let mut lines = Vec::new();
388 for m in &self.migrations {
389 if already.iter().any(|a| a == m.name()) {
390 continue;
391 }
392 lines.push(format!("-- migration: {}", m.name()));
393 let mut schema = Schema::for_driver(self.driver());
394 m.up(&mut schema);
395 for stmt in &schema.statements {
396 lines.push(format!("{stmt};"));
397 }
398 lines.push(String::new());
399 }
400 Ok(lines)
401 }
402
403 pub async fn install(&self) -> Result<(), Error> {
404 self.ensure_table().await
405 }
406
407 pub fn count(&self) -> usize {
408 self.migrations.len()
409 }
410}
411
412#[derive(Debug, Clone)]
414pub struct MigrationStatus {
415 pub name: String,
416 pub applied: bool,
417 pub batch: Option<i32>,
418}