1use crate::pool::{Driver, Pool};
8use crate::schema::Schema;
9use crate::Error;
10
11pub trait Migration: Send + Sync {
12 fn name(&self) -> &'static str;
13 fn up(&self, schema: &mut Schema);
14 fn down(&self, schema: &mut Schema);
15}
16
17inventory::collect!(MigrationRegistration);
18
19pub struct MigrationRegistration {
20 pub builder: fn() -> Box<dyn Migration>,
21}
22
23pub fn collected() -> Vec<Box<dyn Migration>> {
24 inventory::iter::<MigrationRegistration>
25 .into_iter()
26 .map(|r| (r.builder)())
27 .collect()
28}
29
30pub struct MigrationRunner {
31 pool: Pool,
32 migrations: Vec<Box<dyn Migration>>,
33}
34
35impl MigrationRunner {
36 pub fn new(pool: Pool) -> Self {
37 let mut migrations = collected();
38 migrations.sort_by_key(|m| m.name().to_string());
39 Self { pool, migrations }
40 }
41
42 pub fn with_migrations(pool: Pool, mut migrations: Vec<Box<dyn Migration>>) -> Self {
43 migrations.sort_by_key(|m| m.name().to_string());
44 Self { pool, migrations }
45 }
46
47 fn driver(&self) -> Driver {
48 self.pool.driver()
49 }
50
51 fn migrations_table_ddl(&self) -> &'static str {
54 match self.driver() {
55 Driver::Postgres => {
56 "CREATE TABLE IF NOT EXISTS migrations (
57 id BIGSERIAL PRIMARY KEY,
58 name TEXT NOT NULL UNIQUE,
59 batch INTEGER NOT NULL,
60 applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
61 )"
62 }
63 Driver::MySql => {
64 "CREATE TABLE IF NOT EXISTS migrations (
65 id BIGINT AUTO_INCREMENT PRIMARY KEY,
66 name VARCHAR(255) NOT NULL UNIQUE,
67 batch INT NOT NULL,
68 applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
69 )"
70 }
71 Driver::Sqlite => {
72 "CREATE TABLE IF NOT EXISTS migrations (
73 id INTEGER PRIMARY KEY AUTOINCREMENT,
74 name TEXT NOT NULL UNIQUE,
75 batch INTEGER NOT NULL,
76 applied_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
77 )"
78 }
79 }
80 }
81
82 fn fresh_ddl(&self) -> Vec<&'static str> {
83 match self.driver() {
84 Driver::Postgres => vec!["DROP SCHEMA public CASCADE", "CREATE SCHEMA public"],
85 Driver::MySql => vec![
86 "",
89 ],
90 Driver::Sqlite => vec![
91 "PRAGMA writable_schema = 1",
92 "DELETE FROM sqlite_master WHERE type IN ('table','index','trigger')",
93 "PRAGMA writable_schema = 0",
94 "VACUUM",
95 ],
96 }
97 }
98
99 async fn exec(&self, sql: &str) -> Result<(), Error> {
102 if sql.is_empty() {
103 return Ok(());
104 }
105 match &self.pool {
106 Pool::Postgres(p) => {
107 sqlx::query(sql).execute(p).await?;
108 }
109 Pool::MySql(p) => {
110 sqlx::query(sql).execute(p).await?;
111 }
112 Pool::Sqlite(p) => {
113 sqlx::query(sql).execute(p).await?;
114 }
115 }
116 Ok(())
117 }
118
119 async fn applied_rows(&self) -> Result<Vec<(String, i32)>, Error> {
120 Ok(match &self.pool {
121 Pool::Postgres(p) => {
122 sqlx::query_as::<_, (String, i32)>(
123 "SELECT name, batch FROM migrations ORDER BY batch, id",
124 )
125 .fetch_all(p)
126 .await?
127 }
128 Pool::MySql(p) => {
129 sqlx::query_as::<_, (String, i32)>(
130 "SELECT name, batch FROM migrations ORDER BY batch, id",
131 )
132 .fetch_all(p)
133 .await?
134 }
135 Pool::Sqlite(p) => {
136 sqlx::query_as::<_, (String, i32)>(
137 "SELECT name, batch FROM migrations ORDER BY batch, id",
138 )
139 .fetch_all(p)
140 .await?
141 }
142 })
143 }
144
145 async fn max_batch(&self) -> Result<Option<i32>, Error> {
146 Ok(match &self.pool {
147 Pool::Postgres(p) => {
148 sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
149 .fetch_one(p)
150 .await?
151 .0
152 }
153 Pool::MySql(p) => {
154 sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
155 .fetch_one(p)
156 .await?
157 .0
158 }
159 Pool::Sqlite(p) => {
160 sqlx::query_as::<_, (Option<i32>,)>("SELECT MAX(batch) FROM migrations")
161 .fetch_one(p)
162 .await?
163 .0
164 }
165 })
166 }
167
168 async fn names_in_batch(&self, batch: i32) -> Result<Vec<String>, Error> {
169 let rows: Vec<(String,)> = match &self.pool {
170 Pool::Postgres(p) => {
171 sqlx::query_as("SELECT name FROM migrations WHERE batch = $1 ORDER BY id DESC")
172 .bind(batch)
173 .fetch_all(p)
174 .await?
175 }
176 Pool::MySql(p) => {
177 sqlx::query_as("SELECT name FROM migrations WHERE batch = ? ORDER BY id DESC")
178 .bind(batch)
179 .fetch_all(p)
180 .await?
181 }
182 Pool::Sqlite(p) => {
183 sqlx::query_as("SELECT name FROM migrations WHERE batch = ?1 ORDER BY id DESC")
184 .bind(batch)
185 .fetch_all(p)
186 .await?
187 }
188 };
189 Ok(rows.into_iter().map(|(n,)| n).collect())
190 }
191
192 async fn record_applied(&self, name: &str, batch: i32) -> Result<(), Error> {
193 match &self.pool {
194 Pool::Postgres(p) => {
195 sqlx::query("INSERT INTO migrations (name, batch) VALUES ($1, $2)")
196 .bind(name)
197 .bind(batch)
198 .execute(p)
199 .await?;
200 }
201 Pool::MySql(p) => {
202 sqlx::query("INSERT INTO migrations (name, batch) VALUES (?, ?)")
203 .bind(name)
204 .bind(batch)
205 .execute(p)
206 .await?;
207 }
208 Pool::Sqlite(p) => {
209 sqlx::query("INSERT INTO migrations (name, batch) VALUES (?1, ?2)")
210 .bind(name)
211 .bind(batch)
212 .execute(p)
213 .await?;
214 }
215 }
216 Ok(())
217 }
218
219 async fn delete_applied(&self, name: &str) -> Result<(), Error> {
220 match &self.pool {
221 Pool::Postgres(p) => {
222 sqlx::query("DELETE FROM migrations WHERE name = $1")
223 .bind(name)
224 .execute(p)
225 .await?;
226 }
227 Pool::MySql(p) => {
228 sqlx::query("DELETE FROM migrations WHERE name = ?")
229 .bind(name)
230 .execute(p)
231 .await?;
232 }
233 Pool::Sqlite(p) => {
234 sqlx::query("DELETE FROM migrations WHERE name = ?1")
235 .bind(name)
236 .execute(p)
237 .await?;
238 }
239 }
240 Ok(())
241 }
242
243 async fn exec_many(&self, stmts: &[String]) -> Result<(), Error> {
244 for s in stmts {
245 self.exec(s).await?;
246 }
247 Ok(())
248 }
249
250 pub async fn ensure_table(&self) -> Result<(), Error> {
253 let ddl = self.migrations_table_ddl();
254 self.exec(ddl).await
255 }
256
257 pub async fn applied(&self) -> Result<Vec<String>, Error> {
258 Ok(self
259 .applied_rows()
260 .await?
261 .into_iter()
262 .map(|(n, _)| n)
263 .collect())
264 }
265
266 pub async fn next_batch(&self) -> Result<i32, Error> {
267 Ok(self.max_batch().await?.unwrap_or(0) + 1)
268 }
269
270 pub async fn run_up(&self) -> Result<Vec<String>, Error> {
271 self.ensure_table().await?;
272 let already = self.applied().await?;
273 let batch = self.next_batch().await?;
274 let mut applied = Vec::new();
275 for m in &self.migrations {
276 if already.iter().any(|a| a == m.name()) {
277 continue;
278 }
279 let mut schema = Schema::for_driver(self.driver());
280 m.up(&mut schema);
281 self.exec_many(&schema.statements).await?;
282 self.record_applied(m.name(), batch).await?;
283 applied.push(m.name().to_string());
284 tracing::info!(name = m.name(), "migration applied");
285 }
286 Ok(applied)
287 }
288
289 pub async fn rollback(&self) -> Result<Vec<String>, Error> {
290 self.ensure_table().await?;
291 let Some(batch) = self.max_batch().await? else {
292 return Ok(Vec::new());
293 };
294 let names = self.names_in_batch(batch).await?;
295 let mut rolled = Vec::new();
296 for name in names {
297 let Some(m) = self.migrations.iter().find(|m| m.name() == name) else {
298 tracing::warn!(name, "migration row in DB but not registered; skipping");
299 continue;
300 };
301 let mut schema = Schema::for_driver(self.driver());
302 m.down(&mut schema);
303 self.exec_many(&schema.statements).await?;
304 self.delete_applied(&name).await?;
305 rolled.push(name);
306 }
307 Ok(rolled)
308 }
309
310 pub async fn fresh(&self) -> Result<(), Error> {
311 match self.driver() {
315 Driver::Postgres => {
316 for s in self.fresh_ddl() {
317 self.exec(s).await?;
318 }
319 }
320 Driver::MySql => {
321 self.drop_all_mysql_tables().await?;
322 }
323 Driver::Sqlite => {
324 self.drop_all_sqlite_tables().await?;
325 }
326 }
327 self.run_up().await?;
328 Ok(())
329 }
330
331 async fn drop_all_mysql_tables(&self) -> Result<(), Error> {
332 let Pool::MySql(p) = &self.pool else {
333 return Ok(());
334 };
335 let tables: Vec<(String,)> = sqlx::query_as(
336 "SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE()",
337 )
338 .fetch_all(p)
339 .await?;
340 sqlx::query("SET FOREIGN_KEY_CHECKS = 0").execute(p).await?;
341 for (t,) in tables {
342 sqlx::query(&format!("DROP TABLE IF EXISTS `{t}`"))
343 .execute(p)
344 .await?;
345 }
346 sqlx::query("SET FOREIGN_KEY_CHECKS = 1").execute(p).await?;
347 Ok(())
348 }
349
350 async fn drop_all_sqlite_tables(&self) -> Result<(), Error> {
351 let Pool::Sqlite(p) = &self.pool else {
352 return Ok(());
353 };
354 let tables: Vec<(String,)> = sqlx::query_as(
355 "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'",
356 )
357 .fetch_all(p)
358 .await?;
359 for (t,) in tables {
360 sqlx::query(&format!("DROP TABLE IF EXISTS \"{t}\""))
361 .execute(p)
362 .await?;
363 }
364 Ok(())
365 }
366
367 pub async fn status(&self) -> Result<Vec<MigrationStatus>, Error> {
368 self.ensure_table().await?;
369 let rows = self.applied_rows().await?;
370 let applied_map: std::collections::HashMap<String, i32> = rows.into_iter().collect();
371
372 let mut out = Vec::new();
373 for m in &self.migrations {
374 let name = m.name().to_string();
375 let batch = applied_map.get(&name).copied();
376 out.push(MigrationStatus {
377 name,
378 applied: batch.is_some(),
379 batch,
380 });
381 }
382 for (db_name, batch) in &applied_map {
383 if !self.migrations.iter().any(|m| m.name() == db_name) {
384 out.push(MigrationStatus {
385 name: db_name.clone(),
386 applied: true,
387 batch: Some(*batch),
388 });
389 }
390 }
391 Ok(out)
392 }
393
394 pub async fn reset(&self) -> Result<Vec<String>, Error> {
395 self.ensure_table().await?;
396 let mut rolled_total = Vec::new();
397 loop {
398 let rolled = self.rollback().await?;
399 if rolled.is_empty() {
400 break;
401 }
402 rolled_total.extend(rolled);
403 }
404 Ok(rolled_total)
405 }
406
407 pub async fn refresh(&self) -> Result<Vec<String>, Error> {
408 self.reset().await?;
409 self.run_up().await
410 }
411
412 pub async fn run_up_step(&self) -> Result<Vec<String>, Error> {
413 self.ensure_table().await?;
414 let already = self.applied().await?;
415 let mut applied = Vec::new();
416 for m in &self.migrations {
417 if already.iter().any(|a| a == m.name()) {
418 continue;
419 }
420 let batch = self.next_batch().await?;
421 let mut schema = Schema::for_driver(self.driver());
422 m.up(&mut schema);
423 self.exec_many(&schema.statements).await?;
424 self.record_applied(m.name(), batch).await?;
425 applied.push(m.name().to_string());
426 tracing::info!(name = m.name(), batch, "migration applied (stepped)");
427 }
428 Ok(applied)
429 }
430
431 pub async fn pretend(&self) -> Result<Vec<String>, Error> {
432 self.ensure_table().await?;
433 let already = self.applied().await?;
434 let mut lines = Vec::new();
435 for m in &self.migrations {
436 if already.iter().any(|a| a == m.name()) {
437 continue;
438 }
439 lines.push(format!("-- migration: {}", m.name()));
440 let mut schema = Schema::for_driver(self.driver());
441 m.up(&mut schema);
442 for stmt in &schema.statements {
443 lines.push(format!("{stmt};"));
444 }
445 lines.push(String::new());
446 }
447 Ok(lines)
448 }
449
450 pub async fn install(&self) -> Result<(), Error> {
451 self.ensure_table().await
452 }
453
454 pub fn count(&self) -> usize {
455 self.migrations.len()
456 }
457}
458
459#[derive(Debug, Clone)]
461pub struct MigrationStatus {
462 pub name: String,
463 pub applied: bool,
464 pub batch: Option<i32>,
465}